diff --git a/api/chat-stream.js b/api/chat-stream.js index c4a4cd1..32e7601 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,968 +1,3 @@ 'use strict'; -const crypto = require('crypto'); - -const { - extractToolNames, - createToolSieveState, - processToolSieveChunk, - flushToolSieve, - parseToolCalls, - formatOpenAIStreamToolCalls, -} = require('./helpers/stream-tool-sieve'); -const { - BASE_HEADERS, - SKIP_PATTERNS, - SKIP_EXACT_PATHS, -} = require('./shared/deepseek-constants'); - -const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; - -module.exports = async function handler(req, res) { - setCorsHeaders(res); - if (req.method === 'OPTIONS') { - res.statusCode = 204; - res.end(); - return; - } - if (req.method !== 'POST') { - writeOpenAIError(res, 405, 'method not allowed'); - return; - } - - const rawBody = await readRawBody(req); - - // Hard guard: only use Node data path for streaming on Vercel runtime. - // Any non-Vercel runtime always falls back to Go for full behavior parity. - if (!isVercelRuntime()) { - await proxyToGo(req, res, rawBody); - return; - } - - let payload; - try { - payload = JSON.parse(rawBody.toString('utf8') || '{}'); - } catch (_err) { - writeOpenAIError(res, 400, 'invalid json'); - return; - } - - // Keep all non-stream behavior on Go side to avoid compatibility regressions. - if (!toBool(payload.stream)) { - await proxyToGo(req, res, rawBody); - return; - } - - const prep = await fetchStreamPrepare(req, rawBody); - if (!prep.ok) { - relayPreparedFailure(res, prep); - return; - } - - const model = asString(prep.body.model) || asString(payload.model); - const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`; - const leaseID = asString(prep.body.lease_id); - const deepseekToken = asString(prep.body.deepseek_token); - const powHeader = asString(prep.body.pow_header); - const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null; - const finalPrompt = asString(prep.body.final_prompt); - const thinkingEnabled = toBool(prep.body.thinking_enabled); - const searchEnabled = toBool(prep.body.search_enabled); - const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); - const toolNames = toolPolicy.toolNames; - - if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { - writeOpenAIError(res, 500, 'invalid vercel prepare response'); - return; - } - const releaseLease = createLeaseReleaser(req, leaseID); - const upstreamController = new AbortController(); - let clientClosed = false; - let reader = null; - const markClientClosed = () => { - if (clientClosed) { - return; - } - clientClosed = true; - upstreamController.abort(); - if (reader && typeof reader.cancel === 'function') { - Promise.resolve(reader.cancel()).catch(() => {}); - } - }; - const onReqAborted = () => markClientClosed(); - const onResClose = () => { - if (!res.writableEnded) { - markClientClosed(); - } - }; - req.on('aborted', onReqAborted); - res.on('close', onResClose); - try { - let completionRes; - try { - completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { - method: 'POST', - headers: { - ...BASE_HEADERS, - authorization: `Bearer ${deepseekToken}`, - 'x-ds-pow-response': powHeader, - }, - body: JSON.stringify(completionPayload), - signal: upstreamController.signal, - }); - } catch (err) { - if (clientClosed || isAbortError(err)) { - return; - } - throw err; - } - if (clientClosed) { - return; - } - - if (!completionRes.ok || !completionRes.body) { - const detail = await safeReadText(completionRes); - writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.'); - return; - } - - res.statusCode = 200; - res.setHeader('Content-Type', 'text/event-stream'); - res.setHeader('Cache-Control', 'no-cache, no-transform'); - res.setHeader('Connection', 'keep-alive'); - res.setHeader('X-Accel-Buffering', 'no'); - if (typeof res.flushHeaders === 'function') { - res.flushHeaders(); - } - - const created = Math.floor(Date.now() / 1000); - let firstChunkSent = false; - let currentType = thinkingEnabled ? 'thinking' : 'text'; - let thinkingText = ''; - let outputText = ''; - const toolSieveEnabled = toolPolicy.toolSieveEnabled; - const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; - const toolSieveState = createToolSieveState(); - let toolCallsEmitted = false; - const streamToolCallIDs = new Map(); - const decoder = new TextDecoder(); - reader = completionRes.body.getReader(); - let buffered = ''; - let ended = false; - - const sendFrame = (obj) => { - if (clientClosed || res.writableEnded || res.destroyed) { - return; - } - res.write(`data: ${JSON.stringify(obj)}\n\n`); - if (typeof res.flush === 'function') { - res.flush(); - } - }; - - const sendDeltaFrame = (delta) => { - const payloadDelta = { ...delta }; - if (!firstChunkSent) { - payloadDelta.role = 'assistant'; - firstChunkSent = true; - } - sendFrame({ - id: sessionID, - object: 'chat.completion.chunk', - created, - model, - choices: [{ delta: payloadDelta, index: 0 }], - }); - }; - - const finish = async (reason) => { - if (ended) { - return; - } - ended = true; - if (clientClosed || res.writableEnded || res.destroyed) { - await releaseLease(); - return; - } - const detected = parseToolCalls(outputText, toolNames); - if (detected.length > 0 && !toolCallsEmitted) { - toolCallsEmitted = true; - sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); - } else if (toolSieveEnabled) { - const tailEvents = flushToolSieve(toolSieveState, toolNames); - for (const evt of tailEvents) { - if (evt.text) { - sendDeltaFrame({ content: evt.text }); - } - } - } - if (detected.length > 0 || toolCallsEmitted) { - reason = 'tool_calls'; - } - sendFrame({ - id: sessionID, - object: 'chat.completion.chunk', - created, - model, - choices: [{ delta: {}, index: 0, finish_reason: reason }], - usage: buildUsage(finalPrompt, thinkingText, outputText), - }); - if (!res.writableEnded && !res.destroyed) { - res.write('data: [DONE]\n\n'); - } - await releaseLease(); - if (!res.writableEnded && !res.destroyed) { - res.end(); - } - }; - - try { - // eslint-disable-next-line no-constant-condition - while (true) { - if (clientClosed) { - await finish('stop'); - return; - } - const { value, done } = await reader.read(); - if (done) { - break; - } - buffered += decoder.decode(value, { stream: true }); - const lines = buffered.split('\n'); - buffered = lines.pop() || ''; - - for (const rawLine of lines) { - const line = rawLine.trim(); - if (!line.startsWith('data:')) { - continue; - } - const dataStr = line.slice(5).trim(); - if (!dataStr) { - continue; - } - if (dataStr === '[DONE]') { - await finish('stop'); - return; - } - let chunk; - try { - chunk = JSON.parse(dataStr); - } catch (_err) { - continue; - } - if (chunk.error || chunk.code === 'content_filter') { - await finish('content_filter'); - return; - } - const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType); - currentType = parsed.newType; - if (parsed.finished) { - await finish('stop'); - return; - } - - for (const p of parsed.parts) { - if (!p.text) { - continue; - } - if (searchEnabled && isCitation(p.text)) { - continue; - } - if (p.type === 'thinking') { - if (thinkingEnabled) { - thinkingText += p.text; - sendDeltaFrame({ reasoning_content: p.text }); - } - } else { - outputText += p.text; - if (!toolSieveEnabled) { - sendDeltaFrame({ content: p.text }); - continue; - } - const events = processToolSieveChunk(toolSieveState, p.text, toolNames); - for (const evt of events) { - if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { - if (!emitEarlyToolDeltas) { - continue; - } - toolCallsEmitted = true; - sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); - continue; - } - if (evt.type === 'tool_calls') { - toolCallsEmitted = true; - sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); - continue; - } - if (evt.text) { - sendDeltaFrame({ content: evt.text }); - } - } - } - } - } - } - await finish('stop'); - } catch (err) { - if (clientClosed || isAbortError(err)) { - await finish('stop'); - return; - } - await finish('stop'); - } - } finally { - req.removeListener('aborted', onReqAborted); - res.removeListener('close', onResClose); - await releaseLease(); - } -}; - -function setCorsHeaders(res) { - res.setHeader('Access-Control-Allow-Origin', '*'); - res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE'); - res.setHeader( - 'Access-Control-Allow-Headers', - 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass', - ); -} - -function header(req, key) { - if (!req || !req.headers) { - return ''; - } - return asString(req.headers[key.toLowerCase()]); -} - -async function readRawBody(req) { - if (Buffer.isBuffer(req.body)) { - return req.body; - } - if (typeof req.body === 'string') { - return Buffer.from(req.body); - } - if (req.body && typeof req.body === 'object') { - return Buffer.from(JSON.stringify(req.body)); - } - const chunks = []; - for await (const chunk of req) { - chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); - } - return Buffer.concat(chunks); -} - -async function fetchStreamPrepare(req, rawBody) { - const url = buildInternalGoURL(req); - url.searchParams.set('__stream_prepare', '1'); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), - body: rawBody, - }); - - const text = await upstream.text(); - let body = {}; - try { - body = JSON.parse(text || '{}'); - } catch (_err) { - body = {}; - } - - return { - ok: upstream.ok, - status: upstream.status, - contentType: upstream.headers.get('content-type') || 'application/json', - text, - body, - }; -} - -function relayPreparedFailure(res, prep) { - if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) { - writeOpenAIError( - res, - 401, - 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.', - ); - return; - } - res.statusCode = prep.status || 500; - res.setHeader('Content-Type', prep.contentType || 'application/json'); - if (prep.text) { - res.end(prep.text); - return; - } - writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); -} - -function resolveToolcallPolicy(prepBody, payloadTools) { - const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); - const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); - const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); - const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); - return { - toolNames, - toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, - emitEarlyToolDeltas, - }; -} - -function normalizePreparedToolNames(v) { - if (!Array.isArray(v) || v.length === 0) { - return []; - } - const out = []; - for (const item of v) { - const name = asString(item); - if (!name) { - continue; - } - out.push(name); - } - return out; -} - -function boolDefaultTrue(v) { - return v !== false; -} - -async function safeReadText(resp) { - if (!resp) { - return ''; - } - try { - const text = await resp.text(); - return text.trim(); - } catch (_err) { - return ''; - } -} - -function internalSecret() { - return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin'; -} - -function buildInternalGoURL(req) { - const proto = asString(header(req, 'x-forwarded-proto')) || 'https'; - const host = asString(header(req, 'host')); - const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`); - url.searchParams.set('__go', '1'); - const protectionBypass = resolveProtectionBypass(req); - if (protectionBypass) { - url.searchParams.set('x-vercel-protection-bypass', protectionBypass); - } - return url; -} - -function buildInternalGoHeaders(req, opts = {}) { - const headers = { - authorization: asString(header(req, 'authorization')), - 'x-api-key': asString(header(req, 'x-api-key')), - 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')), - 'x-vercel-protection-bypass': resolveProtectionBypass(req), - }; - if (opts.withInternalToken) { - headers['x-ds2-internal-token'] = internalSecret(); - } - if (opts.withContentType) { - headers['content-type'] = asString(header(req, 'content-type')) || 'application/json'; - } - return headers; -} - -function createLeaseReleaser(req, leaseID) { - let released = false; - return async () => { - if (released || !leaseID) { - return; - } - released = true; - try { - await releaseStreamLease(req, leaseID); - } catch (_err) { - // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks. - } - }; -} - -async function releaseStreamLease(req, leaseID) { - const url = buildInternalGoURL(req); - url.searchParams.set('__stream_release', '1'); - const body = Buffer.from(JSON.stringify({ lease_id: leaseID })); - - const controller = new AbortController(); - const timeout = setTimeout(() => controller.abort(), 1500); - try { - await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), - body, - signal: controller.signal, - }); - } finally { - clearTimeout(timeout); - } -} - -function resolveProtectionBypass(req) { - const fromHeader = asString(header(req, 'x-vercel-protection-bypass')); - if (fromHeader) { - return fromHeader; - } - return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS); -} - -function looksLikeVercelAuthPage(text) { - const body = asString(text).toLowerCase(); - if (!body) { - return false; - } - return body.includes('authentication required') && body.includes('vercel'); -} - -function parseChunkForContent(chunk, thinkingEnabled, currentType) { - if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) { - return { parts: [], finished: false, newType: currentType }; - } - const pathValue = asString(chunk.p); - if (shouldSkipPath(pathValue)) { - return { parts: [], finished: false, newType: currentType }; - } - if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') { - return { parts: [], finished: true, newType: currentType }; - } - - let newType = currentType; - const parts = []; - - if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) { - for (const frag of chunk.v) { - if (!frag || typeof frag !== 'object') { - continue; - } - const fragType = asString(frag.type).toUpperCase(); - const content = asString(frag.content); - if (!content) { - continue; - } - if (fragType === 'THINK' || fragType === 'THINKING') { - newType = 'thinking'; - parts.push({ text: content, type: 'thinking' }); - } else if (fragType === 'RESPONSE') { - newType = 'text'; - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: 'text' }); - } - } - } - - if (pathValue === 'response' && Array.isArray(chunk.v)) { - for (const item of chunk.v) { - if (!item || typeof item !== 'object') { - continue; - } - if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) { - for (const frag of item.v) { - const fragType = asString(frag && frag.type).toUpperCase(); - if (fragType === 'THINK' || fragType === 'THINKING') { - newType = 'thinking'; - } else if (fragType === 'RESPONSE') { - newType = 'text'; - } - } - } - } - } - - let partType = 'text'; - if (pathValue === 'response/thinking_content') { - partType = 'thinking'; - } else if (pathValue === 'response/content') { - partType = 'text'; - } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) { - partType = newType; - } else if (!pathValue && thinkingEnabled) { - partType = newType; - } - - const val = chunk.v; - if (typeof val === 'string') { - if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) { - return { parts: [], finished: true, newType }; - } - if (val) { - parts.push({ text: val, type: partType }); - } - return { parts, finished: false, newType }; - } - - if (Array.isArray(val)) { - const extracted = extractContentRecursive(val, partType); - if (extracted.finished) { - return { parts: [], finished: true, newType }; - } - parts.push(...extracted.parts); - return { parts, finished: false, newType }; - } - - if (val && typeof val === 'object') { - const resp = val.response && typeof val.response === 'object' ? val.response : val; - if (Array.isArray(resp.fragments)) { - for (const frag of resp.fragments) { - if (!frag || typeof frag !== 'object') { - continue; - } - const content = asString(frag.content); - if (!content) { - continue; - } - const t = asString(frag.type).toUpperCase(); - if (t === 'THINK' || t === 'THINKING') { - newType = 'thinking'; - parts.push({ text: content, type: 'thinking' }); - } else if (t === 'RESPONSE') { - newType = 'text'; - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: partType }); - } - } - } - } - return { parts, finished: false, newType }; -} - -function extractContentRecursive(items, defaultType) { - const parts = []; - for (const it of items) { - if (!it || typeof it !== 'object') { - continue; - } - if (!Object.prototype.hasOwnProperty.call(it, 'v')) { - continue; - } - const itemPath = asString(it.p); - const itemV = it.v; - if (itemPath === 'status' && asString(itemV) === 'FINISHED') { - return { parts: [], finished: true }; - } - if (shouldSkipPath(itemPath)) { - continue; - } - const content = asString(it.content); - if (content) { - const typeName = asString(it.type).toUpperCase(); - if (typeName === 'THINK' || typeName === 'THINKING') { - parts.push({ text: content, type: 'thinking' }); - } else if (typeName === 'RESPONSE') { - parts.push({ text: content, type: 'text' }); - } else { - parts.push({ text: content, type: defaultType }); - } - continue; - } - - let partType = defaultType; - if (itemPath.includes('thinking')) { - partType = 'thinking'; - } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') { - partType = 'text'; - } - - if (typeof itemV === 'string') { - if (itemV && itemV !== 'FINISHED') { - parts.push({ text: itemV, type: partType }); - } - continue; - } - - if (!Array.isArray(itemV)) { - continue; - } - for (const inner of itemV) { - if (typeof inner === 'string') { - if (inner) { - parts.push({ text: inner, type: partType }); - } - continue; - } - if (!inner || typeof inner !== 'object') { - continue; - } - const ct = asString(inner.content); - if (!ct) { - continue; - } - const typeName = asString(inner.type).toUpperCase(); - if (typeName === 'THINK' || typeName === 'THINKING') { - parts.push({ text: ct, type: 'thinking' }); - } else if (typeName === 'RESPONSE') { - parts.push({ text: ct, type: 'text' }); - } else { - parts.push({ text: ct, type: partType }); - } - } - } - return { parts, finished: false }; -} - -function shouldSkipPath(pathValue) { - if (SKIP_EXACT_PATHS.has(pathValue)) { - return true; - } - for (const p of SKIP_PATTERNS) { - if (pathValue.includes(p)) { - return true; - } - } - return false; -} - -function isCitation(text) { - return asString(text).trim().startsWith('[citation:'); -} - -function buildUsage(prompt, thinking, output) { - const promptTokens = estimateTokens(prompt); - const reasoningTokens = estimateTokens(thinking); - const completionTokens = estimateTokens(output); - return { - prompt_tokens: promptTokens, - completion_tokens: reasoningTokens + completionTokens, - total_tokens: promptTokens + reasoningTokens + completionTokens, - completion_tokens_details: { - reasoning_tokens: reasoningTokens, - }, - }; -} - -function formatIncrementalToolCallDeltas(deltas, idStore) { - if (!Array.isArray(deltas) || deltas.length === 0) { - return []; - } - const out = []; - for (const d of deltas) { - if (!d || typeof d !== 'object') { - continue; - } - const index = Number.isInteger(d.index) ? d.index : 0; - const id = ensureStreamToolCallID(idStore, index); - const item = { - index, - id, - type: 'function', - }; - const fn = {}; - if (asString(d.name)) { - fn.name = asString(d.name); - } - if (typeof d.arguments === 'string' && d.arguments !== '') { - fn.arguments = d.arguments; - } - if (Object.keys(fn).length > 0) { - item.function = fn; - } - out.push(item); - } - return out; -} - -function ensureStreamToolCallID(idStore, index) { - const key = Number.isInteger(index) ? index : 0; - const existing = idStore.get(key); - if (existing) { - return existing; - } - const next = `call_${newCallID()}`; - idStore.set(key, next); - return next; -} - -function newCallID() { - if (typeof crypto.randomUUID === 'function') { - return crypto.randomUUID().replace(/-/g, ''); - } - return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; -} - -function estimateTokens(text) { - const t = asString(text); - if (!t) { - return 0; - } - let asciiChars = 0; - let nonASCIIChars = 0; - for (const ch of Array.from(t)) { - if (ch.charCodeAt(0) < 128) { - asciiChars += 1; - } else { - nonASCIIChars += 1; - } - } - const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13); - return n < 1 ? 1 : n; -} - -async function proxyToGo(req, res, rawBody) { - const url = buildInternalGoURL(req); - const controller = new AbortController(); - let clientClosed = false; - const markClientClosed = () => { - if (clientClosed) { - return; - } - clientClosed = true; - controller.abort(); - }; - const onReqAborted = () => markClientClosed(); - const onResClose = () => { - if (!res.writableEnded) { - markClientClosed(); - } - }; - req.on('aborted', onReqAborted); - res.on('close', onResClose); - - try { - let upstream; - try { - upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withContentType: true }), - body: rawBody, - signal: controller.signal, - }); - } catch (err) { - if (clientClosed || isAbortError(err)) { - if (!res.writableEnded) { - res.end(); - } - return; - } - throw err; - } - if (clientClosed) { - if (!res.writableEnded) { - res.end(); - } - return; - } - - res.statusCode = upstream.status; - upstream.headers.forEach((value, key) => { - if (key.toLowerCase() === 'content-length') { - return; - } - res.setHeader(key, value); - }); - - if (!upstream.body || typeof upstream.body.getReader !== 'function') { - const bytes = Buffer.from(await upstream.arrayBuffer()); - res.end(bytes); - return; - } - - const reader = upstream.body.getReader(); - try { - // eslint-disable-next-line no-constant-condition - while (true) { - if (clientClosed) { - break; - } - const { value, done } = await reader.read(); - if (done) { - break; - } - if (value && value.length > 0) { - res.write(Buffer.from(value)); - if (typeof res.flush === 'function') { - res.flush(); - } - } - } - if (!res.writableEnded) { - res.end(); - } - } catch (err) { - if (!isAbortError(err) && !res.writableEnded) { - res.end(); - } - } - } finally { - req.removeListener('aborted', onReqAborted); - res.removeListener('close', onResClose); - if (!res.writableEnded) { - res.end(); - } - } -} - -function writeOpenAIError(res, status, message) { - res.statusCode = status; - res.setHeader('Content-Type', 'application/json'); - res.end( - JSON.stringify({ - error: { - message, - type: openAIErrorType(status), - }, - }), - ); -} - -function openAIErrorType(status) { - switch (status) { - case 400: - return 'invalid_request_error'; - case 401: - return 'authentication_error'; - case 403: - return 'permission_error'; - case 429: - return 'rate_limit_error'; - case 503: - return 'service_unavailable_error'; - default: - return status >= 500 ? 'api_error' : 'invalid_request_error'; - } -} - -function toBool(v) { - return v === true; -} - -function isVercelRuntime() { - return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== ''; -} - -function asString(v) { - if (typeof v === 'string') { - return v.trim(); - } - if (Array.isArray(v)) { - return asString(v[0]); - } - if (v == null) { - return ''; - } - return String(v).trim(); -} - -function isAbortError(err) { - if (!err || typeof err !== 'object') { - return false; - } - return err.name === 'AbortError' || err.code === 'ABORT_ERR'; -} - -module.exports.__test = { - parseChunkForContent, - extractContentRecursive, - shouldSkipPath, - asString, - resolveToolcallPolicy, - normalizePreparedToolNames, - boolDefaultTrue, - estimateTokens, -}; +module.exports = require('./chat-stream/index.js'); diff --git a/api/chat-stream/error_shape.js b/api/chat-stream/error_shape.js new file mode 100644 index 0000000..18aeedb --- /dev/null +++ b/api/chat-stream/error_shape.js @@ -0,0 +1,36 @@ +'use strict'; + +function writeOpenAIError(res, status, message) { + res.statusCode = status; + res.setHeader('Content-Type', 'application/json'); + res.end( + JSON.stringify({ + error: { + message, + type: openAIErrorType(status), + }, + }), + ); +} + +function openAIErrorType(status) { + switch (status) { + case 400: + return 'invalid_request_error'; + case 401: + return 'authentication_error'; + case 403: + return 'permission_error'; + case 429: + return 'rate_limit_error'; + case 503: + return 'service_unavailable_error'; + default: + return status >= 500 ? 'api_error' : 'invalid_request_error'; + } +} + +module.exports = { + writeOpenAIError, + openAIErrorType, +}; diff --git a/api/chat-stream/http_internal.js b/api/chat-stream/http_internal.js new file mode 100644 index 0000000..20f24c8 --- /dev/null +++ b/api/chat-stream/http_internal.js @@ -0,0 +1,214 @@ +'use strict'; + +const { + writeOpenAIError, +} = require('./error_shape'); + +function setCorsHeaders(res) { + res.setHeader('Access-Control-Allow-Origin', '*'); + res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE'); + res.setHeader( + 'Access-Control-Allow-Headers', + 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass', + ); +} + +function header(req, key) { + if (!req || !req.headers) { + return ''; + } + return asString(req.headers[key.toLowerCase()]); +} + +async function readRawBody(req) { + if (Buffer.isBuffer(req.body)) { + return req.body; + } + if (typeof req.body === 'string') { + return Buffer.from(req.body); + } + if (req.body && typeof req.body === 'object') { + return Buffer.from(JSON.stringify(req.body)); + } + const chunks = []; + for await (const chunk of req) { + chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk)); + } + return Buffer.concat(chunks); +} + +async function fetchStreamPrepare(req, rawBody) { + const url = buildInternalGoURL(req); + url.searchParams.set('__stream_prepare', '1'); + + const upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), + body: rawBody, + }); + + const text = await upstream.text(); + let body = {}; + try { + body = JSON.parse(text || '{}'); + } catch (_err) { + body = {}; + } + + return { + ok: upstream.ok, + status: upstream.status, + contentType: upstream.headers.get('content-type') || 'application/json', + text, + body, + }; +} + +function relayPreparedFailure(res, prep) { + if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) { + writeOpenAIError( + res, + 401, + 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.', + ); + return; + } + res.statusCode = prep.status || 500; + res.setHeader('Content-Type', prep.contentType || 'application/json'); + if (prep.text) { + res.end(prep.text); + return; + } + writeOpenAIError(res, prep.status || 500, 'vercel prepare failed'); +} + +async function safeReadText(resp) { + if (!resp) { + return ''; + } + try { + const text = await resp.text(); + return text.trim(); + } catch (_err) { + return ''; + } +} + +function internalSecret() { + return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin'; +} + +function buildInternalGoURL(req) { + const proto = asString(header(req, 'x-forwarded-proto')) || 'https'; + const host = asString(header(req, 'host')); + const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`); + url.searchParams.set('__go', '1'); + const protectionBypass = resolveProtectionBypass(req); + if (protectionBypass) { + url.searchParams.set('x-vercel-protection-bypass', protectionBypass); + } + return url; +} + +function buildInternalGoHeaders(req, opts = {}) { + const headers = { + authorization: asString(header(req, 'authorization')), + 'x-api-key': asString(header(req, 'x-api-key')), + 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')), + 'x-vercel-protection-bypass': resolveProtectionBypass(req), + }; + if (opts.withInternalToken) { + headers['x-ds2-internal-token'] = internalSecret(); + } + if (opts.withContentType) { + headers['content-type'] = asString(header(req, 'content-type')) || 'application/json'; + } + return headers; +} + +function createLeaseReleaser(req, leaseID) { + let released = false; + return async () => { + if (released || !leaseID) { + return; + } + released = true; + try { + await releaseStreamLease(req, leaseID); + } catch (_err) { + // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks. + } + }; +} + +async function releaseStreamLease(req, leaseID) { + const url = buildInternalGoURL(req); + url.searchParams.set('__stream_release', '1'); + const body = Buffer.from(JSON.stringify({ lease_id: leaseID })); + + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), 1500); + try { + await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }), + body, + signal: controller.signal, + }); + } finally { + clearTimeout(timeout); + } +} + +function resolveProtectionBypass(req) { + const fromHeader = asString(header(req, 'x-vercel-protection-bypass')); + if (fromHeader) { + return fromHeader; + } + return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS); +} + +function looksLikeVercelAuthPage(text) { + const body = asString(text).toLowerCase(); + if (!body) { + return false; + } + return body.includes('authentication required') && body.includes('vercel'); +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +function isAbortError(err) { + if (!err || typeof err !== 'object') { + return false; + } + return err.name === 'AbortError' || err.code === 'ABORT_ERR'; +} + +module.exports = { + setCorsHeaders, + header, + readRawBody, + fetchStreamPrepare, + relayPreparedFailure, + safeReadText, + buildInternalGoURL, + buildInternalGoHeaders, + createLeaseReleaser, + releaseStreamLease, + resolveProtectionBypass, + looksLikeVercelAuthPage, + asString, + isAbortError, +}; diff --git a/api/chat-stream/index.js b/api/chat-stream/index.js new file mode 100644 index 0000000..4528924 --- /dev/null +++ b/api/chat-stream/index.js @@ -0,0 +1,88 @@ +'use strict'; + +const { + writeOpenAIError, +} = require('./error_shape'); +const { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, +} = require('./sse_parse'); +const { + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, +} = require('./toolcall_policy'); +const { + estimateTokens, +} = require('./token_usage'); +const { + setCorsHeaders, + readRawBody, + asString, +} = require('./http_internal'); +const { + proxyToGo, +} = require('./proxy_go'); +const { + handleVercelStream, +} = require('./vercel_stream'); + +async function handler(req, res) { + setCorsHeaders(res); + if (req.method === 'OPTIONS') { + res.statusCode = 204; + res.end(); + return; + } + if (req.method !== 'POST') { + writeOpenAIError(res, 405, 'method not allowed'); + return; + } + + const rawBody = await readRawBody(req); + + // Hard guard: only use Node data path for streaming on Vercel runtime. + // Any non-Vercel runtime always falls back to Go for full behavior parity. + if (!isVercelRuntime()) { + await proxyToGo(req, res, rawBody); + return; + } + + let payload; + try { + payload = JSON.parse(rawBody.toString('utf8') || '{}'); + } catch (_err) { + writeOpenAIError(res, 400, 'invalid json'); + return; + } + + // Keep all non-stream behavior on Go side to avoid compatibility regressions. + if (!toBool(payload.stream)) { + await proxyToGo(req, res, rawBody); + return; + } + + await handleVercelStream(req, res, rawBody, payload); +} + +function toBool(v) { + return v === true; +} + +function isVercelRuntime() { + return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== ''; +} + +module.exports = handler; + +module.exports.__test = { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, + asString, + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, + estimateTokens, +}; diff --git a/api/chat-stream/proxy_go.js b/api/chat-stream/proxy_go.js new file mode 100644 index 0000000..5218df0 --- /dev/null +++ b/api/chat-stream/proxy_go.js @@ -0,0 +1,105 @@ +'use strict'; + +const { + buildInternalGoURL, + buildInternalGoHeaders, + isAbortError, +} = require('./http_internal'); + +async function proxyToGo(req, res, rawBody) { + const url = buildInternalGoURL(req); + const controller = new AbortController(); + let clientClosed = false; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + controller.abort(); + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); + + try { + let upstream; + try { + upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withContentType: true }), + body: rawBody, + signal: controller.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + if (!res.writableEnded) { + res.end(); + } + return; + } + throw err; + } + if (clientClosed) { + if (!res.writableEnded) { + res.end(); + } + return; + } + + res.statusCode = upstream.status; + upstream.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (!upstream.body || typeof upstream.body.getReader !== 'function') { + const bytes = Buffer.from(await upstream.arrayBuffer()); + res.end(bytes); + return; + } + + const reader = upstream.body.getReader(); + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + break; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.length > 0) { + res.write(Buffer.from(value)); + if (typeof res.flush === 'function') { + res.flush(); + } + } + } + if (!res.writableEnded) { + res.end(); + } + } catch (err) { + if (!isAbortError(err) && !res.writableEnded) { + res.end(); + } + } + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); + if (!res.writableEnded) { + res.end(); + } + } +} + +module.exports = { + proxyToGo, +}; diff --git a/api/chat-stream/sse_parse.js b/api/chat-stream/sse_parse.js new file mode 100644 index 0000000..1774430 --- /dev/null +++ b/api/chat-stream/sse_parse.js @@ -0,0 +1,229 @@ +'use strict'; + +const { + SKIP_PATTERNS, + SKIP_EXACT_PATHS, +} = require('../shared/deepseek-constants'); + +function parseChunkForContent(chunk, thinkingEnabled, currentType) { + if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) { + return { parts: [], finished: false, newType: currentType }; + } + const pathValue = asString(chunk.p); + if (shouldSkipPath(pathValue)) { + return { parts: [], finished: false, newType: currentType }; + } + if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') { + return { parts: [], finished: true, newType: currentType }; + } + + let newType = currentType; + const parts = []; + + if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) { + for (const frag of chunk.v) { + if (!frag || typeof frag !== 'object') { + continue; + } + const fragType = asString(frag.type).toUpperCase(); + const content = asString(frag.content); + if (!content) { + continue; + } + if (fragType === 'THINK' || fragType === 'THINKING') { + newType = 'thinking'; + parts.push({ text: content, type: 'thinking' }); + } else if (fragType === 'RESPONSE') { + newType = 'text'; + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: 'text' }); + } + } + } + + if (pathValue === 'response' && Array.isArray(chunk.v)) { + for (const item of chunk.v) { + if (!item || typeof item !== 'object') { + continue; + } + if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) { + for (const frag of item.v) { + const fragType = asString(frag && frag.type).toUpperCase(); + if (fragType === 'THINK' || fragType === 'THINKING') { + newType = 'thinking'; + } else if (fragType === 'RESPONSE') { + newType = 'text'; + } + } + } + } + } + + let partType = 'text'; + if (pathValue === 'response/thinking_content') { + partType = 'thinking'; + } else if (pathValue === 'response/content') { + partType = 'text'; + } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) { + partType = newType; + } else if (!pathValue && thinkingEnabled) { + partType = newType; + } + + const val = chunk.v; + if (typeof val === 'string') { + if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) { + return { parts: [], finished: true, newType }; + } + if (val) { + parts.push({ text: val, type: partType }); + } + return { parts, finished: false, newType }; + } + + if (Array.isArray(val)) { + const extracted = extractContentRecursive(val, partType); + if (extracted.finished) { + return { parts: [], finished: true, newType }; + } + parts.push(...extracted.parts); + return { parts, finished: false, newType }; + } + + if (val && typeof val === 'object') { + const resp = val.response && typeof val.response === 'object' ? val.response : val; + if (Array.isArray(resp.fragments)) { + for (const frag of resp.fragments) { + if (!frag || typeof frag !== 'object') { + continue; + } + const content = asString(frag.content); + if (!content) { + continue; + } + const t = asString(frag.type).toUpperCase(); + if (t === 'THINK' || t === 'THINKING') { + newType = 'thinking'; + parts.push({ text: content, type: 'thinking' }); + } else if (t === 'RESPONSE') { + newType = 'text'; + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: partType }); + } + } + } + } + return { parts, finished: false, newType }; +} + +function extractContentRecursive(items, defaultType) { + const parts = []; + for (const it of items) { + if (!it || typeof it !== 'object') { + continue; + } + if (!Object.prototype.hasOwnProperty.call(it, 'v')) { + continue; + } + const itemPath = asString(it.p); + const itemV = it.v; + if (itemPath === 'status' && asString(itemV) === 'FINISHED') { + return { parts: [], finished: true }; + } + if (shouldSkipPath(itemPath)) { + continue; + } + const content = asString(it.content); + if (content) { + const typeName = asString(it.type).toUpperCase(); + if (typeName === 'THINK' || typeName === 'THINKING') { + parts.push({ text: content, type: 'thinking' }); + } else if (typeName === 'RESPONSE') { + parts.push({ text: content, type: 'text' }); + } else { + parts.push({ text: content, type: defaultType }); + } + continue; + } + + let partType = defaultType; + if (itemPath.includes('thinking')) { + partType = 'thinking'; + } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') { + partType = 'text'; + } + + if (typeof itemV === 'string') { + if (itemV && itemV !== 'FINISHED') { + parts.push({ text: itemV, type: partType }); + } + continue; + } + + if (!Array.isArray(itemV)) { + continue; + } + for (const inner of itemV) { + if (typeof inner === 'string') { + if (inner) { + parts.push({ text: inner, type: partType }); + } + continue; + } + if (!inner || typeof inner !== 'object') { + continue; + } + const ct = asString(inner.content); + if (!ct) { + continue; + } + const typeName = asString(inner.type).toUpperCase(); + if (typeName === 'THINK' || typeName === 'THINKING') { + parts.push({ text: ct, type: 'thinking' }); + } else if (typeName === 'RESPONSE') { + parts.push({ text: ct, type: 'text' }); + } else { + parts.push({ text: ct, type: partType }); + } + } + } + return { parts, finished: false }; +} + +function shouldSkipPath(pathValue) { + if (SKIP_EXACT_PATHS.has(pathValue)) { + return true; + } + for (const p of SKIP_PATTERNS) { + if (pathValue.includes(p)) { + return true; + } + } + return false; +} + +function isCitation(text) { + return asString(text).trim().startsWith('[citation:'); +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + parseChunkForContent, + extractContentRecursive, + shouldSkipPath, + isCitation, +}; diff --git a/api/chat-stream/stream_emitter.js b/api/chat-stream/stream_emitter.js new file mode 100644 index 0000000..442c24e --- /dev/null +++ b/api/chat-stream/stream_emitter.js @@ -0,0 +1,39 @@ +'use strict'; + +function createChatCompletionEmitter({ res, sessionID, created, model, isClosed }) { + let firstChunkSent = false; + + const sendFrame = (obj) => { + if (isClosed() || res.writableEnded || res.destroyed) { + return; + } + res.write(`data: ${JSON.stringify(obj)}\n\n`); + if (typeof res.flush === 'function') { + res.flush(); + } + }; + + const sendDeltaFrame = (delta) => { + const payloadDelta = { ...delta }; + if (!firstChunkSent) { + payloadDelta.role = 'assistant'; + firstChunkSent = true; + } + sendFrame({ + id: sessionID, + object: 'chat.completion.chunk', + created, + model, + choices: [{ delta: payloadDelta, index: 0 }], + }); + }; + + return { + sendFrame, + sendDeltaFrame, + }; +} + +module.exports = { + createChatCompletionEmitter, +}; diff --git a/api/chat-stream/token_usage.js b/api/chat-stream/token_usage.js new file mode 100644 index 0000000..57a36fb --- /dev/null +++ b/api/chat-stream/token_usage.js @@ -0,0 +1,51 @@ +'use strict'; + +function buildUsage(prompt, thinking, output) { + const promptTokens = estimateTokens(prompt); + const reasoningTokens = estimateTokens(thinking); + const completionTokens = estimateTokens(output); + return { + prompt_tokens: promptTokens, + completion_tokens: reasoningTokens + completionTokens, + total_tokens: promptTokens + reasoningTokens + completionTokens, + completion_tokens_details: { + reasoning_tokens: reasoningTokens, + }, + }; +} + +function estimateTokens(text) { + const t = asString(text); + if (!t) { + return 0; + } + let asciiChars = 0; + let nonASCIIChars = 0; + for (const ch of Array.from(t)) { + if (ch.charCodeAt(0) < 128) { + asciiChars += 1; + } else { + nonASCIIChars += 1; + } + } + const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13); + return n < 1 ? 1 : n; +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + buildUsage, + estimateTokens, +}; diff --git a/api/chat-stream/toolcall_policy.js b/api/chat-stream/toolcall_policy.js new file mode 100644 index 0000000..4f4b37c --- /dev/null +++ b/api/chat-stream/toolcall_policy.js @@ -0,0 +1,107 @@ +'use strict'; + +const crypto = require('crypto'); + +const { + extractToolNames, +} = require('../helpers/stream-tool-sieve'); + +function resolveToolcallPolicy(prepBody, payloadTools) { + const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names); + const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools); + const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match); + const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high); + return { + toolNames, + toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled, + emitEarlyToolDeltas, + }; +} + +function normalizePreparedToolNames(v) { + if (!Array.isArray(v) || v.length === 0) { + return []; + } + const out = []; + for (const item of v) { + const name = asString(item); + if (!name) { + continue; + } + out.push(name); + } + return out; +} + +function boolDefaultTrue(v) { + return v !== false; +} + +function formatIncrementalToolCallDeltas(deltas, idStore) { + if (!Array.isArray(deltas) || deltas.length === 0) { + return []; + } + const out = []; + for (const d of deltas) { + if (!d || typeof d !== 'object') { + continue; + } + const index = Number.isInteger(d.index) ? d.index : 0; + const id = ensureStreamToolCallID(idStore, index); + const item = { + index, + id, + type: 'function', + }; + const fn = {}; + if (asString(d.name)) { + fn.name = asString(d.name); + } + if (typeof d.arguments === 'string' && d.arguments !== '') { + fn.arguments = d.arguments; + } + if (Object.keys(fn).length > 0) { + item.function = fn; + } + out.push(item); + } + return out; +} + +function ensureStreamToolCallID(idStore, index) { + const key = Number.isInteger(index) ? index : 0; + const existing = idStore.get(key); + if (existing) { + return existing; + } + const next = `call_${newCallID()}`; + idStore.set(key, next); + return next; +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + +function asString(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return asString(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + resolveToolcallPolicy, + normalizePreparedToolNames, + boolDefaultTrue, + formatIncrementalToolCallDeltas, +}; diff --git a/api/chat-stream/vercel_stream.js b/api/chat-stream/vercel_stream.js new file mode 100644 index 0000000..324a3d8 --- /dev/null +++ b/api/chat-stream/vercel_stream.js @@ -0,0 +1,297 @@ +'use strict'; + +const { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + parseToolCalls, + formatOpenAIStreamToolCalls, +} = require('../helpers/stream-tool-sieve'); +const { + BASE_HEADERS, +} = require('../shared/deepseek-constants'); + +const { + writeOpenAIError, +} = require('./error_shape'); +const { + parseChunkForContent, + isCitation, +} = require('./sse_parse'); +const { + buildUsage, +} = require('./token_usage'); +const { + resolveToolcallPolicy, + formatIncrementalToolCallDeltas, +} = require('./toolcall_policy'); +const { + createChatCompletionEmitter, +} = require('./stream_emitter'); +const { + asString, + isAbortError, + fetchStreamPrepare, + relayPreparedFailure, + safeReadText, + createLeaseReleaser, +} = require('./http_internal'); + +const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion'; + +async function handleVercelStream(req, res, rawBody, payload) { + const prep = await fetchStreamPrepare(req, rawBody); + if (!prep.ok) { + relayPreparedFailure(res, prep); + return; + } + + const model = asString(prep.body.model) || asString(payload.model); + const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`; + const leaseID = asString(prep.body.lease_id); + const deepseekToken = asString(prep.body.deepseek_token); + const powHeader = asString(prep.body.pow_header); + const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null; + const finalPrompt = asString(prep.body.final_prompt); + const thinkingEnabled = toBool(prep.body.thinking_enabled); + const searchEnabled = toBool(prep.body.search_enabled); + const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools); + const toolNames = toolPolicy.toolNames; + + if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) { + writeOpenAIError(res, 500, 'invalid vercel prepare response'); + return; + } + + const releaseLease = createLeaseReleaser(req, leaseID); + const upstreamController = new AbortController(); + let clientClosed = false; + let reader = null; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + upstreamController.abort(); + if (reader && typeof reader.cancel === 'function') { + Promise.resolve(reader.cancel()).catch(() => {}); + } + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); + + try { + let completionRes; + try { + completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { + method: 'POST', + headers: { + ...BASE_HEADERS, + authorization: `Bearer ${deepseekToken}`, + 'x-ds-pow-response': powHeader, + }, + body: JSON.stringify(completionPayload), + signal: upstreamController.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + return; + } + throw err; + } + if (clientClosed) { + return; + } + + if (!completionRes.ok || !completionRes.body) { + const detail = await safeReadText(completionRes); + writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.'); + return; + } + + res.statusCode = 200; + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache, no-transform'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + if (typeof res.flushHeaders === 'function') { + res.flushHeaders(); + } + + const created = Math.floor(Date.now() / 1000); + let currentType = thinkingEnabled ? 'thinking' : 'text'; + let thinkingText = ''; + let outputText = ''; + const toolSieveEnabled = toolPolicy.toolSieveEnabled; + const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas; + const toolSieveState = createToolSieveState(); + let toolCallsEmitted = false; + const streamToolCallIDs = new Map(); + const decoder = new TextDecoder(); + reader = completionRes.body.getReader(); + let buffered = ''; + let ended = false; + const { sendFrame, sendDeltaFrame } = createChatCompletionEmitter({ + res, + sessionID, + created, + model, + isClosed: () => clientClosed, + }); + + const finish = async (reason) => { + if (ended) { + return; + } + ended = true; + if (clientClosed || res.writableEnded || res.destroyed) { + await releaseLease(); + return; + } + const detected = parseToolCalls(outputText, toolNames); + if (detected.length > 0 && !toolCallsEmitted) { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); + } else if (toolSieveEnabled) { + const tailEvents = flushToolSieve(toolSieveState, toolNames); + for (const evt of tailEvents) { + if (evt.text) { + sendDeltaFrame({ content: evt.text }); + } + } + } + if (detected.length > 0 || toolCallsEmitted) { + reason = 'tool_calls'; + } + sendFrame({ + id: sessionID, + object: 'chat.completion.chunk', + created, + model, + choices: [{ delta: {}, index: 0, finish_reason: reason }], + usage: buildUsage(finalPrompt, thinkingText, outputText), + }); + if (!res.writableEnded && !res.destroyed) { + res.write('data: [DONE]\n\n'); + } + await releaseLease(); + if (!res.writableEnded && !res.destroyed) { + res.end(); + } + }; + + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + await finish('stop'); + return; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + buffered += decoder.decode(value, { stream: true }); + const lines = buffered.split('\n'); + buffered = lines.pop() || ''; + + for (const rawLine of lines) { + const line = rawLine.trim(); + if (!line.startsWith('data:')) { + continue; + } + const dataStr = line.slice(5).trim(); + if (!dataStr) { + continue; + } + if (dataStr === '[DONE]') { + await finish('stop'); + return; + } + let chunk; + try { + chunk = JSON.parse(dataStr); + } catch (_err) { + continue; + } + if (chunk.error || chunk.code === 'content_filter') { + await finish('content_filter'); + return; + } + const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType); + currentType = parsed.newType; + if (parsed.finished) { + await finish('stop'); + return; + } + + for (const p of parsed.parts) { + if (!p.text) { + continue; + } + if (searchEnabled && isCitation(p.text)) { + continue; + } + if (p.type === 'thinking') { + if (thinkingEnabled) { + thinkingText += p.text; + sendDeltaFrame({ reasoning_content: p.text }); + } + } else { + outputText += p.text; + if (!toolSieveEnabled) { + sendDeltaFrame({ content: p.text }); + continue; + } + const events = processToolSieveChunk(toolSieveState, p.text, toolNames); + for (const evt of events) { + if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + if (!emitEarlyToolDeltas) { + continue; + } + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); + continue; + } + if (evt.type === 'tool_calls') { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); + continue; + } + if (evt.text) { + sendDeltaFrame({ content: evt.text }); + } + } + } + } + } + } + await finish('stop'); + } catch (err) { + if (clientClosed || isAbortError(err)) { + await finish('stop'); + return; + } + await finish('stop'); + } + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); + await releaseLease(); + } +} + +function toBool(v) { + return v === true; +} + +module.exports = { + handleVercelStream, +}; diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 44e31cd..8985478 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -1,957 +1,3 @@ 'use strict'; -const crypto = require('crypto'); -const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; -const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; -const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; - -function extractToolNames(tools) { - if (!Array.isArray(tools) || tools.length === 0) { - return []; - } - const out = []; - for (const t of tools) { - if (!t || typeof t !== 'object') { - continue; - } - const fn = t.function && typeof t.function === 'object' ? t.function : t; - const name = toStringSafe(fn.name); - // Keep parity with Go injectToolPrompt: object tools without name still - // enter tool mode via fallback name "unknown". - out.push(name || 'unknown'); - } - return out; -} - -function createToolSieveState() { - return { - pending: '', - capture: '', - capturing: false, - recentTextTail: '', - toolNameSent: false, - toolName: '', - toolArgsStart: -1, - toolArgsSent: -1, - toolArgsString: false, - toolArgsDone: false, - }; -} - -function resetIncrementalToolState(state) { - state.toolNameSent = false; - state.toolName = ''; - state.toolArgsStart = -1; - state.toolArgsSent = -1; - state.toolArgsString = false; - state.toolArgsDone = false; -} - -function processToolSieveChunk(state, chunk, toolNames) { - if (!state) { - return []; - } - if (chunk) { - state.pending += chunk; - } - const events = []; - // eslint-disable-next-line no-constant-condition - while (true) { - if (state.capturing) { - if (state.pending) { - state.capture += state.pending; - state.pending = ''; - } - const deltas = buildIncrementalToolDeltas(state); - if (deltas.length > 0) { - events.push({ type: 'tool_call_deltas', deltas }); - } - const consumed = consumeToolCapture(state, toolNames); - if (!consumed.ready) { - if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { - noteText(state, state.capture); - events.push({ type: 'text', text: state.capture }); - state.capture = ''; - state.capturing = false; - resetIncrementalToolState(state); - continue; - } - break; - } - state.capture = ''; - state.capturing = false; - resetIncrementalToolState(state); - if (consumed.prefix) { - noteText(state, consumed.prefix); - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - state.pending += consumed.suffix; - } - continue; - } - - if (!state.pending) { - break; - } - - const start = findToolSegmentStart(state.pending); - if (start >= 0) { - const prefix = state.pending.slice(0, start); - if (prefix) { - noteText(state, prefix); - events.push({ type: 'text', text: prefix }); - } - state.capture = state.pending.slice(start); - state.pending = ''; - state.capturing = true; - resetIncrementalToolState(state); - continue; - } - - const [safe, hold] = splitSafeContentForToolDetection(state.pending); - if (!safe) { - break; - } - state.pending = hold; - noteText(state, safe); - events.push({ type: 'text', text: safe }); - } - return events; -} - -function flushToolSieve(state, toolNames) { - if (!state) { - return []; - } - const events = processToolSieveChunk(state, '', toolNames); - if (state.capturing) { - const consumed = consumeToolCapture(state, toolNames); - if (consumed.ready) { - if (consumed.prefix) { - noteText(state, consumed.prefix); - events.push({ type: 'text', text: consumed.prefix }); - } - if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { - events.push({ type: 'tool_calls', calls: consumed.calls }); - } - if (consumed.suffix) { - noteText(state, consumed.suffix); - events.push({ type: 'text', text: consumed.suffix }); - } - } else if (state.capture) { - noteText(state, state.capture); - events.push({ type: 'text', text: state.capture }); - } - state.capture = ''; - state.capturing = false; - resetIncrementalToolState(state); - } - if (state.pending) { - noteText(state, state.pending); - events.push({ type: 'text', text: state.pending }); - state.pending = ''; - } - return events; -} - -function splitSafeContentForToolDetection(s) { - const text = s || ''; - if (!text) { - return ['', '']; - } - const suspiciousStart = findSuspiciousPrefixStart(text); - if (suspiciousStart < 0) { - return [text, '']; - } - if (suspiciousStart > 0) { - return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; - } - // If suspicious content starts at the beginning, keep holding until we can - // either parse a full tool JSON block or reach stream flush. - return ['', text]; -} - -function findSuspiciousPrefixStart(s) { - let start = -1; - for (const needle of ['{', '[', '```']) { - const idx = s.lastIndexOf(needle); - if (idx > start) { - start = idx; - } - } - return start; -} - -function findToolSegmentStart(s) { - if (!s) { - return -1; - } - const lower = s.toLowerCase(); - let offset = 0; - // eslint-disable-next-line no-constant-condition - while (true) { - const keyRel = lower.indexOf('tool_calls', offset); - if (keyRel < 0) { - return -1; - } - const keyIdx = keyRel; - const start = s.slice(0, keyIdx).lastIndexOf('{'); - const candidateStart = start >= 0 ? start : keyIdx; - if (!insideCodeFence(s.slice(0, candidateStart))) { - return candidateStart; - } - offset = keyIdx + 'tool_calls'.length; - } -} - -function consumeToolCapture(state, toolNames) { - const captured = state.capture; - if (!captured) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const lower = captured.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const obj = extractJSONObjectFrom(captured, start); - if (!obj.ok) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const prefixPart = captured.slice(0, start); - const suffixPart = captured.slice(obj.end); - if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { - return { - ready: true, - prefix: captured, - calls: [], - suffix: '', - }; - } - const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - if (state.toolNameSent) { - return { - ready: true, - prefix: prefixPart, - calls: [], - suffix: suffixPart, - }; - } - return { - ready: true, - prefix: captured, - calls: [], - suffix: '', - }; - } - if (state.toolNameSent) { - if (parsed.length > 1) { - return { - ready: true, - prefix: prefixPart, - calls: parsed.slice(1), - suffix: suffixPart, - }; - } - return { - ready: true, - prefix: prefixPart, - calls: [], - suffix: suffixPart, - }; - } - return { - ready: true, - prefix: prefixPart, - calls: parsed, - suffix: suffixPart, - }; -} - -function buildIncrementalToolDeltas(state) { - const captured = state.capture || ''; - if (!captured) { - return []; - } - if (looksLikeToolExampleContext(state.recentTextTail)) { - return []; - } - const lower = captured.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); - if (keyIdx < 0) { - return []; - } - const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0) { - return []; - } - if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) { - return []; - } - const callStart = findFirstToolCallObjectStart(captured, keyIdx); - if (callStart < 0) { - return []; - } - - const deltas = []; - if (!state.toolName) { - const name = extractToolCallName(captured, callStart); - if (!name) { - return []; - } - state.toolName = name; - } - - if (state.toolArgsStart < 0) { - const args = findToolCallArgsStart(captured, callStart); - if (args) { - state.toolArgsString = Boolean(args.stringMode); - state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; - state.toolArgsSent = state.toolArgsStart; - } - } - if (!state.toolNameSent) { - if (state.toolArgsStart < 0) { - return []; - } - state.toolNameSent = true; - deltas.push({ index: 0, name: state.toolName }); - } - if (state.toolArgsStart < 0 || state.toolArgsDone) { - return deltas; - } - const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); - if (!progress) { - return deltas; - } - if (progress.end > state.toolArgsSent) { - deltas.push({ - index: 0, - arguments: captured.slice(state.toolArgsSent, progress.end), - }); - state.toolArgsSent = progress.end; - } - if (progress.complete) { - state.toolArgsDone = true; - } - return deltas; -} - -function findFirstToolCallObjectStart(text, keyIdx) { - const arrStart = findToolCallsArrayStart(text, keyIdx); - if (arrStart < 0) { - return -1; - } - const i = skipSpaces(text, arrStart + 1); - if (i >= text.length || text[i] !== '{') { - return -1; - } - return i; -} - -function findToolCallsArrayStart(text, keyIdx) { - let i = keyIdx + 'tool_calls'.length; - while (i < text.length && text[i] !== ':') { - i += 1; - } - if (i >= text.length) { - return -1; - } - i = skipSpaces(text, i + 1); - if (i >= text.length || text[i] !== '[') { - return -1; - } - return i; -} - -function extractToolCallName(text, callStart) { - let valueStart = findObjectFieldValueStart(text, callStart, ['name']); - if (valueStart < 0 || text[valueStart] !== '"') { - const fnStart = findFunctionObjectStart(text, callStart); - if (fnStart < 0) { - return ''; - } - valueStart = findObjectFieldValueStart(text, fnStart, ['name']); - if (valueStart < 0 || text[valueStart] !== '"') { - return ''; - } - } - const parsed = parseJSONStringLiteral(text, valueStart); - if (!parsed) { - return ''; - } - return parsed.value; -} - -function findToolCallArgsStart(text, callStart) { - const keys = ['input', 'arguments', 'args', 'parameters', 'params']; - let valueStart = findObjectFieldValueStart(text, callStart, keys); - if (valueStart < 0) { - const fnStart = findFunctionObjectStart(text, callStart); - if (fnStart < 0) { - return null; - } - valueStart = findObjectFieldValueStart(text, fnStart, keys); - if (valueStart < 0) { - return null; - } - } - if (valueStart >= text.length) { - return null; - } - const ch = text[valueStart]; - if (ch === '{' || ch === '[') { - return { start: valueStart, stringMode: false }; - } - if (ch === '"') { - return { start: valueStart, stringMode: true }; - } - return null; -} - -function scanToolCallArgsProgress(text, start, stringMode) { - if (start < 0 || start > text.length) { - return null; - } - if (stringMode) { - let escaped = false; - for (let i = start; i < text.length; i += 1) { - const ch = text[i]; - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === '"') { - return { end: i, complete: true }; - } - } - return { end: text.length, complete: false }; - } - if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { - return null; - } - let depth = 0; - let quote = ''; - let escaped = false; - for (let i = start; i < text.length; i += 1) { - const ch = text[i]; - if (quote) { - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === quote) { - quote = ''; - } - continue; - } - if (ch === '"' || ch === "'") { - quote = ch; - continue; - } - if (ch === '{' || ch === '[') { - depth += 1; - continue; - } - if (ch === '}' || ch === ']') { - depth -= 1; - if (depth === 0) { - return { end: i + 1, complete: true }; - } - } - } - return { end: text.length, complete: false }; -} - -function findObjectFieldValueStart(text, objStart, keys) { - if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { - return -1; - } - let depth = 0; - let quote = ''; - let escaped = false; - for (let i = objStart; i < text.length; i += 1) { - const ch = text[i]; - if (quote) { - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === quote) { - quote = ''; - } - continue; - } - if (ch === '"' || ch === "'") { - if (depth === 1) { - const parsed = parseJSONStringLiteral(text, i); - if (!parsed) { - return -1; - } - let j = skipSpaces(text, parsed.end); - if (j >= text.length || text[j] !== ':') { - i = parsed.end - 1; - continue; - } - j = skipSpaces(text, j + 1); - if (j >= text.length) { - return -1; - } - if (keys.includes(parsed.value)) { - return j; - } - i = j - 1; - continue; - } - quote = ch; - continue; - } - if (ch === '{') { - depth += 1; - continue; - } - if (ch === '}') { - depth -= 1; - if (depth === 0) { - break; - } - } - } - return -1; -} - -function findFunctionObjectStart(text, callStart) { - const valueStart = findObjectFieldValueStart(text, callStart, ['function']); - if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { - return -1; - } - return valueStart; -} - -function parseJSONStringLiteral(text, start) { - if (!text || start < 0 || start >= text.length || text[start] !== '"') { - return null; - } - let out = ''; - let escaped = false; - for (let i = start + 1; i < text.length; i += 1) { - const ch = text[i]; - if (escaped) { - out += ch; - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === '"') { - return { value: out, end: i + 1 }; - } - out += ch; - } - return null; -} - -function skipSpaces(text, i) { - let idx = i; - while (idx < text.length) { - const ch = text[idx]; - if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { - idx += 1; - continue; - } - break; - } - return idx; -} - -function extractJSONObjectFrom(text, start) { - if (!text || start < 0 || start >= text.length || text[start] !== '{') { - return { ok: false, end: 0 }; - } - let depth = 0; - let quote = ''; - let escaped = false; - for (let i = start; i < text.length; i += 1) { - const ch = text[i]; - if (quote) { - if (escaped) { - escaped = false; - continue; - } - if (ch === '\\') { - escaped = true; - continue; - } - if (ch === quote) { - quote = ''; - } - continue; - } - if (ch === '"' || ch === "'") { - quote = ch; - continue; - } - if (ch === '{') { - depth += 1; - continue; - } - if (ch === '}') { - depth -= 1; - if (depth === 0) { - return { ok: true, end: i + 1 }; - } - } - } - return { ok: false, end: 0 }; -} - -function parseToolCalls(text, toolNames) { - if (!toStringSafe(text)) { - return []; - } - const sanitized = stripFencedCodeBlocks(text); - if (!toStringSafe(sanitized)) { - return []; - } - const candidates = buildToolCallCandidates(sanitized); - let parsed = []; - for (const c of candidates) { - parsed = parseToolCallsPayload(c); - if (parsed.length > 0) { - break; - } - } - if (parsed.length === 0) { - return []; - } - return filterToolCalls(parsed, toolNames); -} - -function stripFencedCodeBlocks(text) { - const t = typeof text === 'string' ? text : ''; - if (!t) { - return ''; - } - return t.replace(/```[\s\S]*?```/g, ' '); -} - -function parseStandaloneToolCalls(text, toolNames) { - const trimmed = toStringSafe(text); - if (!trimmed) { - return []; - } - if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) { - return []; - } - if (looksLikeToolExampleContext(trimmed)) { - return []; - } - const candidates = [trimmed]; - if (trimmed.startsWith('```') && trimmed.endsWith('```')) { - const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); - if (m && m[1]) { - candidates.push(toStringSafe(m[1])); - } - } - for (const candidate of candidates) { - const c = toStringSafe(candidate); - if (!c) { - continue; - } - if (!c.startsWith('{') && !c.startsWith('[')) { - continue; - } - const parsed = parseToolCallsPayload(c); - if (parsed.length > 0) { - return filterToolCalls(parsed, toolNames); - } - } - return []; -} - -function buildToolCallCandidates(text) { - const trimmed = toStringSafe(text); - const candidates = [trimmed]; - const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; - for (const block of fenced) { - const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); - if (m && m[1]) { - candidates.push(toStringSafe(m[1])); - } - } - for (const candidate of extractToolCallObjects(trimmed)) { - candidates.push(toStringSafe(candidate)); - } - const first = trimmed.indexOf('{'); - const last = trimmed.lastIndexOf('}'); - if (first >= 0 && last > first) { - candidates.push(toStringSafe(trimmed.slice(first, last + 1))); - } - const m = trimmed.match(TOOL_CALL_PATTERN); - if (m && m[1]) { - candidates.push(`{"tool_calls":[${m[1]}]}`); - } - return [...new Set(candidates.filter(Boolean))]; -} - -function extractToolCallObjects(text) { - const raw = toStringSafe(text); - if (!raw) { - return []; - } - const lower = raw.toLowerCase(); - const out = []; - let offset = 0; - // eslint-disable-next-line no-constant-condition - while (true) { - let idx = lower.indexOf('tool_calls', offset); - if (idx < 0) { - break; - } - let start = raw.slice(0, idx).lastIndexOf('{'); - while (start >= 0) { - const obj = extractJSONObjectFrom(raw, start); - if (obj.ok) { - out.push(raw.slice(start, obj.end).trim()); - offset = obj.end; - idx = -1; - break; - } - start = raw.slice(0, start).lastIndexOf('{'); - } - if (idx >= 0) { - offset = idx + 'tool_calls'.length; - } - } - return out; -} - -function parseToolCallsPayload(payload) { - let decoded; - try { - decoded = JSON.parse(payload); - } catch (_err) { - return []; - } - if (Array.isArray(decoded)) { - return parseToolCallList(decoded); - } - if (!decoded || typeof decoded !== 'object') { - return []; - } - if (decoded.tool_calls) { - return parseToolCallList(decoded.tool_calls); - } - const one = parseToolCallItem(decoded); - return one ? [one] : []; -} - -function parseToolCallList(v) { - if (!Array.isArray(v)) { - return []; - } - const out = []; - for (const item of v) { - if (!item || typeof item !== 'object') { - continue; - } - const one = parseToolCallItem(item); - if (one) { - out.push(one); - } - } - return out; -} - -function parseToolCallItem(m) { - let name = toStringSafe(m.name); - let inputRaw = m.input; - let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); - const fn = m.function && typeof m.function === 'object' ? m.function : null; - if (fn) { - if (!name) { - name = toStringSafe(fn.name); - } - if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { - inputRaw = fn.arguments; - hasInput = true; - } - } - if (!hasInput) { - for (const k of ['arguments', 'args', 'parameters', 'params']) { - if (Object.prototype.hasOwnProperty.call(m, k)) { - inputRaw = m[k]; - hasInput = true; - break; - } - } - } - if (!name) { - return null; - } - return { - name, - input: parseToolCallInput(inputRaw), - }; -} - -function parseToolCallInput(v) { - if (v == null) { - return {}; - } - if (typeof v === 'string') { - const raw = toStringSafe(v); - if (!raw) { - return {}; - } - try { - const parsed = JSON.parse(raw); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - return { _raw: raw }; - } catch (_err) { - return { _raw: raw }; - } - } - if (typeof v === 'object' && !Array.isArray(v)) { - return v; - } - try { - const parsed = JSON.parse(JSON.stringify(v)); - if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { - return parsed; - } - } catch (_err) { - return {}; - } - return {}; -} - -function filterToolCalls(parsed, toolNames) { - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); - } - } - return out; -} - -function noteText(state, text) { - if (!state || !hasMeaningfulText(text)) { - return; - } - state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); -} - -function appendTail(prev, next, max) { - const left = typeof prev === 'string' ? prev : ''; - const right = typeof next === 'string' ? next : ''; - if (!Number.isFinite(max) || max <= 0) { - return ''; - } - const combined = left + right; - if (combined.length <= max) { - return combined; - } - return combined.slice(combined.length - max); -} - -function looksLikeToolExampleContext(text) { - return insideCodeFence(text); -} - -function insideCodeFence(text) { - const t = typeof text === 'string' ? text : ''; - if (!t) { - return false; - } - const ticks = (t.match(/```/g) || []).length; - return ticks % 2 === 1; -} - -function hasMeaningfulText(text) { - return toStringSafe(text) !== ''; -} - -function formatOpenAIStreamToolCalls(calls) { - if (!Array.isArray(calls) || calls.length === 0) { - return []; - } - return calls.map((c, idx) => ({ - index: idx, - id: `call_${newCallID()}`, - type: 'function', - function: { - name: c.name, - arguments: JSON.stringify(c.input || {}), - }, - })); -} - -function newCallID() { - if (typeof crypto.randomUUID === 'function') { - return crypto.randomUUID().replace(/-/g, ''); - } - return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; -} - -function toStringSafe(v) { - if (typeof v === 'string') { - return v.trim(); - } - if (Array.isArray(v)) { - return toStringSafe(v[0]); - } - if (v == null) { - return ''; - } - return String(v).trim(); -} - -module.exports = { - extractToolNames, - createToolSieveState, - processToolSieveChunk, - flushToolSieve, - parseToolCalls, - parseStandaloneToolCalls, - formatOpenAIStreamToolCalls, -}; +module.exports = require('./stream-tool-sieve/index.js'); diff --git a/api/helpers/stream-tool-sieve/format.js b/api/helpers/stream-tool-sieve/format.js new file mode 100644 index 0000000..ff1dcef --- /dev/null +++ b/api/helpers/stream-tool-sieve/format.js @@ -0,0 +1,29 @@ +'use strict'; + +const crypto = require('crypto'); + +function formatOpenAIStreamToolCalls(calls) { + if (!Array.isArray(calls) || calls.length === 0) { + return []; + } + return calls.map((c, idx) => ({ + index: idx, + id: `call_${newCallID()}`, + type: 'function', + function: { + name: c.name, + arguments: JSON.stringify(c.input || {}), + }, + })); +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + +module.exports = { + formatOpenAIStreamToolCalls, +}; diff --git a/api/helpers/stream-tool-sieve/incremental.js b/api/helpers/stream-tool-sieve/incremental.js new file mode 100644 index 0000000..1895075 --- /dev/null +++ b/api/helpers/stream-tool-sieve/incremental.js @@ -0,0 +1,226 @@ +'use strict'; + +const { + looksLikeToolExampleContext, + insideCodeFence, +} = require('./state'); +const { + findObjectFieldValueStart, + parseJSONStringLiteral, + skipSpaces, +} = require('./jsonscan'); + +function buildIncrementalToolDeltas(state) { + const captured = state.capture || ''; + if (!captured) { + return []; + } + if (looksLikeToolExampleContext(state.recentTextTail)) { + return []; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return []; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + return []; + } + if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) { + return []; + } + const callStart = findFirstToolCallObjectStart(captured, keyIdx); + if (callStart < 0) { + return []; + } + + const deltas = []; + if (!state.toolName) { + const name = extractToolCallName(captured, callStart); + if (!name) { + return []; + } + state.toolName = name; + } + + if (state.toolArgsStart < 0) { + const args = findToolCallArgsStart(captured, callStart); + if (args) { + state.toolArgsString = Boolean(args.stringMode); + state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; + state.toolArgsSent = state.toolArgsStart; + } + } + if (!state.toolNameSent) { + if (state.toolArgsStart < 0) { + return []; + } + state.toolNameSent = true; + deltas.push({ index: 0, name: state.toolName }); + } + if (state.toolArgsStart < 0 || state.toolArgsDone) { + return deltas; + } + const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); + if (!progress) { + return deltas; + } + if (progress.end > state.toolArgsSent) { + deltas.push({ + index: 0, + arguments: captured.slice(state.toolArgsSent, progress.end), + }); + state.toolArgsSent = progress.end; + } + if (progress.complete) { + state.toolArgsDone = true; + } + return deltas; +} + +function findFirstToolCallObjectStart(text, keyIdx) { + const arrStart = findToolCallsArrayStart(text, keyIdx); + if (arrStart < 0) { + return -1; + } + const i = skipSpaces(text, arrStart + 1); + if (i >= text.length || text[i] !== '{') { + return -1; + } + return i; +} + +function findToolCallsArrayStart(text, keyIdx) { + let i = keyIdx + 'tool_calls'.length; + while (i < text.length && text[i] !== ':') { + i += 1; + } + if (i >= text.length) { + return -1; + } + i = skipSpaces(text, i + 1); + if (i >= text.length || text[i] !== '[') { + return -1; + } + return i; +} + +function extractToolCallName(text, callStart) { + let valueStart = findObjectFieldValueStart(text, callStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return ''; + } + valueStart = findObjectFieldValueStart(text, fnStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + return ''; + } + } + const parsed = parseJSONStringLiteral(text, valueStart); + if (!parsed) { + return ''; + } + return parsed.value; +} + +function findToolCallArgsStart(text, callStart) { + const keys = ['input', 'arguments', 'args', 'parameters', 'params']; + let valueStart = findObjectFieldValueStart(text, callStart, keys); + if (valueStart < 0) { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return null; + } + valueStart = findObjectFieldValueStart(text, fnStart, keys); + if (valueStart < 0) { + return null; + } + } + if (valueStart >= text.length) { + return null; + } + const ch = text[valueStart]; + if (ch === '{' || ch === '[') { + return { start: valueStart, stringMode: false }; + } + if (ch === '"') { + return { start: valueStart, stringMode: true }; + } + return null; +} + +function scanToolCallArgsProgress(text, start, stringMode) { + if (start < 0 || start > text.length) { + return null; + } + if (stringMode) { + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { end: i, complete: true }; + } + } + return { end: text.length, complete: false }; + } + if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { + return null; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{' || ch === '[') { + depth += 1; + continue; + } + if (ch === '}' || ch === ']') { + depth -= 1; + if (depth === 0) { + return { end: i + 1, complete: true }; + } + } + } + return { end: text.length, complete: false }; +} + +function findFunctionObjectStart(text, callStart) { + const valueStart = findObjectFieldValueStart(text, callStart, ['function']); + if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { + return -1; + } + return valueStart; +} + +module.exports = { + buildIncrementalToolDeltas, +}; diff --git a/api/helpers/stream-tool-sieve/index.js b/api/helpers/stream-tool-sieve/index.js new file mode 100644 index 0000000..f218b52 --- /dev/null +++ b/api/helpers/stream-tool-sieve/index.js @@ -0,0 +1,27 @@ +'use strict'; + +const { + createToolSieveState, +} = require('./state'); +const { + processToolSieveChunk, + flushToolSieve, +} = require('./sieve'); +const { + extractToolNames, + parseToolCalls, + parseStandaloneToolCalls, +} = require('./parse'); +const { + formatOpenAIStreamToolCalls, +} = require('./format'); + +module.exports = { + extractToolNames, + createToolSieveState, + processToolSieveChunk, + flushToolSieve, + parseToolCalls, + parseStandaloneToolCalls, + formatOpenAIStreamToolCalls, +}; diff --git a/api/helpers/stream-tool-sieve/jsonscan.js b/api/helpers/stream-tool-sieve/jsonscan.js new file mode 100644 index 0000000..a86ed05 --- /dev/null +++ b/api/helpers/stream-tool-sieve/jsonscan.js @@ -0,0 +1,148 @@ +'use strict'; + +function findObjectFieldValueStart(text, objStart, keys) { + if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { + return -1; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = objStart; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + if (depth === 1) { + const parsed = parseJSONStringLiteral(text, i); + if (!parsed) { + return -1; + } + let j = skipSpaces(text, parsed.end); + if (j >= text.length || text[j] !== ':') { + i = parsed.end - 1; + continue; + } + j = skipSpaces(text, j + 1); + if (j >= text.length) { + return -1; + } + if (keys.includes(parsed.value)) { + return j; + } + i = j - 1; + continue; + } + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + break; + } + } + } + return -1; +} + +function parseJSONStringLiteral(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '"') { + return null; + } + let out = ''; + let escaped = false; + for (let i = start + 1; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + out += ch; + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { value: out, end: i + 1 }; + } + out += ch; + } + return null; +} + +function skipSpaces(text, i) { + let idx = i; + while (idx < text.length) { + const ch = text[idx]; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { + idx += 1; + continue; + } + break; + } + return idx; +} + +function extractJSONObjectFrom(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '{') { + return { ok: false, end: 0 }; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + return { ok: true, end: i + 1 }; + } + } + } + return { ok: false, end: 0 }; +} + +module.exports = { + findObjectFieldValueStart, + parseJSONStringLiteral, + skipSpaces, + extractJSONObjectFrom, +}; diff --git a/api/helpers/stream-tool-sieve/parse.js b/api/helpers/stream-tool-sieve/parse.js new file mode 100644 index 0000000..def46db --- /dev/null +++ b/api/helpers/stream-tool-sieve/parse.js @@ -0,0 +1,281 @@ +'use strict'; + +const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; + +const { + toStringSafe, + looksLikeToolExampleContext, +} = require('./state'); +const { + extractJSONObjectFrom, +} = require('./jsonscan'); + +function extractToolNames(tools) { + if (!Array.isArray(tools) || tools.length === 0) { + return []; + } + const out = []; + for (const t of tools) { + if (!t || typeof t !== 'object') { + continue; + } + const fn = t.function && typeof t.function === 'object' ? t.function : t; + const name = toStringSafe(fn.name); + // Keep parity with Go injectToolPrompt: object tools without name still + // enter tool mode via fallback name "unknown". + out.push(name || 'unknown'); + } + return out; +} + +function parseToolCalls(text, toolNames) { + if (!toStringSafe(text)) { + return []; + } + const sanitized = stripFencedCodeBlocks(text); + if (!toStringSafe(sanitized)) { + return []; + } + const candidates = buildToolCallCandidates(sanitized); + let parsed = []; + for (const c of candidates) { + parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + break; + } + } + if (parsed.length === 0) { + return []; + } + return filterToolCalls(parsed, toolNames); +} + +function stripFencedCodeBlocks(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return ''; + } + return t.replace(/```[\s\S]*?```/g, ' '); +} + +function parseStandaloneToolCalls(text, toolNames) { + const trimmed = toStringSafe(text); + if (!trimmed) { + return []; + } + if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) { + return []; + } + if (looksLikeToolExampleContext(trimmed)) { + return []; + } + const candidates = [trimmed]; + if (trimmed.startsWith('```') && trimmed.endsWith('```')) { + const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); + } + } + for (const candidate of candidates) { + const c = toStringSafe(candidate); + if (!c) { + continue; + } + if (!c.startsWith('{') && !c.startsWith('[')) { + continue; + } + const parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + return filterToolCalls(parsed, toolNames); + } + } + return []; +} + +function buildToolCallCandidates(text) { + const trimmed = toStringSafe(text); + const candidates = [trimmed]; + const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || []; + for (const block of fenced) { + const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); + } + } + for (const candidate of extractToolCallObjects(trimmed)) { + candidates.push(toStringSafe(candidate)); + } + const first = trimmed.indexOf('{'); + const last = trimmed.lastIndexOf('}'); + if (first >= 0 && last > first) { + candidates.push(toStringSafe(trimmed.slice(first, last + 1))); + } + const m = trimmed.match(TOOL_CALL_PATTERN); + if (m && m[1]) { + candidates.push(`{"tool_calls":[${m[1]}]}`); + } + return [...new Set(candidates.filter(Boolean))]; +} + +function extractToolCallObjects(text) { + const raw = toStringSafe(text); + if (!raw) { + return []; + } + const lower = raw.toLowerCase(); + const out = []; + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + let idx = lower.indexOf('tool_calls', offset); + if (idx < 0) { + break; + } + let start = raw.slice(0, idx).lastIndexOf('{'); + while (start >= 0) { + const obj = extractJSONObjectFrom(raw, start); + if (obj.ok) { + out.push(raw.slice(start, obj.end).trim()); + offset = obj.end; + idx = -1; + break; + } + start = raw.slice(0, start).lastIndexOf('{'); + } + if (idx >= 0) { + offset = idx + 'tool_calls'.length; + } + } + return out; +} + +function parseToolCallsPayload(payload) { + let decoded; + try { + decoded = JSON.parse(payload); + } catch (_err) { + return []; + } + if (Array.isArray(decoded)) { + return parseToolCallList(decoded); + } + if (!decoded || typeof decoded !== 'object') { + return []; + } + if (decoded.tool_calls) { + return parseToolCallList(decoded.tool_calls); + } + const one = parseToolCallItem(decoded); + return one ? [one] : []; +} + +function parseToolCallList(v) { + if (!Array.isArray(v)) { + return []; + } + const out = []; + for (const item of v) { + if (!item || typeof item !== 'object') { + continue; + } + const one = parseToolCallItem(item); + if (one) { + out.push(one); + } + } + return out; +} + +function parseToolCallItem(m) { + let name = toStringSafe(m.name); + let inputRaw = m.input; + let hasInput = Object.prototype.hasOwnProperty.call(m, 'input'); + const fn = m.function && typeof m.function === 'object' ? m.function : null; + if (fn) { + if (!name) { + name = toStringSafe(fn.name); + } + if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) { + inputRaw = fn.arguments; + hasInput = true; + } + } + if (!hasInput) { + for (const k of ['arguments', 'args', 'parameters', 'params']) { + if (Object.prototype.hasOwnProperty.call(m, k)) { + inputRaw = m[k]; + hasInput = true; + break; + } + } + } + if (!name) { + return null; + } + return { + name, + input: parseToolCallInput(inputRaw), + }; +} + +function parseToolCallInput(v) { + if (v == null) { + return {}; + } + if (typeof v === 'string') { + const raw = toStringSafe(v); + if (!raw) { + return {}; + } + try { + const parsed = JSON.parse(raw); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + return { _raw: raw }; + } catch (_err) { + return { _raw: raw }; + } + } + if (typeof v === 'object' && !Array.isArray(v)) { + return v; + } + try { + const parsed = JSON.parse(JSON.stringify(v)); + if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) { + return parsed; + } + } catch (_err) { + return {}; + } + return {}; +} + +function filterToolCalls(parsed, toolNames) { + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + if (out.length === 0 && parsed.length > 0) { + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + } + return out; +} + +module.exports = { + extractToolNames, + parseToolCalls, + parseStandaloneToolCalls, +}; diff --git a/api/helpers/stream-tool-sieve/sieve.js b/api/helpers/stream-tool-sieve/sieve.js new file mode 100644 index 0000000..c10e636 --- /dev/null +++ b/api/helpers/stream-tool-sieve/sieve.js @@ -0,0 +1,252 @@ +'use strict'; + +const { + TOOL_SIEVE_CAPTURE_LIMIT, + resetIncrementalToolState, + noteText, + insideCodeFence, +} = require('./state'); +const { + buildIncrementalToolDeltas, +} = require('./incremental'); +const { + parseStandaloneToolCalls, +} = require('./parse'); +const { + extractJSONObjectFrom, +} = require('./jsonscan'); + +function processToolSieveChunk(state, chunk, toolNames) { + if (!state) { + return []; + } + if (chunk) { + state.pending += chunk; + } + const events = []; + // eslint-disable-next-line no-constant-condition + while (true) { + if (state.capturing) { + if (state.pending) { + state.capture += state.pending; + state.pending = ''; + } + const deltas = buildIncrementalToolDeltas(state); + if (deltas.length > 0) { + events.push({ type: 'tool_call_deltas', deltas }); + } + const consumed = consumeToolCapture(state, toolNames); + if (!consumed.ready) { + if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + continue; + } + break; + } + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + if (consumed.prefix) { + noteText(state, consumed.prefix); + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + state.pending += consumed.suffix; + } + continue; + } + + if (!state.pending) { + break; + } + + const start = findToolSegmentStart(state.pending); + if (start >= 0) { + const prefix = state.pending.slice(0, start); + if (prefix) { + noteText(state, prefix); + events.push({ type: 'text', text: prefix }); + } + state.capture = state.pending.slice(start); + state.pending = ''; + state.capturing = true; + resetIncrementalToolState(state); + continue; + } + + const [safe, hold] = splitSafeContentForToolDetection(state.pending); + if (!safe) { + break; + } + state.pending = hold; + noteText(state, safe); + events.push({ type: 'text', text: safe }); + } + return events; +} + +function flushToolSieve(state, toolNames) { + if (!state) { + return []; + } + const events = processToolSieveChunk(state, '', toolNames); + if (state.capturing) { + const consumed = consumeToolCapture(state, toolNames); + if (consumed.ready) { + if (consumed.prefix) { + noteText(state, consumed.prefix); + events.push({ type: 'text', text: consumed.prefix }); + } + if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { + events.push({ type: 'tool_calls', calls: consumed.calls }); + } + if (consumed.suffix) { + noteText(state, consumed.suffix); + events.push({ type: 'text', text: consumed.suffix }); + } + } else if (state.capture) { + noteText(state, state.capture); + events.push({ type: 'text', text: state.capture }); + } + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + } + if (state.pending) { + noteText(state, state.pending); + events.push({ type: 'text', text: state.pending }); + state.pending = ''; + } + return events; +} + +function splitSafeContentForToolDetection(s) { + const text = s || ''; + if (!text) { + return ['', '']; + } + const suspiciousStart = findSuspiciousPrefixStart(text); + if (suspiciousStart < 0) { + return [text, '']; + } + if (suspiciousStart > 0) { + return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)]; + } + // If suspicious content starts at the beginning, keep holding until we can + // either parse a full tool JSON block or reach stream flush. + return ['', text]; +} + +function findSuspiciousPrefixStart(s) { + let start = -1; + for (const needle of ['{', '[', '```']) { + const idx = s.lastIndexOf(needle); + if (idx > start) { + start = idx; + } + } + return start; +} + +function findToolSegmentStart(s) { + if (!s) { + return -1; + } + const lower = s.toLowerCase(); + let offset = 0; + // eslint-disable-next-line no-constant-condition + while (true) { + const keyRel = lower.indexOf('tool_calls', offset); + if (keyRel < 0) { + return -1; + } + const keyIdx = keyRel; + const start = s.slice(0, keyIdx).lastIndexOf('{'); + const candidateStart = start >= 0 ? start : keyIdx; + if (!insideCodeFence(s.slice(0, candidateStart))) { + return candidateStart; + } + offset = keyIdx + 'tool_calls'.length; + } +} + +function consumeToolCapture(state, toolNames) { + const captured = state.capture; + if (!captured) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const obj = extractJSONObjectFrom(captured, start); + if (!obj.ok) { + return { ready: false, prefix: '', calls: [], suffix: '' }; + } + const prefixPart = captured.slice(0, start); + const suffixPart = captured.slice(obj.end); + if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + if (state.toolNameSent) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + if (state.toolNameSent) { + if (parsed.length > 1) { + return { + ready: true, + prefix: prefixPart, + calls: parsed.slice(1), + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: parsed, + suffix: suffixPart, + }; +} + +module.exports = { + processToolSieveChunk, + flushToolSieve, +}; diff --git a/api/helpers/stream-tool-sieve/state.js b/api/helpers/stream-tool-sieve/state.js new file mode 100644 index 0000000..a2d2b5c --- /dev/null +++ b/api/helpers/stream-tool-sieve/state.js @@ -0,0 +1,91 @@ +'use strict'; + +const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; +const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256; + +function createToolSieveState() { + return { + pending: '', + capture: '', + capturing: false, + recentTextTail: '', + toolNameSent: false, + toolName: '', + toolArgsStart: -1, + toolArgsSent: -1, + toolArgsString: false, + toolArgsDone: false, + }; +} + +function resetIncrementalToolState(state) { + state.toolNameSent = false; + state.toolName = ''; + state.toolArgsStart = -1; + state.toolArgsSent = -1; + state.toolArgsString = false; + state.toolArgsDone = false; +} + +function noteText(state, text) { + if (!state || !hasMeaningfulText(text)) { + return; + } + state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); +} + +function appendTail(prev, next, max) { + const left = typeof prev === 'string' ? prev : ''; + const right = typeof next === 'string' ? next : ''; + if (!Number.isFinite(max) || max <= 0) { + return ''; + } + const combined = left + right; + if (combined.length <= max) { + return combined; + } + return combined.slice(combined.length - max); +} + +function looksLikeToolExampleContext(text) { + return insideCodeFence(text); +} + +function insideCodeFence(text) { + const t = typeof text === 'string' ? text : ''; + if (!t) { + return false; + } + const ticks = (t.match(/```/g) || []).length; + return ticks % 2 === 1; +} + +function hasMeaningfulText(text) { + return toStringSafe(text) !== ''; +} + +function toStringSafe(v) { + if (typeof v === 'string') { + return v.trim(); + } + if (Array.isArray(v)) { + return toStringSafe(v[0]); + } + if (v == null) { + return ''; + } + return String(v).trim(); +} + +module.exports = { + TOOL_SIEVE_CAPTURE_LIMIT, + TOOL_SIEVE_CONTEXT_TAIL_LIMIT, + createToolSieveState, + resetIncrementalToolState, + noteText, + appendTail, + looksLikeToolExampleContext, + insideCodeFence, + hasMeaningfulText, + toStringSafe, +}; diff --git a/internal/account/pool.go b/internal/account/pool.go deleted file mode 100644 index 12d8874..0000000 --- a/internal/account/pool.go +++ /dev/null @@ -1,363 +0,0 @@ -package account - -import ( - "context" - "os" - "sort" - "strconv" - "strings" - "sync" - - "ds2api/internal/config" -) - -type Pool struct { - store *config.Store - mu sync.Mutex - queue []string - inUse map[string]int - waiters []chan struct{} - maxInflightPerAccount int - recommendedConcurrency int - maxQueueSize int - globalMaxInflight int -} - -func NewPool(store *config.Store) *Pool { - maxPer := 2 - if store != nil { - maxPer = store.RuntimeAccountMaxInflight() - } - p := &Pool{ - store: store, - inUse: map[string]int{}, - maxInflightPerAccount: maxPer, - } - p.Reset() - return p -} - -func (p *Pool) Reset() { - accounts := p.store.Accounts() - sort.SliceStable(accounts, func(i, j int) bool { - iHas := accounts[i].Token != "" - jHas := accounts[j].Token != "" - if iHas == jHas { - return i < j - } - return iHas - }) - ids := make([]string, 0, len(accounts)) - for _, a := range accounts { - id := a.Identifier() - if id != "" { - ids = append(ids, id) - } - } - if p.store != nil { - p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight() - } else { - p.maxInflightPerAccount = maxInflightFromEnv() - } - recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) - queueLimit := maxQueueFromEnv(recommended) - globalLimit := recommended - if p.store != nil { - queueLimit = p.store.RuntimeAccountMaxQueue(recommended) - globalLimit = p.store.RuntimeGlobalMaxInflight(recommended) - } - p.mu.Lock() - defer p.mu.Unlock() - p.drainWaitersLocked() - p.queue = ids - p.inUse = map[string]int{} - p.recommendedConcurrency = recommended - p.maxQueueSize = queueLimit - p.globalMaxInflight = globalLimit - config.Logger.Info( - "[init_account_queue] initialized", - "total", len(ids), - "max_inflight_per_account", p.maxInflightPerAccount, - "global_max_inflight", p.globalMaxInflight, - "recommended_concurrency", p.recommendedConcurrency, - "max_queue_size", p.maxQueueSize, - ) -} - -func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) { - p.mu.Lock() - defer p.mu.Unlock() - return p.acquireLocked(target, normalizeExclude(exclude)) -} - -func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) { - if ctx == nil { - ctx = context.Background() - } - exclude = normalizeExclude(exclude) - for { - if ctx.Err() != nil { - return config.Account{}, false - } - - p.mu.Lock() - if acc, ok := p.acquireLocked(target, exclude); ok { - p.mu.Unlock() - return acc, true - } - if !p.canQueueLocked(target, exclude) { - p.mu.Unlock() - return config.Account{}, false - } - waiter := make(chan struct{}) - p.waiters = append(p.waiters, waiter) - p.mu.Unlock() - - select { - case <-ctx.Done(): - p.mu.Lock() - p.removeWaiterLocked(waiter) - p.mu.Unlock() - return config.Account{}, false - case <-waiter: - } - } -} - -func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { - if target != "" { - if exclude[target] || !p.canAcquireIDLocked(target) { - return config.Account{}, false - } - acc, ok := p.store.FindAccount(target) - if !ok { - return config.Account{}, false - } - p.inUse[target]++ - p.bumpQueue(target) - return acc, true - } - - if acc, ok := p.tryAcquire(exclude, true); ok { - return acc, true - } - if acc, ok := p.tryAcquire(exclude, false); ok { - return acc, true - } - return config.Account{}, false -} - -func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { - for i := 0; i < len(p.queue); i++ { - id := p.queue[i] - if exclude[id] || !p.canAcquireIDLocked(id) { - continue - } - acc, ok := p.store.FindAccount(id) - if !ok { - continue - } - if requireToken && acc.Token == "" { - continue - } - p.inUse[id]++ - p.bumpQueue(id) - return acc, true - } - return config.Account{}, false -} - -func (p *Pool) bumpQueue(accountID string) { - for i, id := range p.queue { - if id != accountID { - continue - } - p.queue = append(p.queue[:i], p.queue[i+1:]...) - p.queue = append(p.queue, accountID) - return - } -} - -func (p *Pool) Release(accountID string) { - if accountID == "" { - return - } - p.mu.Lock() - defer p.mu.Unlock() - count := p.inUse[accountID] - if count <= 0 { - return - } - if count == 1 { - delete(p.inUse, accountID) - p.notifyWaiterLocked() - return - } - p.inUse[accountID] = count - 1 - p.notifyWaiterLocked() -} - -func (p *Pool) Status() map[string]any { - p.mu.Lock() - defer p.mu.Unlock() - available := make([]string, 0, len(p.queue)) - inUseAccounts := make([]string, 0, len(p.inUse)) - inUseSlots := 0 - for _, id := range p.queue { - if p.inUse[id] < p.maxInflightPerAccount { - available = append(available, id) - } - } - for id, count := range p.inUse { - if count > 0 { - inUseAccounts = append(inUseAccounts, id) - inUseSlots += count - } - } - sort.Strings(inUseAccounts) - return map[string]any{ - "available": len(available), - "in_use": inUseSlots, - "total": len(p.store.Accounts()), - "available_accounts": available, - "in_use_accounts": inUseAccounts, - "max_inflight_per_account": p.maxInflightPerAccount, - "global_max_inflight": p.globalMaxInflight, - "recommended_concurrency": p.recommendedConcurrency, - "waiting": len(p.waiters), - "max_queue_size": p.maxQueueSize, - } -} - -func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) { - if maxInflightPerAccount <= 0 { - maxInflightPerAccount = 1 - } - if maxQueueSize < 0 { - maxQueueSize = 0 - } - if globalMaxInflight <= 0 { - globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts()) - if globalMaxInflight <= 0 { - globalMaxInflight = maxInflightPerAccount - } - } - p.mu.Lock() - defer p.mu.Unlock() - p.maxInflightPerAccount = maxInflightPerAccount - p.maxQueueSize = maxQueueSize - p.globalMaxInflight = globalMaxInflight - p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount) - p.notifyWaiterLocked() -} - -func maxInflightFromEnv() int { - for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n > 0 { - return n - } - } - return 2 -} - -func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int { - if accountCount <= 0 { - return 0 - } - if maxInflightPerAccount <= 0 { - maxInflightPerAccount = 2 - } - return accountCount * maxInflightPerAccount -} - -func normalizeExclude(exclude map[string]bool) map[string]bool { - if exclude == nil { - return map[string]bool{} - } - return exclude -} - -func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool { - if target != "" { - if exclude[target] { - return false - } - if _, ok := p.store.FindAccount(target); !ok { - return false - } - } - if p.maxQueueSize <= 0 { - return false - } - return len(p.waiters) < p.maxQueueSize -} - -func (p *Pool) notifyWaiterLocked() { - if len(p.waiters) == 0 { - return - } - waiter := p.waiters[0] - p.waiters = p.waiters[1:] - close(waiter) -} - -func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool { - for i, w := range p.waiters { - if w != waiter { - continue - } - p.waiters = append(p.waiters[:i], p.waiters[i+1:]...) - return true - } - return false -} - -func (p *Pool) drainWaitersLocked() { - for _, waiter := range p.waiters { - close(waiter) - } - p.waiters = nil -} - -func maxQueueFromEnv(defaultSize int) int { - for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n >= 0 { - return n - } - } - if defaultSize < 0 { - return 0 - } - return defaultSize -} - -func (p *Pool) canAcquireIDLocked(accountID string) bool { - if accountID == "" { - return false - } - if p.inUse[accountID] >= p.maxInflightPerAccount { - return false - } - if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight { - return false - } - return true -} - -func (p *Pool) currentInUseLocked() int { - total := 0 - for _, n := range p.inUse { - total += n - } - return total -} diff --git a/internal/account/pool_acquire.go b/internal/account/pool_acquire.go new file mode 100644 index 0000000..b0c548c --- /dev/null +++ b/internal/account/pool_acquire.go @@ -0,0 +1,108 @@ +package account + +import ( + "context" + + "ds2api/internal/config" +) + +func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) { + p.mu.Lock() + defer p.mu.Unlock() + return p.acquireLocked(target, normalizeExclude(exclude)) +} + +func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) { + if ctx == nil { + ctx = context.Background() + } + exclude = normalizeExclude(exclude) + for { + if ctx.Err() != nil { + return config.Account{}, false + } + + p.mu.Lock() + if acc, ok := p.acquireLocked(target, exclude); ok { + p.mu.Unlock() + return acc, true + } + if !p.canQueueLocked(target, exclude) { + p.mu.Unlock() + return config.Account{}, false + } + waiter := make(chan struct{}) + p.waiters = append(p.waiters, waiter) + p.mu.Unlock() + + select { + case <-ctx.Done(): + p.mu.Lock() + p.removeWaiterLocked(waiter) + p.mu.Unlock() + return config.Account{}, false + case <-waiter: + } + } +} + +func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) { + if target != "" { + if exclude[target] || !p.canAcquireIDLocked(target) { + return config.Account{}, false + } + acc, ok := p.store.FindAccount(target) + if !ok { + return config.Account{}, false + } + p.inUse[target]++ + p.bumpQueue(target) + return acc, true + } + + if acc, ok := p.tryAcquire(exclude, true); ok { + return acc, true + } + if acc, ok := p.tryAcquire(exclude, false); ok { + return acc, true + } + return config.Account{}, false +} + +func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) { + for i := 0; i < len(p.queue); i++ { + id := p.queue[i] + if exclude[id] || !p.canAcquireIDLocked(id) { + continue + } + acc, ok := p.store.FindAccount(id) + if !ok { + continue + } + if requireToken && acc.Token == "" { + continue + } + p.inUse[id]++ + p.bumpQueue(id) + return acc, true + } + return config.Account{}, false +} + +func (p *Pool) bumpQueue(accountID string) { + for i, id := range p.queue { + if id != accountID { + continue + } + p.queue = append(p.queue[:i], p.queue[i+1:]...) + p.queue = append(p.queue, accountID) + return + } +} + +func normalizeExclude(exclude map[string]bool) map[string]bool { + if exclude == nil { + return map[string]bool{} + } + return exclude +} diff --git a/internal/account/pool_core.go b/internal/account/pool_core.go new file mode 100644 index 0000000..90e2594 --- /dev/null +++ b/internal/account/pool_core.go @@ -0,0 +1,132 @@ +package account + +import ( + "sort" + "sync" + + "ds2api/internal/config" +) + +type Pool struct { + store *config.Store + mu sync.Mutex + queue []string + inUse map[string]int + waiters []chan struct{} + maxInflightPerAccount int + recommendedConcurrency int + maxQueueSize int + globalMaxInflight int +} + +func NewPool(store *config.Store) *Pool { + maxPer := 2 + if store != nil { + maxPer = store.RuntimeAccountMaxInflight() + } + p := &Pool{ + store: store, + inUse: map[string]int{}, + maxInflightPerAccount: maxPer, + } + p.Reset() + return p +} + +func (p *Pool) Reset() { + accounts := p.store.Accounts() + sort.SliceStable(accounts, func(i, j int) bool { + iHas := accounts[i].Token != "" + jHas := accounts[j].Token != "" + if iHas == jHas { + return i < j + } + return iHas + }) + ids := make([]string, 0, len(accounts)) + for _, a := range accounts { + id := a.Identifier() + if id != "" { + ids = append(ids, id) + } + } + if p.store != nil { + p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight() + } else { + p.maxInflightPerAccount = maxInflightFromEnv() + } + recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount) + queueLimit := maxQueueFromEnv(recommended) + globalLimit := recommended + if p.store != nil { + queueLimit = p.store.RuntimeAccountMaxQueue(recommended) + globalLimit = p.store.RuntimeGlobalMaxInflight(recommended) + } + p.mu.Lock() + defer p.mu.Unlock() + p.drainWaitersLocked() + p.queue = ids + p.inUse = map[string]int{} + p.recommendedConcurrency = recommended + p.maxQueueSize = queueLimit + p.globalMaxInflight = globalLimit + config.Logger.Info( + "[init_account_queue] initialized", + "total", len(ids), + "max_inflight_per_account", p.maxInflightPerAccount, + "global_max_inflight", p.globalMaxInflight, + "recommended_concurrency", p.recommendedConcurrency, + "max_queue_size", p.maxQueueSize, + ) +} + +func (p *Pool) Release(accountID string) { + if accountID == "" { + return + } + p.mu.Lock() + defer p.mu.Unlock() + count := p.inUse[accountID] + if count <= 0 { + return + } + if count == 1 { + delete(p.inUse, accountID) + p.notifyWaiterLocked() + return + } + p.inUse[accountID] = count - 1 + p.notifyWaiterLocked() +} + +func (p *Pool) Status() map[string]any { + p.mu.Lock() + defer p.mu.Unlock() + available := make([]string, 0, len(p.queue)) + inUseAccounts := make([]string, 0, len(p.inUse)) + inUseSlots := 0 + for _, id := range p.queue { + if p.inUse[id] < p.maxInflightPerAccount { + available = append(available, id) + } + } + for id, count := range p.inUse { + if count > 0 { + inUseAccounts = append(inUseAccounts, id) + inUseSlots += count + } + } + sort.Strings(inUseAccounts) + return map[string]any{ + "available": len(available), + "in_use": inUseSlots, + "total": len(p.store.Accounts()), + "available_accounts": available, + "in_use_accounts": inUseAccounts, + "max_inflight_per_account": p.maxInflightPerAccount, + "global_max_inflight": p.globalMaxInflight, + "recommended_concurrency": p.recommendedConcurrency, + "waiting": len(p.waiters), + "max_queue_size": p.maxQueueSize, + } +} diff --git a/internal/account/pool_limits.go b/internal/account/pool_limits.go new file mode 100644 index 0000000..0f0854f --- /dev/null +++ b/internal/account/pool_limits.go @@ -0,0 +1,91 @@ +package account + +import ( + "os" + "strconv" + "strings" +) + +func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) { + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 1 + } + if maxQueueSize < 0 { + maxQueueSize = 0 + } + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts()) + if globalMaxInflight <= 0 { + globalMaxInflight = maxInflightPerAccount + } + } + p.mu.Lock() + defer p.mu.Unlock() + p.maxInflightPerAccount = maxInflightPerAccount + p.maxQueueSize = maxQueueSize + p.globalMaxInflight = globalMaxInflight + p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount) + p.notifyWaiterLocked() +} + +func maxInflightFromEnv() int { + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int { + if accountCount <= 0 { + return 0 + } + if maxInflightPerAccount <= 0 { + maxInflightPerAccount = 2 + } + return accountCount * maxInflightPerAccount +} + +func maxQueueFromEnv(defaultSize int) int { + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (p *Pool) canAcquireIDLocked(accountID string) bool { + if accountID == "" { + return false + } + if p.inUse[accountID] >= p.maxInflightPerAccount { + return false + } + if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight { + return false + } + return true +} + +func (p *Pool) currentInUseLocked() int { + total := 0 + for _, n := range p.inUse { + total += n + } + return total +} diff --git a/internal/account/pool_waiters.go b/internal/account/pool_waiters.go new file mode 100644 index 0000000..40bd146 --- /dev/null +++ b/internal/account/pool_waiters.go @@ -0,0 +1,43 @@ +package account + +func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool { + if target != "" { + if exclude[target] { + return false + } + if _, ok := p.store.FindAccount(target); !ok { + return false + } + } + if p.maxQueueSize <= 0 { + return false + } + return len(p.waiters) < p.maxQueueSize +} + +func (p *Pool) notifyWaiterLocked() { + if len(p.waiters) == 0 { + return + } + waiter := p.waiters[0] + p.waiters = p.waiters[1:] + close(waiter) +} + +func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool { + for i, w := range p.waiters { + if w != waiter { + continue + } + p.waiters = append(p.waiters[:i], p.waiters[i+1:]...) + return true + } + return false +} + +func (p *Pool) drainWaitersLocked() { + for _, waiter := range p.waiters { + close(waiter) + } + p.waiters = nil +} diff --git a/internal/adapter/claude/error_shape_test.go b/internal/adapter/claude/error_shape_test.go index 910fce8..b9dc469 100644 --- a/internal/adapter/claude/error_shape_test.go +++ b/internal/adapter/claude/error_shape_test.go @@ -32,4 +32,3 @@ func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) { t.Fatal("expected param field") } } - diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go deleted file mode 100644 index 2fa6796..0000000 --- a/internal/adapter/claude/handler.go +++ /dev/null @@ -1,368 +0,0 @@ -package claude - -import ( - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/auth" - "ds2api/internal/config" - "ds2api/internal/deepseek" - claudefmt "ds2api/internal/format/claude" - "ds2api/internal/sse" - streamengine "ds2api/internal/stream" - "ds2api/internal/util" -) - -// writeJSON is a package-internal alias to avoid mass-renaming all call-sites. -var writeJSON = util.WriteJSON - -type Handler struct { - Store ConfigReader - Auth AuthResolver - DS DeepSeekCaller -} - -var ( - claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second - claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second - claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount -) - -func RegisterRoutes(r chi.Router, h *Handler) { - r.Get("/anthropic/v1/models", h.ListModels) - r.Post("/anthropic/v1/messages", h.Messages) - r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) - r.Post("/v1/messages", h.Messages) - r.Post("/messages", h.Messages) - r.Post("/v1/messages/count_tokens", h.CountTokens) - r.Post("/messages/count_tokens", h.CountTokens) -} - -func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) -} - -func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { - if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { - r.Header.Set("anthropic-version", "2023-06-01") - } - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeClaudeError(w, status, detail) - return - } - defer h.Auth.Release(a) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeClaudeError(w, http.StatusBadRequest, "invalid json") - return - } - norm, err := normalizeClaudeRequest(h.Store, req) - if err != nil { - writeClaudeError(w, http.StatusBadRequest, err.Error()) - return - } - stdReq := norm.Standard - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - writeClaudeError(w, http.StatusUnauthorized, "invalid token.") - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") - return - } - requestPayload := stdReq.CompletionPayload(sessionID) - resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) - if err != nil { - writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") - return - } - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - writeClaudeError(w, http.StatusInternalServerError, string(body)) - return - } - - if stdReq.Stream { - h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) - return - } - result := sse.CollectStream(resp, stdReq.Thinking, true) - respBody := claudefmt.BuildMessageResponse( - fmt.Sprintf("msg_%d", time.Now().UnixNano()), - stdReq.ResponseModel, - norm.NormalizedMessages, - result.Thinking, - result.Text, - stdReq.ToolNames, - ) - writeJSON(w, http.StatusOK, respBody) -} - -func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { - a, err := h.Auth.Determine(r) - if err != nil { - writeClaudeError(w, http.StatusUnauthorized, err.Error()) - return - } - defer h.Auth.Release(a) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeClaudeError(w, http.StatusBadRequest, "invalid json") - return - } - model, _ := req["model"].(string) - messages, _ := req["messages"].([]any) - if model == "" || len(messages) == 0 { - writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") - return - } - inputTokens := 0 - if sys, ok := req["system"].(string); ok { - inputTokens += util.EstimateTokens(sys) - } - for _, item := range messages { - msg, ok := item.(map[string]any) - if !ok { - continue - } - inputTokens += 2 - inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) - } - if tools, ok := req["tools"].([]any); ok { - for _, t := range tools { - b, _ := json.Marshal(t) - inputTokens += util.EstimateTokens(string(b)) - } - } - if inputTokens < 1 { - inputTokens = 1 - } - writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) -} - -func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeClaudeError(w, http.StatusInternalServerError, string(body)) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - rc := http.NewResponseController(w) - _, canFlush := w.(http.Flusher) - if !canFlush { - config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") - } - - streamRuntime := newClaudeStreamRuntime( - w, - rc, - canFlush, - model, - messages, - thinkingEnabled, - searchEnabled, - toolNames, - ) - streamRuntime.sendMessageStart() - - initialType := "text" - if thinkingEnabled { - initialType = "thinking" - } - streamengine.ConsumeSSE(streamengine.ConsumeConfig{ - Context: r.Context(), - Body: resp.Body, - ThinkingEnabled: thinkingEnabled, - InitialType: initialType, - KeepAliveInterval: claudeStreamPingInterval, - IdleTimeout: claudeStreamIdleTimeout, - MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, - }, streamengine.ConsumeHooks{ - OnKeepAlive: func() { - streamRuntime.sendPing() - }, - OnParsed: streamRuntime.onParsed, - OnFinalize: streamRuntime.onFinalize, - }) -} - -func writeClaudeError(w http.ResponseWriter, status int, message string) { - code := "invalid_request" - switch status { - case http.StatusUnauthorized: - code = "authentication_failed" - case http.StatusTooManyRequests: - code = "rate_limit_exceeded" - case http.StatusNotFound: - code = "not_found" - case http.StatusInternalServerError: - code = "internal_error" - } - writeJSON(w, status, map[string]any{ - "error": map[string]any{ - "type": "invalid_request_error", - "message": message, - "code": code, - "param": nil, - }, - }) -} - -func normalizeClaudeMessages(messages []any) []any { - out := make([]any, 0, len(messages)) - for _, m := range messages { - msg, ok := m.(map[string]any) - if !ok { - continue - } - copied := cloneMap(msg) - switch content := msg["content"].(type) { - case []any: - parts := make([]string, 0, len(content)) - for _, block := range content { - b, ok := block.(map[string]any) - if !ok { - continue - } - typeStr, _ := b["type"].(string) - if typeStr == "text" { - if t, ok := b["text"].(string); ok { - parts = append(parts, t) - } - } - if typeStr == "tool_result" { - parts = append(parts, formatClaudeToolResultForPrompt(b)) - } - } - copied["content"] = strings.Join(parts, "\n") - } - out = append(out, copied) - } - return out -} - -func buildClaudeToolPrompt(tools []any) string { - parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} - for _, t := range tools { - m, ok := t.(map[string]any) - if !ok { - continue - } - name, _ := m["name"].(string) - desc, _ := m["description"].(string) - schema, _ := json.Marshal(m["input_schema"]) - parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) - } - parts = append(parts, - "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", - "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", - "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", - ) - return strings.Join(parts, "\n\n") -} - -func formatClaudeToolResultForPrompt(block map[string]any) string { - if block == nil { - return "" - } - toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) - if toolCallID == "" { - toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) - } - if toolCallID == "" { - toolCallID = "unknown" - } - name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) - if name == "" { - name = "unknown" - } - content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) - if content == "" { - content = "null" - } - return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) -} - -func hasSystemMessage(messages []any) bool { - for _, m := range messages { - msg, ok := m.(map[string]any) - if ok && msg["role"] == "system" { - return true - } - } - return false -} - -func extractClaudeToolNames(tools []any) []string { - out := make([]string, 0, len(tools)) - for _, t := range tools { - m, ok := t.(map[string]any) - if !ok { - continue - } - if name, ok := m["name"].(string); ok && name != "" { - out = append(out, name) - } - } - return out -} - -func toMessageMaps(v any) []map[string]any { - arr, ok := v.([]any) - if !ok { - return nil - } - out := make([]map[string]any, 0, len(arr)) - for _, item := range arr { - if m, ok := item.(map[string]any); ok { - out = append(out, m) - } - } - return out -} - -func extractMessageContent(v any) string { - switch x := v.(type) { - case string: - return x - case []any: - parts := make([]string, 0, len(x)) - for _, it := range x { - parts = append(parts, fmt.Sprintf("%v", it)) - } - return strings.Join(parts, "\n") - default: - return fmt.Sprintf("%v", x) - } -} - -func cloneMap(in map[string]any) map[string]any { - out := make(map[string]any, len(in)) - for k, v := range in { - out[k] = v - } - return out -} diff --git a/internal/adapter/claude/handler_errors.go b/internal/adapter/claude/handler_errors.go new file mode 100644 index 0000000..f1188d6 --- /dev/null +++ b/internal/adapter/claude/handler_errors.go @@ -0,0 +1,25 @@ +package claude + +import "net/http" + +func writeClaudeError(w http.ResponseWriter, status int, message string) { + code := "invalid_request" + switch status { + case http.StatusUnauthorized: + code = "authentication_failed" + case http.StatusTooManyRequests: + code = "rate_limit_exceeded" + case http.StatusNotFound: + code = "not_found" + case http.StatusInternalServerError: + code = "internal_error" + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "type": "invalid_request_error", + "message": message, + "code": code, + "param": nil, + }, + }) +} diff --git a/internal/adapter/claude/handler_messages.go b/internal/adapter/claude/handler_messages.go new file mode 100644 index 0000000..1c4272b --- /dev/null +++ b/internal/adapter/claude/handler_messages.go @@ -0,0 +1,134 @@ +package claude + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + claudefmt "ds2api/internal/format/claude" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { + if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" { + r.Header.Set("anthropic-version", "2023-06-01") + } + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeClaudeError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + return + } + norm, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) + return + } + stdReq := norm.Standard + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, "invalid token.") + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW") + return + } + requestPayload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3) + if err != nil { + writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.") + return + } + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + writeClaudeError(w, http.StatusInternalServerError, string(body)) + return + } + + if stdReq.Stream { + h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + result := sse.CollectStream(resp, stdReq.Thinking, true) + respBody := claudefmt.BuildMessageResponse( + fmt.Sprintf("msg_%d", time.Now().UnixNano()), + stdReq.ResponseModel, + norm.NormalizedMessages, + result.Thinking, + result.Text, + stdReq.ToolNames, + ) + writeJSON(w, http.StatusOK, respBody) +} + +func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeClaudeError(w, http.StatusInternalServerError, string(body)) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + if !canFlush { + config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered") + } + + streamRuntime := newClaudeStreamRuntime( + w, + rc, + canFlush, + model, + messages, + thinkingEnabled, + searchEnabled, + toolNames, + ) + streamRuntime.sendMessageStart() + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: claudeStreamPingInterval, + IdleTimeout: claudeStreamIdleTimeout, + MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendPing() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: streamRuntime.onFinalize, + }) +} diff --git a/internal/adapter/claude/handler_routes.go b/internal/adapter/claude/handler_routes.go new file mode 100644 index 0000000..0376b2c --- /dev/null +++ b/internal/adapter/claude/handler_routes.go @@ -0,0 +1,41 @@ +package claude + +import ( + "net/http" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" + "ds2api/internal/deepseek" + "ds2api/internal/util" +) + +// writeJSON is a package-internal alias to avoid mass-renaming all call-sites. +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +var ( + claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second + claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second + claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount +) + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/anthropic/v1/models", h.ListModels) + r.Post("/anthropic/v1/messages", h.Messages) + r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens) + r.Post("/v1/messages", h.Messages) + r.Post("/messages", h.Messages) + r.Post("/v1/messages/count_tokens", h.CountTokens) + r.Post("/messages/count_tokens", h.CountTokens) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.ClaudeModelsResponse()) +} diff --git a/internal/adapter/claude/handler_tokens.go b/internal/adapter/claude/handler_tokens.go new file mode 100644 index 0000000..a369345 --- /dev/null +++ b/internal/adapter/claude/handler_tokens.go @@ -0,0 +1,51 @@ +package claude + +import ( + "encoding/json" + "net/http" + + "ds2api/internal/util" +) + +func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { + a, err := h.Auth.Determine(r) + if err != nil { + writeClaudeError(w, http.StatusUnauthorized, err.Error()) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + return + } + model, _ := req["model"].(string) + messages, _ := req["messages"].([]any) + if model == "" || len(messages) == 0 { + writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") + return + } + inputTokens := 0 + if sys, ok := req["system"].(string); ok { + inputTokens += util.EstimateTokens(sys) + } + for _, item := range messages { + msg, ok := item.(map[string]any) + if !ok { + continue + } + inputTokens += 2 + inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) + } + if tools, ok := req["tools"].([]any); ok { + for _, t := range tools { + b, _ := json.Marshal(t) + inputTokens += util.EstimateTokens(string(b)) + } + } + if inputTokens < 1 { + inputTokens = 1 + } + writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) +} diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go new file mode 100644 index 0000000..df4c6b2 --- /dev/null +++ b/internal/adapter/claude/handler_utils.go @@ -0,0 +1,143 @@ +package claude + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeClaudeMessages(messages []any) []any { + out := make([]any, 0, len(messages)) + for _, m := range messages { + msg, ok := m.(map[string]any) + if !ok { + continue + } + copied := cloneMap(msg) + switch content := msg["content"].(type) { + case []any: + parts := make([]string, 0, len(content)) + for _, block := range content { + b, ok := block.(map[string]any) + if !ok { + continue + } + typeStr, _ := b["type"].(string) + if typeStr == "text" { + if t, ok := b["text"].(string); ok { + parts = append(parts, t) + } + } + if typeStr == "tool_result" { + parts = append(parts, formatClaudeToolResultForPrompt(b)) + } + } + copied["content"] = strings.Join(parts, "\n") + } + out = append(out, copied) + } + return out +} + +func buildClaudeToolPrompt(tools []any) string { + parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"} + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + name, _ := m["name"].(string) + desc, _ := m["description"].(string) + schema, _ := json.Marshal(m["input_schema"]) + parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) + } + parts = append(parts, + "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", + "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", + "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", + ) + return strings.Join(parts, "\n\n") +} + +func formatClaudeToolResultForPrompt(block map[string]any) string { + if block == nil { + return "" + } + toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])) + if name == "" { + name = "unknown" + } + content := strings.TrimSpace(fmt.Sprintf("%v", block["content"])) + if content == "" { + content = "null" + } + return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content) +} + +func hasSystemMessage(messages []any) bool { + for _, m := range messages { + msg, ok := m.(map[string]any) + if ok && msg["role"] == "system" { + return true + } + } + return false +} + +func extractClaudeToolNames(tools []any) []string { + out := make([]string, 0, len(tools)) + for _, t := range tools { + m, ok := t.(map[string]any) + if !ok { + continue + } + if name, ok := m["name"].(string); ok && name != "" { + out = append(out, name) + } + } + return out +} + +func toMessageMaps(v any) []map[string]any { + arr, ok := v.([]any) + if !ok { + return nil + } + out := make([]map[string]any, 0, len(arr)) + for _, item := range arr { + if m, ok := item.(map[string]any); ok { + out = append(out, m) + } + } + return out +} + +func extractMessageContent(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, it := range x { + parts = append(parts, fmt.Sprintf("%v", it)) + } + return strings.Join(parts, "\n") + default: + return fmt.Sprintf("%v", x) + } +} + +func cloneMap(in map[string]any) map[string]any { + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/internal/adapter/claude/stream_runtime.go b/internal/adapter/claude/stream_runtime.go deleted file mode 100644 index 01e07a9..0000000 --- a/internal/adapter/claude/stream_runtime.go +++ /dev/null @@ -1,308 +0,0 @@ -package claude - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" - "time" - - "ds2api/internal/sse" - streamengine "ds2api/internal/stream" - "ds2api/internal/util" -) - -type claudeStreamRuntime struct { - w http.ResponseWriter - rc *http.ResponseController - canFlush bool - - model string - toolNames []string - messages []any - - thinkingEnabled bool - searchEnabled bool - bufferToolContent bool - - messageID string - thinking strings.Builder - text strings.Builder - - nextBlockIndex int - thinkingBlockOpen bool - thinkingBlockIndex int - textBlockOpen bool - textBlockIndex int - ended bool - upstreamErr string -} - -func newClaudeStreamRuntime( - w http.ResponseWriter, - rc *http.ResponseController, - canFlush bool, - model string, - messages []any, - thinkingEnabled bool, - searchEnabled bool, - toolNames []string, -) *claudeStreamRuntime { - return &claudeStreamRuntime{ - w: w, - rc: rc, - canFlush: canFlush, - model: model, - messages: messages, - thinkingEnabled: thinkingEnabled, - searchEnabled: searchEnabled, - bufferToolContent: len(toolNames) > 0, - toolNames: toolNames, - messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), - thinkingBlockIndex: -1, - textBlockIndex: -1, - } -} - -func (s *claudeStreamRuntime) send(event string, v any) { - b, _ := json.Marshal(v) - _, _ = s.w.Write([]byte("event: ")) - _, _ = s.w.Write([]byte(event)) - _, _ = s.w.Write([]byte("\n")) - _, _ = s.w.Write([]byte("data: ")) - _, _ = s.w.Write(b) - _, _ = s.w.Write([]byte("\n\n")) - if s.canFlush { - _ = s.rc.Flush() - } -} - -func (s *claudeStreamRuntime) sendError(message string) { - msg := strings.TrimSpace(message) - if msg == "" { - msg = "upstream stream error" - } - s.send("error", map[string]any{ - "type": "error", - "error": map[string]any{ - "type": "api_error", - "message": msg, - "code": "internal_error", - "param": nil, - }, - }) -} - -func (s *claudeStreamRuntime) sendPing() { - s.send("ping", map[string]any{"type": "ping"}) -} - -func (s *claudeStreamRuntime) sendMessageStart() { - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) - s.send("message_start", map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": s.messageID, - "type": "message", - "role": "assistant", - "model": s.model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, - }, - }) -} - -func (s *claudeStreamRuntime) closeThinkingBlock() { - if !s.thinkingBlockOpen { - return - } - s.send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": s.thinkingBlockIndex, - }) - s.thinkingBlockOpen = false - s.thinkingBlockIndex = -1 -} - -func (s *claudeStreamRuntime) closeTextBlock() { - if !s.textBlockOpen { - return - } - s.send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": s.textBlockIndex, - }) - s.textBlockOpen = false - s.textBlockIndex = -1 -} - -func (s *claudeStreamRuntime) finalize(stopReason string) { - if s.ended { - return - } - s.ended = true - - s.closeThinkingBlock() - s.closeTextBlock() - - finalThinking := s.thinking.String() - finalText := s.text.String() - - if s.bufferToolContent { - detected := util.ParseToolCalls(finalText, s.toolNames) - if len(detected) > 0 { - stopReason = "tool_use" - for i, tc := range detected { - idx := s.nextBlockIndex + i - s.send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "tool_use", - "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), - "name": tc.Name, - "input": tc.Input, - }, - }) - s.send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - s.nextBlockIndex += len(detected) - } else if finalText != "" { - idx := s.nextBlockIndex - s.nextBlockIndex++ - s.send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": idx, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - s.send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": idx, - "delta": map[string]any{ - "type": "text_delta", - "text": finalText, - }, - }) - s.send("content_block_stop", map[string]any{ - "type": "content_block_stop", - "index": idx, - }) - } - } - - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) - s.send("message_delta", map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": stopReason, - "stop_sequence": nil, - }, - "usage": map[string]any{ - "output_tokens": outputTokens, - }, - }) - s.send("message_stop", map[string]any{"type": "message_stop"}) -} - -func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { - if !parsed.Parsed { - return streamengine.ParsedDecision{} - } - if parsed.ErrorMessage != "" { - s.upstreamErr = parsed.ErrorMessage - return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} - } - if parsed.Stop { - return streamengine.ParsedDecision{Stop: true} - } - - contentSeen := false - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { - continue - } - contentSeen = true - - if p.Type == "thinking" { - if !s.thinkingEnabled { - continue - } - s.thinking.WriteString(p.Text) - s.closeTextBlock() - if !s.thinkingBlockOpen { - s.thinkingBlockIndex = s.nextBlockIndex - s.nextBlockIndex++ - s.send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": s.thinkingBlockIndex, - "content_block": map[string]any{ - "type": "thinking", - "thinking": "", - }, - }) - s.thinkingBlockOpen = true - } - s.send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": s.thinkingBlockIndex, - "delta": map[string]any{ - "type": "thinking_delta", - "thinking": p.Text, - }, - }) - continue - } - - s.text.WriteString(p.Text) - if s.bufferToolContent { - continue - } - s.closeThinkingBlock() - if !s.textBlockOpen { - s.textBlockIndex = s.nextBlockIndex - s.nextBlockIndex++ - s.send("content_block_start", map[string]any{ - "type": "content_block_start", - "index": s.textBlockIndex, - "content_block": map[string]any{ - "type": "text", - "text": "", - }, - }) - s.textBlockOpen = true - } - s.send("content_block_delta", map[string]any{ - "type": "content_block_delta", - "index": s.textBlockIndex, - "delta": map[string]any{ - "type": "text_delta", - "text": p.Text, - }, - }) - } - - return streamengine.ParsedDecision{ContentSeen: contentSeen} -} - -func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) { - if string(reason) == "upstream_error" { - s.sendError(s.upstreamErr) - return - } - if scannerErr != nil { - s.sendError(scannerErr.Error()) - return - } - s.finalize("end_turn") -} diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go new file mode 100644 index 0000000..cb24bdd --- /dev/null +++ b/internal/adapter/claude/stream_runtime_core.go @@ -0,0 +1,146 @@ +package claude + +import ( + "fmt" + "net/http" + "strings" + "time" + + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +type claudeStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + toolNames []string + messages []any + + thinkingEnabled bool + searchEnabled bool + bufferToolContent bool + + messageID string + thinking strings.Builder + text strings.Builder + + nextBlockIndex int + thinkingBlockOpen bool + thinkingBlockIndex int + textBlockOpen bool + textBlockIndex int + ended bool + upstreamErr string +} + +func newClaudeStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + messages []any, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *claudeStreamRuntime { + return &claudeStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + messages: messages, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferToolContent: len(toolNames) > 0, + toolNames: toolNames, + messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), + thinkingBlockIndex: -1, + textBlockIndex: -1, + } +} + +func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ErrorMessage != "" { + s.upstreamErr = parsed.ErrorMessage + return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")} + } + if parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.closeTextBlock() + if !s.thinkingBlockOpen { + s.thinkingBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.thinkingBlockIndex, + "content_block": map[string]any{ + "type": "thinking", + "thinking": "", + }, + }) + s.thinkingBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.thinkingBlockIndex, + "delta": map[string]any{ + "type": "thinking_delta", + "thinking": p.Text, + }, + }) + continue + } + + s.text.WriteString(p.Text) + if s.bufferToolContent { + continue + } + s.closeThinkingBlock() + if !s.textBlockOpen { + s.textBlockIndex = s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": s.textBlockIndex, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.textBlockOpen = true + } + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": s.textBlockIndex, + "delta": map[string]any{ + "type": "text_delta", + "text": p.Text, + }, + }) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/claude/stream_runtime_emit.go b/internal/adapter/claude/stream_runtime_emit.go new file mode 100644 index 0000000..c2fba19 --- /dev/null +++ b/internal/adapter/claude/stream_runtime_emit.go @@ -0,0 +1,59 @@ +package claude + +import ( + "encoding/json" + "fmt" + "strings" + + "ds2api/internal/util" +) + +func (s *claudeStreamRuntime) send(event string, v any) { + b, _ := json.Marshal(v) + _, _ = s.w.Write([]byte("event: ")) + _, _ = s.w.Write([]byte(event)) + _, _ = s.w.Write([]byte("\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *claudeStreamRuntime) sendError(message string) { + msg := strings.TrimSpace(message) + if msg == "" { + msg = "upstream stream error" + } + s.send("error", map[string]any{ + "type": "error", + "error": map[string]any{ + "type": "api_error", + "message": msg, + "code": "internal_error", + "param": nil, + }, + }) +} + +func (s *claudeStreamRuntime) sendPing() { + s.send("ping", map[string]any{"type": "ping"}) +} + +func (s *claudeStreamRuntime) sendMessageStart() { + inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + s.send("message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": s.messageID, + "type": "message", + "role": "assistant", + "model": s.model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0}, + }, + }) +} diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go new file mode 100644 index 0000000..f957ba1 --- /dev/null +++ b/internal/adapter/claude/stream_runtime_finalize.go @@ -0,0 +1,119 @@ +package claude + +import ( + "fmt" + "time" + + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +func (s *claudeStreamRuntime) closeThinkingBlock() { + if !s.thinkingBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.thinkingBlockIndex, + }) + s.thinkingBlockOpen = false + s.thinkingBlockIndex = -1 +} + +func (s *claudeStreamRuntime) closeTextBlock() { + if !s.textBlockOpen { + return + } + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": s.textBlockIndex, + }) + s.textBlockOpen = false + s.textBlockIndex = -1 +} + +func (s *claudeStreamRuntime) finalize(stopReason string) { + if s.ended { + return + } + s.ended = true + + s.closeThinkingBlock() + s.closeTextBlock() + + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferToolContent { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) > 0 { + stopReason = "tool_use" + for i, tc := range detected { + idx := s.nextBlockIndex + i + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx), + "name": tc.Name, + "input": tc.Input, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + s.nextBlockIndex += len(detected) + } else if finalText != "" { + idx := s.nextBlockIndex + s.nextBlockIndex++ + s.send("content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + s.send("content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": finalText, + }, + }) + s.send("content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + s.send("message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": nil, + }, + "usage": map[string]any{ + "output_tokens": outputTokens, + }, + }) + s.send("message_stop", map[string]any{"type": "message_stop"}) +} + +func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) { + if string(reason) == "upstream_error" { + s.sendError(s.upstreamErr) + return + } + if scannerErr != nil { + s.sendError(scannerErr.Error()) + return + } + s.finalize("end_turn") +} diff --git a/internal/adapter/gemini/convert.go b/internal/adapter/gemini/convert.go deleted file mode 100644 index 3f63579..0000000 --- a/internal/adapter/gemini/convert.go +++ /dev/null @@ -1,313 +0,0 @@ -package gemini - -import ( - "encoding/json" - "fmt" - "strings" - - "ds2api/internal/adapter/openai" - "ds2api/internal/config" - "ds2api/internal/util" -) - -func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) { - requestedModel := strings.TrimSpace(routeModel) - if requestedModel == "" { - return util.StandardRequest{}, fmt.Errorf("model is required in request path") - } - - resolvedModel, ok := config.ResolveModel(store, requestedModel) - if !ok { - return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel) - } - thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) - - messagesRaw := geminiMessagesFromRequest(req) - if len(messagesRaw) == 0 { - return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.") - } - - toolsRaw := convertGeminiTools(req["tools"]) - finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") - passThrough := collectGeminiPassThrough(req) - - return util.StandardRequest{ - Surface: "google_gemini", - RequestedModel: requestedModel, - ResolvedModel: resolvedModel, - ResponseModel: requestedModel, - Messages: messagesRaw, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: stream, - Thinking: thinkingEnabled, - Search: searchEnabled, - PassThrough: passThrough, - }, nil -} - -func geminiMessagesFromRequest(req map[string]any) []any { - out := make([]any, 0, 8) - if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { - out = append(out, map[string]any{ - "role": "system", - "content": sys, - }) - } - - contents, _ := req["contents"].([]any) - for _, item := range contents { - content, ok := item.(map[string]any) - if !ok { - continue - } - role := mapGeminiRole(content["role"]) - if role == "" { - role = "user" - } - parts, _ := content["parts"].([]any) - if len(parts) == 0 { - if text := strings.TrimSpace(asString(content["text"])); text != "" { - out = append(out, map[string]any{ - "role": role, - "content": text, - }) - } - continue - } - - textParts := make([]string, 0, len(parts)) - flushText := func() { - if len(textParts) == 0 { - return - } - out = append(out, map[string]any{ - "role": role, - "content": strings.Join(textParts, "\n"), - }) - textParts = textParts[:0] - } - - for _, rawPart := range parts { - part, ok := rawPart.(map[string]any) - if !ok { - continue - } - if text := strings.TrimSpace(asString(part["text"])); text != "" { - textParts = append(textParts, text) - continue - } - - if fnCall, ok := part["functionCall"].(map[string]any); ok { - flushText() - if name := strings.TrimSpace(asString(fnCall["name"])); name != "" { - callID := strings.TrimSpace(asString(fnCall["id"])) - if callID == "" { - callID = "call_gemini" - } - out = append(out, map[string]any{ - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "id": callID, - "type": "function", - "function": map[string]any{ - "name": name, - "arguments": stringifyJSON(fnCall["args"]), - }, - }, - }, - }) - } - continue - } - - if fnResp, ok := part["functionResponse"].(map[string]any); ok { - flushText() - name := strings.TrimSpace(asString(fnResp["name"])) - callID := strings.TrimSpace(asString(fnResp["id"])) - if callID == "" { - callID = strings.TrimSpace(asString(fnResp["callId"])) - } - if callID == "" { - callID = strings.TrimSpace(asString(fnResp["tool_call_id"])) - } - if callID == "" { - callID = "call_gemini" - } - content := fnResp["response"] - if content == nil { - content = fnResp["output"] - } - if content == nil { - content = "" - } - msg := map[string]any{ - "role": "tool", - "tool_call_id": callID, - "content": content, - } - if name != "" { - msg["name"] = name - } - out = append(out, msg) - } - } - flushText() - } - return out -} - -func normalizeGeminiSystemInstruction(raw any) string { - switch v := raw.(type) { - case string: - return strings.TrimSpace(v) - case map[string]any: - if parts, ok := v["parts"].([]any); ok { - texts := make([]string, 0, len(parts)) - for _, item := range parts { - part, ok := item.(map[string]any) - if !ok { - continue - } - if text := strings.TrimSpace(asString(part["text"])); text != "" { - texts = append(texts, text) - } - } - return strings.Join(texts, "\n") - } - if text := strings.TrimSpace(asString(v["text"])); text != "" { - return text - } - } - return "" -} - -func mapGeminiRole(v any) string { - switch strings.ToLower(strings.TrimSpace(asString(v))) { - case "user": - return "user" - case "model", "assistant": - return "assistant" - case "system": - return "system" - default: - return "" - } -} - -func convertGeminiTools(raw any) []any { - tools, _ := raw.([]any) - if len(tools) == 0 { - return nil - } - out := make([]any, 0, len(tools)) - for _, item := range tools { - tool, ok := item.(map[string]any) - if !ok { - continue - } - - if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 { - for _, declRaw := range fnDecls { - decl, ok := declRaw.(map[string]any) - if !ok { - continue - } - name := strings.TrimSpace(asString(decl["name"])) - if name == "" { - continue - } - function := map[string]any{ - "name": name, - } - if desc := strings.TrimSpace(asString(decl["description"])); desc != "" { - function["description"] = desc - } - if params, ok := decl["parameters"].(map[string]any); ok { - function["parameters"] = params - } - out = append(out, map[string]any{ - "type": "function", - "function": function, - }) - } - continue - } - - // OpenAI-style passthrough fallback. - if _, ok := tool["function"].(map[string]any); ok { - out = append(out, tool) - continue - } - - // Loose fallback for flattened function schema objects. - name := strings.TrimSpace(asString(tool["name"])) - if name == "" { - continue - } - fn := map[string]any{"name": name} - if desc := strings.TrimSpace(asString(tool["description"])); desc != "" { - fn["description"] = desc - } - if params, ok := tool["parameters"].(map[string]any); ok { - fn["parameters"] = params - } - out = append(out, map[string]any{ - "type": "function", - "function": fn, - }) - } - if len(out) == 0 { - return nil - } - return out -} - -func collectGeminiPassThrough(req map[string]any) map[string]any { - cfg, _ := req["generationConfig"].(map[string]any) - if len(cfg) == 0 { - return nil - } - out := map[string]any{} - if v, ok := cfg["temperature"]; ok { - out["temperature"] = v - } - if v, ok := cfg["topP"]; ok { - out["top_p"] = v - } - if v, ok := cfg["maxOutputTokens"]; ok { - out["max_tokens"] = v - } - if v, ok := cfg["stopSequences"]; ok { - out["stop"] = v - } - if len(out) == 0 { - return nil - } - return out -} - -func asString(v any) string { - s, _ := v.(string) - return s -} - -func stringifyJSON(v any) string { - switch x := v.(type) { - case nil: - return "{}" - case string: - s := strings.TrimSpace(x) - if s == "" { - return "{}" - } - return s - default: - b, err := json.Marshal(x) - if err != nil || len(b) == 0 { - return "{}" - } - return string(b) - } -} diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go new file mode 100644 index 0000000..1148a7a --- /dev/null +++ b/internal/adapter/gemini/convert_messages.go @@ -0,0 +1,153 @@ +package gemini + +import "strings" + +func geminiMessagesFromRequest(req map[string]any) []any { + out := make([]any, 0, 8) + if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { + out = append(out, map[string]any{ + "role": "system", + "content": sys, + }) + } + + contents, _ := req["contents"].([]any) + for _, item := range contents { + content, ok := item.(map[string]any) + if !ok { + continue + } + role := mapGeminiRole(content["role"]) + if role == "" { + role = "user" + } + parts, _ := content["parts"].([]any) + if len(parts) == 0 { + if text := strings.TrimSpace(asString(content["text"])); text != "" { + out = append(out, map[string]any{ + "role": role, + "content": text, + }) + } + continue + } + + textParts := make([]string, 0, len(parts)) + flushText := func() { + if len(textParts) == 0 { + return + } + out = append(out, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + textParts = textParts[:0] + } + + for _, rawPart := range parts { + part, ok := rawPart.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + textParts = append(textParts, text) + continue + } + + if fnCall, ok := part["functionCall"].(map[string]any); ok { + flushText() + if name := strings.TrimSpace(asString(fnCall["name"])); name != "" { + callID := strings.TrimSpace(asString(fnCall["id"])) + if callID == "" { + callID = "call_gemini" + } + out = append(out, map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": stringifyJSON(fnCall["args"]), + }, + }, + }, + }) + } + continue + } + + if fnResp, ok := part["functionResponse"].(map[string]any); ok { + flushText() + name := strings.TrimSpace(asString(fnResp["name"])) + callID := strings.TrimSpace(asString(fnResp["id"])) + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["callId"])) + } + if callID == "" { + callID = strings.TrimSpace(asString(fnResp["tool_call_id"])) + } + if callID == "" { + callID = "call_gemini" + } + content := fnResp["response"] + if content == nil { + content = fnResp["output"] + } + if content == nil { + content = "" + } + msg := map[string]any{ + "role": "tool", + "tool_call_id": callID, + "content": content, + } + if name != "" { + msg["name"] = name + } + out = append(out, msg) + } + } + flushText() + } + return out +} + +func normalizeGeminiSystemInstruction(raw any) string { + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if parts, ok := v["parts"].([]any); ok { + texts := make([]string, 0, len(parts)) + for _, item := range parts { + part, ok := item.(map[string]any) + if !ok { + continue + } + if text := strings.TrimSpace(asString(part["text"])); text != "" { + texts = append(texts, text) + } + } + return strings.Join(texts, "\n") + } + if text := strings.TrimSpace(asString(v["text"])); text != "" { + return text + } + } + return "" +} + +func mapGeminiRole(v any) string { + switch strings.ToLower(strings.TrimSpace(asString(v))) { + case "user": + return "user" + case "model", "assistant": + return "assistant" + case "system": + return "system" + default: + return "" + } +} diff --git a/internal/adapter/gemini/convert_passthrough.go b/internal/adapter/gemini/convert_passthrough.go new file mode 100644 index 0000000..05cd6cd --- /dev/null +++ b/internal/adapter/gemini/convert_passthrough.go @@ -0,0 +1,54 @@ +package gemini + +import ( + "encoding/json" + "strings" +) + +func collectGeminiPassThrough(req map[string]any) map[string]any { + cfg, _ := req["generationConfig"].(map[string]any) + if len(cfg) == 0 { + return nil + } + out := map[string]any{} + if v, ok := cfg["temperature"]; ok { + out["temperature"] = v + } + if v, ok := cfg["topP"]; ok { + out["top_p"] = v + } + if v, ok := cfg["maxOutputTokens"]; ok { + out["max_tokens"] = v + } + if v, ok := cfg["stopSequences"]; ok { + out["stop"] = v + } + if len(out) == 0 { + return nil + } + return out +} + +func asString(v any) string { + s, _ := v.(string) + return s +} + +func stringifyJSON(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/gemini/convert_request.go b/internal/adapter/gemini/convert_request.go new file mode 100644 index 0000000..2eca687 --- /dev/null +++ b/internal/adapter/gemini/convert_request.go @@ -0,0 +1,46 @@ +package gemini + +import ( + "fmt" + "strings" + + "ds2api/internal/adapter/openai" + "ds2api/internal/config" + "ds2api/internal/util" +) + +func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) { + requestedModel := strings.TrimSpace(routeModel) + if requestedModel == "" { + return util.StandardRequest{}, fmt.Errorf("model is required in request path") + } + + resolvedModel, ok := config.ResolveModel(store, requestedModel) + if !ok { + return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel) + } + thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel) + + messagesRaw := geminiMessagesFromRequest(req) + if len(messagesRaw) == 0 { + return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.") + } + + toolsRaw := convertGeminiTools(req["tools"]) + finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") + passThrough := collectGeminiPassThrough(req) + + return util.StandardRequest{ + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, + }, nil +} diff --git a/internal/adapter/gemini/convert_tools.go b/internal/adapter/gemini/convert_tools.go new file mode 100644 index 0000000..4611f85 --- /dev/null +++ b/internal/adapter/gemini/convert_tools.go @@ -0,0 +1,71 @@ +package gemini + +import "strings" + +func convertGeminiTools(raw any) []any { + tools, _ := raw.([]any) + if len(tools) == 0 { + return nil + } + out := make([]any, 0, len(tools)) + for _, item := range tools { + tool, ok := item.(map[string]any) + if !ok { + continue + } + + if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 { + for _, declRaw := range fnDecls { + decl, ok := declRaw.(map[string]any) + if !ok { + continue + } + name := strings.TrimSpace(asString(decl["name"])) + if name == "" { + continue + } + function := map[string]any{ + "name": name, + } + if desc := strings.TrimSpace(asString(decl["description"])); desc != "" { + function["description"] = desc + } + if params, ok := decl["parameters"].(map[string]any); ok { + function["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": function, + }) + } + continue + } + + // OpenAI-style passthrough fallback. + if _, ok := tool["function"].(map[string]any); ok { + out = append(out, tool) + continue + } + + // Loose fallback for flattened function schema objects. + name := strings.TrimSpace(asString(tool["name"])) + if name == "" { + continue + } + fn := map[string]any{"name": name} + if desc := strings.TrimSpace(asString(tool["description"])); desc != "" { + fn["description"] = desc + } + if params, ok := tool["parameters"].(map[string]any); ok { + fn["parameters"] = params + } + out = append(out, map[string]any{ + "type": "function", + "function": fn, + }) + } + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/gemini/handler.go b/internal/adapter/gemini/handler.go deleted file mode 100644 index 8daaeda..0000000 --- a/internal/adapter/gemini/handler.go +++ /dev/null @@ -1,348 +0,0 @@ -package gemini - -import ( - "encoding/json" - "io" - "net/http" - "strings" - "time" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/auth" - "ds2api/internal/deepseek" - "ds2api/internal/sse" - streamengine "ds2api/internal/stream" - "ds2api/internal/util" -) - -var writeJSON = util.WriteJSON - -type Handler struct { - Store ConfigReader - Auth AuthResolver - DS DeepSeekCaller -} - -func RegisterRoutes(r chi.Router, h *Handler) { - r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent) - r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent) - r.Post("/v1/models/{model}:generateContent", h.GenerateContent) - r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent) -} - -func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) { - h.handleGenerateContent(w, r, false) -} - -func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) { - h.handleGenerateContent(w, r, true) -} - -func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeGeminiError(w, status, detail) - return - } - defer h.Auth.Release(a) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeGeminiError(w, http.StatusBadRequest, "invalid json") - return - } - - routeModel := strings.TrimSpace(chi.URLParam(r, "model")) - stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) - if err != nil { - writeGeminiError(w, http.StatusBadRequest, err.Error()) - return - } - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { - writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - writeGeminiError(w, http.StatusUnauthorized, "Invalid token.") - } - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - payload := stdReq.CompletionPayload(sessionID) - resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) - if err != nil { - writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.") - return - } - - if stream { - h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) - return - } - h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) -} - -func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) - return - } - - result := sse.CollectStream(resp, thinkingEnabled, true) - writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) -} - -func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - - rc := http.NewResponseController(w) - _, canFlush := w.(http.Flusher) - runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) - - initialType := "text" - if thinkingEnabled { - initialType = "thinking" - } - streamengine.ConsumeSSE(streamengine.ConsumeConfig{ - Context: r.Context(), - Body: resp.Body, - ThinkingEnabled: thinkingEnabled, - InitialType: initialType, - KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, - IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, - MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, - }, streamengine.ConsumeHooks{ - OnParsed: runtime.onParsed, - OnFinalize: func(_ streamengine.StopReason, _ error) { - runtime.finalize() - }, - }) -} - -func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) - return map[string]any{ - "candidates": []map[string]any{ - { - "index": 0, - "content": map[string]any{ - "role": "model", - "parts": parts, - }, - "finishReason": "STOP", - }, - }, - "modelVersion": model, - "usageMetadata": usage, - } -} - -func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - return map[string]any{ - "promptTokenCount": promptTokens, - "candidatesTokenCount": reasoningTokens + completionTokens, - "totalTokenCount": promptTokens + reasoningTokens + completionTokens, - } -} - -func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any { - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { - detected = util.ParseToolCalls(finalThinking, toolNames) - } - if len(detected) > 0 { - parts := make([]map[string]any, 0, len(detected)) - for _, tc := range detected { - parts = append(parts, map[string]any{ - "functionCall": map[string]any{ - "name": tc.Name, - "args": tc.Input, - }, - }) - } - return parts - } - - text := finalText - if strings.TrimSpace(text) == "" { - text = finalThinking - } - return []map[string]any{{"text": text}} -} - -type geminiStreamRuntime struct { - w http.ResponseWriter - rc *http.ResponseController - canFlush bool - - model string - finalPrompt string - - thinkingEnabled bool - searchEnabled bool - bufferContent bool - toolNames []string - - thinking strings.Builder - text strings.Builder -} - -func newGeminiStreamRuntime( - w http.ResponseWriter, - rc *http.ResponseController, - canFlush bool, - model string, - finalPrompt string, - thinkingEnabled bool, - searchEnabled bool, - toolNames []string, -) *geminiStreamRuntime { - return &geminiStreamRuntime{ - w: w, - rc: rc, - canFlush: canFlush, - model: model, - finalPrompt: finalPrompt, - thinkingEnabled: thinkingEnabled, - searchEnabled: searchEnabled, - bufferContent: len(toolNames) > 0, - toolNames: toolNames, - } -} - -func (s *geminiStreamRuntime) sendChunk(payload map[string]any) { - b, _ := json.Marshal(payload) - _, _ = s.w.Write([]byte("data: ")) - _, _ = s.w.Write(b) - _, _ = s.w.Write([]byte("\n\n")) - if s.canFlush { - _ = s.rc.Flush() - } -} - -func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { - if !parsed.Parsed { - return streamengine.ParsedDecision{} - } - if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { - return streamengine.ParsedDecision{Stop: true} - } - - contentSeen := false - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { - continue - } - contentSeen = true - if p.Type == "thinking" { - if s.thinkingEnabled { - s.thinking.WriteString(p.Text) - } - continue - } - s.text.WriteString(p.Text) - if s.bufferContent { - continue - } - s.sendChunk(map[string]any{ - "candidates": []map[string]any{ - { - "index": 0, - "content": map[string]any{ - "role": "model", - "parts": []map[string]any{{"text": p.Text}}, - }, - }, - }, - "modelVersion": s.model, - }) - } - return streamengine.ParsedDecision{ContentSeen: contentSeen} -} - -func (s *geminiStreamRuntime) finalize() { - finalThinking := s.thinking.String() - finalText := s.text.String() - - if s.bufferContent { - parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) - s.sendChunk(map[string]any{ - "candidates": []map[string]any{ - { - "index": 0, - "content": map[string]any{ - "role": "model", - "parts": parts, - }, - }, - }, - "modelVersion": s.model, - }) - } - - s.sendChunk(map[string]any{ - "candidates": []map[string]any{ - { - "index": 0, - "finishReason": "STOP", - }, - }, - "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), - }) -} - -func writeGeminiError(w http.ResponseWriter, status int, message string) { - errorStatus := "INVALID_ARGUMENT" - switch status { - case http.StatusUnauthorized: - errorStatus = "UNAUTHENTICATED" - case http.StatusForbidden: - errorStatus = "PERMISSION_DENIED" - case http.StatusTooManyRequests: - errorStatus = "RESOURCE_EXHAUSTED" - case http.StatusNotFound: - errorStatus = "NOT_FOUND" - default: - if status >= 500 { - errorStatus = "INTERNAL" - } - } - writeJSON(w, status, map[string]any{ - "error": map[string]any{ - "code": status, - "message": message, - "status": errorStatus, - }, - }) -} diff --git a/internal/adapter/gemini/handler_errors.go b/internal/adapter/gemini/handler_errors.go new file mode 100644 index 0000000..09df09b --- /dev/null +++ b/internal/adapter/gemini/handler_errors.go @@ -0,0 +1,28 @@ +package gemini + +import "net/http" + +func writeGeminiError(w http.ResponseWriter, status int, message string) { + errorStatus := "INVALID_ARGUMENT" + switch status { + case http.StatusUnauthorized: + errorStatus = "UNAUTHENTICATED" + case http.StatusForbidden: + errorStatus = "PERMISSION_DENIED" + case http.StatusTooManyRequests: + errorStatus = "RESOURCE_EXHAUSTED" + case http.StatusNotFound: + errorStatus = "NOT_FOUND" + default: + if status >= 500 { + errorStatus = "INTERNAL" + } + } + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "code": status, + "message": message, + "status": errorStatus, + }, + }) +} diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go new file mode 100644 index 0000000..9144a42 --- /dev/null +++ b/internal/adapter/gemini/handler_generate.go @@ -0,0 +1,135 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/sse" + "ds2api/internal/util" +) + +func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) { + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeGeminiError(w, status, detail) + return + } + defer h.Auth.Release(a) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + return + } + + routeModel := strings.TrimSpace(chi.URLParam(r, "model")) + stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream) + if err != nil { + writeGeminiError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeGeminiError(w, http.StatusUnauthorized, "Invalid token.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + + if stream { + h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + result := sse.CollectStream(resp, thinkingEnabled, true) + writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames)) +} + +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + return map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + "finishReason": "STOP", + }, + }, + "modelVersion": model, + "usageMetadata": usage, + } +} + +func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "promptTokenCount": promptTokens, + "candidatesTokenCount": reasoningTokens + completionTokens, + "totalTokenCount": promptTokens + reasoningTokens + completionTokens, + } +} + +func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + if len(detected) > 0 { + parts := make([]map[string]any, 0, len(detected)) + for _, tc := range detected { + parts = append(parts, map[string]any{ + "functionCall": map[string]any{ + "name": tc.Name, + "args": tc.Input, + }, + }) + } + return parts + } + + text := finalText + if strings.TrimSpace(text) == "" { + text = finalThinking + } + return []map[string]any{{"text": text}} +} diff --git a/internal/adapter/gemini/handler_routes.go b/internal/adapter/gemini/handler_routes.go new file mode 100644 index 0000000..6850b51 --- /dev/null +++ b/internal/adapter/gemini/handler_routes.go @@ -0,0 +1,32 @@ +package gemini + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/util" +) + +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent) + r.Post("/v1/models/{model}:generateContent", h.GenerateContent) + r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent) +} + +func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, false) +} + +func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) { + h.handleGenerateContent(w, r, true) +} diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go new file mode 100644 index 0000000..83a393d --- /dev/null +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -0,0 +1,175 @@ +package gemini + +import ( + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "ds2api/internal/deepseek" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body))) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames) + + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnParsed: runtime.onParsed, + OnFinalize: func(_ streamengine.StopReason, _ error) { + runtime.finalize() + }, + }) +} + +type geminiStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + model string + finalPrompt string + + thinkingEnabled bool + searchEnabled bool + bufferContent bool + toolNames []string + + thinking strings.Builder + text strings.Builder +} + +func newGeminiStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, +) *geminiStreamRuntime { + return &geminiStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + bufferContent: len(toolNames) > 0, + toolNames: toolNames, + } +} + +func (s *geminiStreamRuntime) sendChunk(payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if s.thinkingEnabled { + s.thinking.WriteString(p.Text) + } + continue + } + s.text.WriteString(p.Text) + if s.bufferContent { + continue + } + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": []map[string]any{{"text": p.Text}}, + }, + }, + }, + "modelVersion": s.model, + }) + } + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} + +func (s *geminiStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + + if s.bufferContent { + parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames) + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "content": map[string]any{ + "role": "model", + "parts": parts, + }, + }, + }, + "modelVersion": s.model, + }) + } + + s.sendChunk(map[string]any{ + "candidates": []map[string]any{ + { + "index": 0, + "finishReason": "STOP", + }, + }, + "modelVersion": s.model, + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + }) +} diff --git a/internal/adapter/openai/error_shape_test.go b/internal/adapter/openai/error_shape_test.go index c169e04..8c73e4b 100644 --- a/internal/adapter/openai/error_shape_test.go +++ b/internal/adapter/openai/error_shape_test.go @@ -32,4 +32,3 @@ func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) { t.Fatal("expected param field") } } - diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go deleted file mode 100644 index 391a035..0000000 --- a/internal/adapter/openai/handler.go +++ /dev/null @@ -1,386 +0,0 @@ -package openai - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "sync" - "time" - - "github.com/go-chi/chi/v5" - "github.com/google/uuid" - - "ds2api/internal/auth" - "ds2api/internal/config" - "ds2api/internal/deepseek" - openaifmt "ds2api/internal/format/openai" - "ds2api/internal/sse" - streamengine "ds2api/internal/stream" - "ds2api/internal/util" -) - -// writeJSON is a package-internal alias kept to avoid mass-renaming across -// every call-site in this file. It delegates to the shared util version. -var writeJSON = util.WriteJSON - -type Handler struct { - Store ConfigReader - Auth AuthResolver - DS DeepSeekCaller - - leaseMu sync.Mutex - streamLeases map[string]streamLease - responsesMu sync.Mutex - responses *responseStore -} - -type streamLease struct { - Auth *auth.RequestAuth - ExpiresAt time.Time -} - -func RegisterRoutes(r chi.Router, h *Handler) { - r.Get("/v1/models", h.ListModels) - r.Get("/v1/models/{model_id}", h.GetModel) - r.Post("/v1/chat/completions", h.ChatCompletions) - r.Post("/v1/responses", h.Responses) - r.Get("/v1/responses/{response_id}", h.GetResponseByID) - r.Post("/v1/embeddings", h.Embeddings) -} - -func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) -} - -func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) { - modelID := strings.TrimSpace(chi.URLParam(r, "model_id")) - model, ok := config.OpenAIModelByID(h.Store, modelID) - if !ok { - writeOpenAIError(w, http.StatusNotFound, "Model not found.") - return - } - writeJSON(w, http.StatusOK, model) -} - -func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { - if isVercelStreamReleaseRequest(r) { - h.handleVercelStreamRelease(w, r) - return - } - if isVercelStreamPrepareRequest(r) { - h.handleVercelStreamPrepare(w, r) - return - } - - a, err := h.Auth.Determine(r) - if err != nil { - status := http.StatusUnauthorized - detail := err.Error() - if err == auth.ErrNoAccount { - status = http.StatusTooManyRequests - } - writeOpenAIError(w, status, detail) - return - } - defer h.Auth.Release(a) - r = r.WithContext(auth.WithAuth(r.Context(), a)) - - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeOpenAIError(w, http.StatusBadRequest, "invalid json") - return - } - stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) - if err != nil { - writeOpenAIError(w, http.StatusBadRequest, err.Error()) - return - } - - sessionID, err := h.DS.CreateSession(r.Context(), a, 3) - if err != nil { - if a.UseConfigToken { - writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") - } else { - writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") - } - return - } - pow, err := h.DS.GetPow(r.Context(), a, 3) - if err != nil { - writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") - return - } - payload := stdReq.CompletionPayload(sessionID) - resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) - if err != nil { - writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") - return - } - if stdReq.Stream { - h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) - return - } - h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) -} - -func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { - if resp.StatusCode != http.StatusOK { - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - writeOpenAIError(w, resp.StatusCode, string(body)) - return - } - _ = ctx - result := sse.CollectStream(resp, thinkingEnabled, true) - - finalThinking := result.Thinking - finalText := result.Text - respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) - writeJSON(w, http.StatusOK, respBody) -} - -func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - writeOpenAIError(w, resp.StatusCode, string(body)) - return - } - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache, no-transform") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") - rc := http.NewResponseController(w) - _, canFlush := w.(http.Flusher) - if !canFlush { - config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") - } - - created := time.Now().Unix() - bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() - emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() - initialType := "text" - if thinkingEnabled { - initialType = "thinking" - } - - streamRuntime := newChatStreamRuntime( - w, - rc, - canFlush, - completionID, - created, - model, - finalPrompt, - thinkingEnabled, - searchEnabled, - toolNames, - bufferToolContent, - emitEarlyToolDeltas, - ) - - streamengine.ConsumeSSE(streamengine.ConsumeConfig{ - Context: r.Context(), - Body: resp.Body, - ThinkingEnabled: thinkingEnabled, - InitialType: initialType, - KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, - IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, - MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, - }, streamengine.ConsumeHooks{ - OnKeepAlive: func() { - streamRuntime.sendKeepAlive() - }, - OnParsed: streamRuntime.onParsed, - OnFinalize: func(reason streamengine.StopReason, _ error) { - if string(reason) == "content_filter" { - streamRuntime.finalize("content_filter") - return - } - streamRuntime.finalize("stop") - }, - }) -} - -func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { - toolSchemas := make([]string, 0, len(tools)) - names := make([]string, 0, len(tools)) - for _, t := range tools { - tool, ok := t.(map[string]any) - if !ok { - continue - } - fn, _ := tool["function"].(map[string]any) - if len(fn) == 0 { - fn = tool - } - name, _ := fn["name"].(string) - desc, _ := fn["description"].(string) - schema, _ := fn["parameters"].(map[string]any) - if name == "" { - name = "unknown" - } - names = append(names, name) - if desc == "" { - desc = "No description available" - } - b, _ := json.Marshal(schema) - toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) - } - if len(toolSchemas) == 0 { - return messages, names - } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." - - for i := range messages { - if messages[i]["role"] == "system" { - old, _ := messages[i]["content"].(string) - messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt) - return messages, names - } - } - messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...) - return messages, names -} - -func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { - if len(deltas) == 0 { - return nil - } - out := make([]map[string]any, 0, len(deltas)) - for _, d := range deltas { - if d.Name == "" && d.Arguments == "" { - continue - } - callID, ok := ids[d.Index] - if !ok || callID == "" { - callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - ids[d.Index] = callID - } - item := map[string]any{ - "index": d.Index, - "id": callID, - "type": "function", - } - fn := map[string]any{} - if d.Name != "" { - fn["name"] = d.Name - } - if d.Arguments != "" { - fn["arguments"] = d.Arguments - } - if len(fn) > 0 { - item["function"] = fn - } - out = append(out, item) - } - return out -} - -func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any { - if len(calls) == 0 { - return nil - } - out := make([]map[string]any, 0, len(calls)) - for i, c := range calls { - callID := "" - if ids != nil { - callID = strings.TrimSpace(ids[i]) - } - if callID == "" { - callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - if ids != nil { - ids[i] = callID - } - } - args, _ := json.Marshal(c.Input) - out = append(out, map[string]any{ - "index": i, - "id": callID, - "type": "function", - "function": map[string]any{ - "name": c.Name, - "arguments": string(args), - }, - }) - } - return out -} - -func writeOpenAIError(w http.ResponseWriter, status int, message string) { - writeJSON(w, status, map[string]any{ - "error": map[string]any{ - "message": message, - "type": openAIErrorType(status), - "code": openAIErrorCode(status), - "param": nil, - }, - }) -} - -func openAIErrorType(status int) string { - switch status { - case http.StatusBadRequest: - return "invalid_request_error" - case http.StatusUnauthorized: - return "authentication_error" - case http.StatusForbidden: - return "permission_error" - case http.StatusTooManyRequests: - return "rate_limit_error" - case http.StatusServiceUnavailable: - return "service_unavailable_error" - default: - if status >= 500 { - return "api_error" - } - return "invalid_request_error" - } -} - -func openAIErrorCode(status int) string { - switch status { - case http.StatusBadRequest: - return "invalid_request" - case http.StatusUnauthorized: - return "authentication_failed" - case http.StatusForbidden: - return "forbidden" - case http.StatusTooManyRequests: - return "rate_limit_exceeded" - case http.StatusNotFound: - return "not_found" - case http.StatusServiceUnavailable: - return "service_unavailable" - default: - if status >= 500 { - return "internal_error" - } - return "invalid_request" - } -} - -func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { - for k, v := range collectOpenAIChatPassThrough(req) { - payload[k] = v - } -} - -func (h *Handler) toolcallFeatureMatchEnabled() bool { - if h == nil || h.Store == nil { - return true - } - mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) - return mode == "" || mode == "feature_match" -} - -func (h *Handler) toolcallEarlyEmitHighConfidence() bool { - if h == nil || h.Store == nil { - return true - } - level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) - return level == "" || level == "high" -} diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go new file mode 100644 index 0000000..26a4bf2 --- /dev/null +++ b/internal/adapter/openai/handler_chat.go @@ -0,0 +1,156 @@ +package openai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/deepseek" + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" +) + +func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { + if isVercelStreamReleaseRequest(r) { + h.handleVercelStreamRelease(w, r) + return + } + if isVercelStreamPrepareRequest(r) { + h.handleVercelStreamPrepare(w, r) + return + } + + a, err := h.Auth.Determine(r) + if err != nil { + status := http.StatusUnauthorized + detail := err.Error() + if err == auth.ErrNoAccount { + status = http.StatusTooManyRequests + } + writeOpenAIError(w, status, detail) + return + } + defer h.Auth.Release(a) + r = r.WithContext(auth.WithAuth(r.Context(), a)) + + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeOpenAIError(w, http.StatusBadRequest, "invalid json") + return + } + stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r)) + if err != nil { + writeOpenAIError(w, http.StatusBadRequest, err.Error()) + return + } + + sessionID, err := h.DS.CreateSession(r.Context(), a, 3) + if err != nil { + if a.UseConfigToken { + writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.") + } else { + writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.") + } + return + } + pow, err := h.DS.GetPow(r.Context(), a, 3) + if err != nil { + writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).") + return + } + payload := stdReq.CompletionPayload(sessionID) + resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3) + if err != nil { + writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.") + return + } + if stdReq.Stream { + h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames) + return + } + h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames) +} + +func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) { + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, string(body)) + return + } + _ = ctx + result := sse.CollectStream(resp, thinkingEnabled, true) + + finalThinking := result.Thinking + finalText := result.Text + respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) + writeJSON(w, http.StatusOK, respBody) +} + +func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) { + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + writeOpenAIError(w, resp.StatusCode, string(body)) + return + } + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache, no-transform") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Accel-Buffering", "no") + rc := http.NewResponseController(w) + _, canFlush := w.(http.Flusher) + if !canFlush { + config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered") + } + + created := time.Now().Unix() + bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled() + emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence() + initialType := "text" + if thinkingEnabled { + initialType = "thinking" + } + + streamRuntime := newChatStreamRuntime( + w, + rc, + canFlush, + completionID, + created, + model, + finalPrompt, + thinkingEnabled, + searchEnabled, + toolNames, + bufferToolContent, + emitEarlyToolDeltas, + ) + + streamengine.ConsumeSSE(streamengine.ConsumeConfig{ + Context: r.Context(), + Body: resp.Body, + ThinkingEnabled: thinkingEnabled, + InitialType: initialType, + KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second, + IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second, + MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount, + }, streamengine.ConsumeHooks{ + OnKeepAlive: func() { + streamRuntime.sendKeepAlive() + }, + OnParsed: streamRuntime.onParsed, + OnFinalize: func(reason streamengine.StopReason, _ error) { + if string(reason) == "content_filter" { + streamRuntime.finalize("content_filter") + return + } + streamRuntime.finalize("stop") + }, + }) +} diff --git a/internal/adapter/openai/handler_errors.go b/internal/adapter/openai/handler_errors.go new file mode 100644 index 0000000..62249d2 --- /dev/null +++ b/internal/adapter/openai/handler_errors.go @@ -0,0 +1,56 @@ +package openai + +import "net/http" + +func writeOpenAIError(w http.ResponseWriter, status int, message string) { + writeJSON(w, status, map[string]any{ + "error": map[string]any{ + "message": message, + "type": openAIErrorType(status), + "code": openAIErrorCode(status), + "param": nil, + }, + }) +} + +func openAIErrorType(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request_error" + case http.StatusUnauthorized: + return "authentication_error" + case http.StatusForbidden: + return "permission_error" + case http.StatusTooManyRequests: + return "rate_limit_error" + case http.StatusServiceUnavailable: + return "service_unavailable_error" + default: + if status >= 500 { + return "api_error" + } + return "invalid_request_error" + } +} + +func openAIErrorCode(status int) string { + switch status { + case http.StatusBadRequest: + return "invalid_request" + case http.StatusUnauthorized: + return "authentication_failed" + case http.StatusForbidden: + return "forbidden" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "not_found" + case http.StatusServiceUnavailable: + return "service_unavailable" + default: + if status >= 500 { + return "internal_error" + } + return "invalid_request" + } +} diff --git a/internal/adapter/openai/handler_routes.go b/internal/adapter/openai/handler_routes.go new file mode 100644 index 0000000..a0cfcd6 --- /dev/null +++ b/internal/adapter/openai/handler_routes.go @@ -0,0 +1,57 @@ +package openai + +import ( + "net/http" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/auth" + "ds2api/internal/config" + "ds2api/internal/util" +) + +// writeJSON is a package-internal alias kept to avoid mass-renaming across +// every call-site in this package. +var writeJSON = util.WriteJSON + +type Handler struct { + Store ConfigReader + Auth AuthResolver + DS DeepSeekCaller + + leaseMu sync.Mutex + streamLeases map[string]streamLease + responsesMu sync.Mutex + responses *responseStore +} + +type streamLease struct { + Auth *auth.RequestAuth + ExpiresAt time.Time +} + +func RegisterRoutes(r chi.Router, h *Handler) { + r.Get("/v1/models", h.ListModels) + r.Get("/v1/models/{model_id}", h.GetModel) + r.Post("/v1/chat/completions", h.ChatCompletions) + r.Post("/v1/responses", h.Responses) + r.Get("/v1/responses/{response_id}", h.GetResponseByID) + r.Post("/v1/embeddings", h.Embeddings) +} + +func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, config.OpenAIModelsResponse()) +} + +func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) { + modelID := strings.TrimSpace(chi.URLParam(r, "model_id")) + model, ok := config.OpenAIModelByID(h.Store, modelID) + if !ok { + writeOpenAIError(w, http.StatusNotFound, "Model not found.") + return + } + writeJSON(w, http.StatusOK, model) +} diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go new file mode 100644 index 0000000..d939c68 --- /dev/null +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -0,0 +1,116 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { + toolSchemas := make([]string, 0, len(tools)) + names := make([]string, 0, len(tools)) + for _, t := range tools { + tool, ok := t.(map[string]any) + if !ok { + continue + } + fn, _ := tool["function"].(map[string]any) + if len(fn) == 0 { + fn = tool + } + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + schema, _ := fn["parameters"].(map[string]any) + if name == "" { + name = "unknown" + } + names = append(names, name) + if desc == "" { + desc = "No description available" + } + b, _ := json.Marshal(schema) + toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b))) + } + if len(toolSchemas) == 0 { + return messages, names + } + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." + + for i := range messages { + if messages[i]["role"] == "system" { + old, _ := messages[i]["content"].(string) + messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt) + return messages, names + } + } + messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...) + return messages, names +} + +func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { + if len(deltas) == 0 { + return nil + } + out := make([]map[string]any, 0, len(deltas)) + for _, d := range deltas { + if d.Name == "" && d.Arguments == "" { + continue + } + callID, ok := ids[d.Index] + if !ok || callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + ids[d.Index] = callID + } + item := map[string]any{ + "index": d.Index, + "id": callID, + "type": "function", + } + fn := map[string]any{} + if d.Name != "" { + fn["name"] = d.Name + } + if d.Arguments != "" { + fn["arguments"] = d.Arguments + } + if len(fn) > 0 { + item["function"] = fn + } + out = append(out, item) + } + return out +} + +func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any { + if len(calls) == 0 { + return nil + } + out := make([]map[string]any, 0, len(calls)) + for i, c := range calls { + callID := "" + if ids != nil { + callID = strings.TrimSpace(ids[i]) + } + if callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + if ids != nil { + ids[i] = callID + } + } + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "index": i, + "id": callID, + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} diff --git a/internal/adapter/openai/handler_toolcall_policy.go b/internal/adapter/openai/handler_toolcall_policy.go new file mode 100644 index 0000000..9f0e839 --- /dev/null +++ b/internal/adapter/openai/handler_toolcall_policy.go @@ -0,0 +1,25 @@ +package openai + +import "strings" + +func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) { + for k, v := range collectOpenAIChatPassThrough(req) { + payload[k] = v + } +} + +func (h *Handler) toolcallFeatureMatchEnabled() bool { + if h == nil || h.Store == nil { + return true + } + mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode())) + return mode == "" || mode == "feature_match" +} + +func (h *Handler) toolcallEarlyEmitHighConfidence() bool { + if h == nil || h.Store == nil { + return true + } + level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence())) + return level == "" || level == "high" +} diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e71cafe..7b35e0c 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -2,7 +2,6 @@ package openai import ( "encoding/json" - "fmt" "io" "net/http" "strings" @@ -170,264 +169,3 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request, }, }) } - -func responsesMessagesFromRequest(req map[string]any) []any { - if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { - return prependInstructionMessage(msgs, req["instructions"]) - } - if rawInput, ok := req["input"]; ok { - if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 { - return prependInstructionMessage(msgs, req["instructions"]) - } - } - return nil -} - -func prependInstructionMessage(messages []any, instructions any) []any { - sys, _ := instructions.(string) - sys = strings.TrimSpace(sys) - if sys == "" { - return messages - } - out := make([]any, 0, len(messages)+1) - out = append(out, map[string]any{"role": "system", "content": sys}) - out = append(out, messages...) - return out -} - -func normalizeResponsesInputAsMessages(input any) []any { - switch v := input.(type) { - case string: - if strings.TrimSpace(v) == "" { - return nil - } - return []any{map[string]any{"role": "user", "content": v}} - case []any: - return normalizeResponsesInputArray(v) - case map[string]any: - if msg := normalizeResponsesInputItem(v); msg != nil { - return []any{msg} - } - if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { - return []any{map[string]any{"role": "user", "content": txt}} - } - if content, ok := v["content"]; ok { - if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { - return []any{map[string]any{"role": "user", "content": content}} - } - } - } - return nil -} - -func normalizeResponsesInputArray(items []any) []any { - if len(items) == 0 { - return nil - } - out := make([]any, 0, len(items)) - fallbackParts := make([]string, 0, len(items)) - flushFallback := func() { - if len(fallbackParts) == 0 { - return - } - out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")}) - fallbackParts = fallbackParts[:0] - } - - for _, item := range items { - switch x := item.(type) { - case map[string]any: - if msg := normalizeResponsesInputItem(x); msg != nil { - flushFallback() - out = append(out, msg) - continue - } - if s := normalizeResponsesFallbackPart(x); s != "" { - fallbackParts = append(fallbackParts, s) - } - default: - if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { - fallbackParts = append(fallbackParts, s) - } - } - } - flushFallback() - if len(out) == 0 { - return nil - } - return out -} - -func normalizeResponsesInputItem(m map[string]any) map[string]any { - if m == nil { - return nil - } - - role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) - if role != "" { - content := m["content"] - if content == nil { - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - content = txt - } - } - if content == nil { - return nil - } - return map[string]any{ - "role": role, - "content": content, - } - } - - itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) - switch itemType { - case "message", "input_message": - content := m["content"] - if content == nil { - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - content = txt - } - } - if content == nil { - return nil - } - role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) - if role == "" { - role = "user" - } - return map[string]any{ - "role": role, - "content": content, - } - case "function_call_output", "tool_result": - content := m["output"] - if content == nil { - content = m["content"] - } - if content == nil { - content = "" - } - out := map[string]any{ - "role": "tool", - "content": content, - } - if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { - out["tool_call_id"] = callID - } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { - out["tool_call_id"] = callID - } - if name := strings.TrimSpace(asString(m["name"])); name != "" { - out["name"] = name - } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { - out["name"] = name - } - return out - case "function_call", "tool_call": - name := strings.TrimSpace(asString(m["name"])) - var fn map[string]any - if rawFn, ok := m["function"].(map[string]any); ok { - fn = rawFn - if name == "" { - name = strings.TrimSpace(asString(fn["name"])) - } - } - if name == "" { - return nil - } - - var argsRaw any - if v, ok := m["arguments"]; ok { - argsRaw = v - } else if v, ok := m["input"]; ok { - argsRaw = v - } - if argsRaw == nil && fn != nil { - if v, ok := fn["arguments"]; ok { - argsRaw = v - } else if v, ok := fn["input"]; ok { - argsRaw = v - } - } - - functionPayload := map[string]any{ - "name": name, - "arguments": stringifyToolCallArguments(argsRaw), - } - call := map[string]any{ - "type": "function", - "function": functionPayload, - } - if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { - call["id"] = callID - } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { - call["id"] = callID - } - return map[string]any{ - "role": "assistant", - "tool_calls": []any{call}, - } - case "input_text": - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - return map[string]any{ - "role": "user", - "content": txt, - } - } - } - - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - return map[string]any{ - "role": "user", - "content": txt, - } - } - if content, ok := m["content"]; ok { - if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { - return map[string]any{ - "role": "user", - "content": content, - } - } - } - return nil -} - -func normalizeResponsesFallbackPart(m map[string]any) string { - if m == nil { - return "" - } - if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - return txt - } - } - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - return txt - } - if content, ok := m["content"]; ok { - if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" { - return normalized - } - } - return strings.TrimSpace(fmt.Sprintf("%v", m)) -} - -func stringifyToolCallArguments(v any) string { - switch x := v.(type) { - case nil: - return "{}" - case string: - s := strings.TrimSpace(x) - if s == "" { - return "{}" - } - return s - default: - b, err := json.Marshal(x) - if err != nil || len(b) == 0 { - return "{}" - } - return string(b) - } -} diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go new file mode 100644 index 0000000..2a2dfc4 --- /dev/null +++ b/internal/adapter/openai/responses_input_items.go @@ -0,0 +1,181 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeResponsesInputItem(m map[string]any) map[string]any { + if m == nil { + return nil + } + + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role != "" { + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + return map[string]any{ + "role": role, + "content": content, + } + } + + itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + switch itemType { + case "message", "input_message": + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role == "" { + role = "user" + } + return map[string]any{ + "role": role, + "content": content, + } + case "function_call_output", "tool_result": + content := m["output"] + if content == nil { + content = m["content"] + } + if content == nil { + content = "" + } + out := map[string]any{ + "role": "tool", + "content": content, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + out["tool_call_id"] = callID + } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if name := strings.TrimSpace(asString(m["name"])); name != "" { + out["name"] = name + } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { + out["name"] = name + } + return out + case "function_call", "tool_call": + name := strings.TrimSpace(asString(m["name"])) + var fn map[string]any + if rawFn, ok := m["function"].(map[string]any); ok { + fn = rawFn + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + } + if name == "" { + return nil + } + + var argsRaw any + if v, ok := m["arguments"]; ok { + argsRaw = v + } else if v, ok := m["input"]; ok { + argsRaw = v + } + if argsRaw == nil && fn != nil { + if v, ok := fn["arguments"]; ok { + argsRaw = v + } else if v, ok := fn["input"]; ok { + argsRaw = v + } + } + + functionPayload := map[string]any{ + "name": name, + "arguments": stringifyToolCallArguments(argsRaw), + } + call := map[string]any{ + "type": "function", + "function": functionPayload, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + call["id"] = callID + } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { + call["id"] = callID + } + return map[string]any{ + "role": "assistant", + "tool_calls": []any{call}, + } + case "input_text": + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + } + + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + if content, ok := m["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return map[string]any{ + "role": "user", + "content": content, + } + } + } + return nil +} + +func normalizeResponsesFallbackPart(m map[string]any) string { + if m == nil { + return "" + } + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + } + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + if content, ok := m["content"]; ok { + if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" { + return normalized + } + } + return strings.TrimSpace(fmt.Sprintf("%v", m)) +} + +func stringifyToolCallArguments(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/openai/responses_input_normalize.go b/internal/adapter/openai/responses_input_normalize.go new file mode 100644 index 0000000..13f9e1a --- /dev/null +++ b/internal/adapter/openai/responses_input_normalize.go @@ -0,0 +1,93 @@ +package openai + +import ( + "fmt" + "strings" +) + +func responsesMessagesFromRequest(req map[string]any) []any { + if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + if rawInput, ok := req["input"]; ok { + if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 { + return prependInstructionMessage(msgs, req["instructions"]) + } + } + return nil +} + +func prependInstructionMessage(messages []any, instructions any) []any { + sys, _ := instructions.(string) + sys = strings.TrimSpace(sys) + if sys == "" { + return messages + } + out := make([]any, 0, len(messages)+1) + out = append(out, map[string]any{"role": "system", "content": sys}) + out = append(out, messages...) + return out +} + +func normalizeResponsesInputAsMessages(input any) []any { + switch v := input.(type) { + case string: + if strings.TrimSpace(v) == "" { + return nil + } + return []any{map[string]any{"role": "user", "content": v}} + case []any: + return normalizeResponsesInputArray(v) + case map[string]any: + if msg := normalizeResponsesInputItem(v); msg != nil { + return []any{msg} + } + if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { + return []any{map[string]any{"role": "user", "content": txt}} + } + if content, ok := v["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } + } + } + return nil +} + +func normalizeResponsesInputArray(items []any) []any { + if len(items) == 0 { + return nil + } + out := make([]any, 0, len(items)) + fallbackParts := make([]string, 0, len(items)) + flushFallback := func() { + if len(fallbackParts) == 0 { + return + } + out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")}) + fallbackParts = fallbackParts[:0] + } + + for _, item := range items { + switch x := item.(type) { + case map[string]any: + if msg := normalizeResponsesInputItem(x); msg != nil { + flushFallback() + out = append(out, msg) + continue + } + if s := normalizeResponsesFallbackPart(x); s != "" { + fallbackParts = append(fallbackParts, s) + } + default: + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + fallbackParts = append(fallbackParts, s) + } + } + } + flushFallback() + if len(out) == 0 { + return nil + } + return out +} diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go deleted file mode 100644 index 050965c..0000000 --- a/internal/adapter/openai/responses_stream_runtime.go +++ /dev/null @@ -1,366 +0,0 @@ -package openai - -import ( - "encoding/json" - "net/http" - "sort" - "strings" - - openaifmt "ds2api/internal/format/openai" - "ds2api/internal/sse" - streamengine "ds2api/internal/stream" - "ds2api/internal/util" - - "github.com/google/uuid" -) - -type responsesStreamRuntime struct { - w http.ResponseWriter - rc *http.ResponseController - canFlush bool - - responseID string - model string - finalPrompt string - toolNames []string - - thinkingEnabled bool - searchEnabled bool - - bufferToolContent bool - emitEarlyToolDeltas bool - toolCallsEmitted bool - toolCallsDoneEmitted bool - - sieve toolStreamSieveState - thinkingSieve toolStreamSieveState - thinking strings.Builder - text strings.Builder - streamToolCallIDs map[int]string - streamFunctionIDs map[int]string - functionDone map[int]bool - toolCallsDoneSigs map[string]bool - reasoningItemID string - - persistResponse func(obj map[string]any) -} - -func newResponsesStreamRuntime( - w http.ResponseWriter, - rc *http.ResponseController, - canFlush bool, - responseID string, - model string, - finalPrompt string, - thinkingEnabled bool, - searchEnabled bool, - toolNames []string, - bufferToolContent bool, - emitEarlyToolDeltas bool, - persistResponse func(obj map[string]any), -) *responsesStreamRuntime { - return &responsesStreamRuntime{ - w: w, - rc: rc, - canFlush: canFlush, - responseID: responseID, - model: model, - finalPrompt: finalPrompt, - thinkingEnabled: thinkingEnabled, - searchEnabled: searchEnabled, - toolNames: toolNames, - bufferToolContent: bufferToolContent, - emitEarlyToolDeltas: emitEarlyToolDeltas, - streamToolCallIDs: map[int]string{}, - streamFunctionIDs: map[int]string{}, - functionDone: map[int]bool{}, - toolCallsDoneSigs: map[string]bool{}, - persistResponse: persistResponse, - } -} - -func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { - b, _ := json.Marshal(payload) - _, _ = s.w.Write([]byte("event: " + event + "\n")) - _, _ = s.w.Write([]byte("data: ")) - _, _ = s.w.Write(b) - _, _ = s.w.Write([]byte("\n\n")) - if s.canFlush { - _ = s.rc.Flush() - } -} - -func (s *responsesStreamRuntime) sendCreated() { - s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model)) -} - -func (s *responsesStreamRuntime) sendDone() { - _, _ = s.w.Write([]byte("data: [DONE]\n\n")) - if s.canFlush { - _ = s.rc.Flush() - } -} - -func (s *responsesStreamRuntime) finalize() { - finalThinking := s.thinking.String() - finalText := s.text.String() - if strings.TrimSpace(finalThinking) != "" { - s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking)) - } - if s.bufferToolContent { - s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) - s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) - } - // Compatibility fallback: some streams only emit incremental tool deltas. - // Ensure final function_call_arguments.done is emitted at least once. - if s.toolCallsEmitted { - detected := util.ParseToolCalls(finalText, s.toolNames) - if len(detected) == 0 { - detected = util.ParseToolCalls(finalThinking, s.toolNames) - } - if len(detected) > 0 { - if !s.toolCallsDoneEmitted { - s.emitToolCallsDone(detected) - } else { - s.emitFunctionCallDoneEvents(detected) - } - } - } - - obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) - if s.toolCallsEmitted { - s.alignCompletedOutputCallIDs(obj) - } - if s.toolCallsEmitted { - obj["status"] = "completed" - } - if s.persistResponse != nil { - s.persistResponse(obj) - } - s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj)) - s.sendDone() -} - -func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { - if !parsed.Parsed { - return streamengine.ParsedDecision{} - } - if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { - return streamengine.ParsedDecision{Stop: true} - } - - contentSeen := false - for _, p := range parsed.Parts { - if p.Text == "" { - continue - } - if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { - continue - } - contentSeen = true - if p.Type == "thinking" { - if !s.thinkingEnabled { - continue - } - s.thinking.WriteString(p.Text) - s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) - s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text)) - if s.bufferToolContent { - s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) - } - continue - } - - s.text.WriteString(p.Text) - if !s.bufferToolContent { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) - continue - } - s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) - } - - return streamengine.ParsedDecision{ContentSeen: contentSeen} -} - -func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { - for _, evt := range events { - if emitContent && evt.Content != "" { - s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) - } - if len(evt.ToolCallDeltas) > 0 { - if !s.emitEarlyToolDeltas { - continue - } - formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) - if len(formatted) == 0 { - continue - } - s.toolCallsEmitted = true - s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted)) - s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) - } - if len(evt.ToolCalls) > 0 { - s.emitToolCallsDone(evt.ToolCalls) - } - } -} - -func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) { - if len(calls) == 0 { - return - } - sig := toolCallListSignature(calls) - if sig != "" && s.toolCallsDoneSigs[sig] { - return - } - if sig != "" { - s.toolCallsDoneSigs[sig] = true - } - formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs) - if len(formatted) == 0 { - return - } - s.toolCallsEmitted = true - s.toolCallsDoneEmitted = true - s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted)) - s.emitFunctionCallDoneEvents(calls) -} - -func (s *responsesStreamRuntime) ensureReasoningItemID() string { - if strings.TrimSpace(s.reasoningItemID) != "" { - return s.reasoningItemID - } - s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "") - return s.reasoningItemID -} - -func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string { - if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" { - return id - } - id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "") - s.streamFunctionIDs[index] = id - return id -} - -func (s *responsesStreamRuntime) ensureToolCallID(index int) string { - if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" { - return id - } - id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - s.streamToolCallIDs[index] = id - return id -} - -func (s *responsesStreamRuntime) functionOutputBaseIndex() int { - if strings.TrimSpace(s.thinking.String()) != "" { - return 1 - } - return 0 -} - -func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) { - for _, d := range deltas { - if strings.TrimSpace(d.Arguments) == "" { - continue - } - outputIndex := s.functionOutputBaseIndex() + d.Index - itemID := s.ensureFunctionItemID(outputIndex) - callID := s.ensureToolCallID(d.Index) - s.sendEvent( - "response.function_call_arguments.delta", - openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments), - ) - } -} - -func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) { - base := s.functionOutputBaseIndex() - for idx, tc := range calls { - if strings.TrimSpace(tc.Name) == "" { - continue - } - outputIndex := base + idx - if s.functionDone[outputIndex] { - continue - } - itemID := s.ensureFunctionItemID(outputIndex) - callID := s.ensureToolCallID(idx) - argsBytes, _ := json.Marshal(tc.Input) - s.sendEvent( - "response.function_call_arguments.done", - openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)), - ) - s.functionDone[outputIndex] = true - } -} - -func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) { - if obj == nil || len(s.streamToolCallIDs) == 0 { - return - } - output, _ := obj["output"].([]any) - if len(output) == 0 { - return - } - indices := make([]int, 0, len(s.streamToolCallIDs)) - for idx := range s.streamToolCallIDs { - indices = append(indices, idx) - } - sort.Ints(indices) - ordered := make([]string, 0, len(indices)) - for _, idx := range indices { - id := strings.TrimSpace(s.streamToolCallIDs[idx]) - if id == "" { - continue - } - ordered = append(ordered, id) - } - if len(ordered) == 0 { - return - } - - functionIdx := 0 - for _, item := range output { - m, _ := item.(map[string]any) - if m == nil { - continue - } - typ, _ := m["type"].(string) - switch typ { - case "function_call": - if functionIdx < len(ordered) { - m["call_id"] = ordered[functionIdx] - functionIdx++ - } - case "tool_calls": - tcArr, _ := m["tool_calls"].([]any) - for i, raw := range tcArr { - tc, _ := raw.(map[string]any) - if tc == nil { - continue - } - if i < len(ordered) { - tc["id"] = ordered[i] - } - } - } - } -} - -func toolCallListSignature(calls []util.ParsedToolCall) string { - if len(calls) == 0 { - return "" - } - var b strings.Builder - for i, tc := range calls { - if i > 0 { - b.WriteString("|") - } - b.WriteString(strings.TrimSpace(tc.Name)) - b.WriteString(":") - args, _ := json.Marshal(tc.Input) - b.Write(args) - } - return b.String() -} diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go new file mode 100644 index 0000000..5aad11e --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -0,0 +1,157 @@ +package openai + +import ( + "net/http" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/sse" + streamengine "ds2api/internal/stream" + "ds2api/internal/util" +) + +type responsesStreamRuntime struct { + w http.ResponseWriter + rc *http.ResponseController + canFlush bool + + responseID string + model string + finalPrompt string + toolNames []string + + thinkingEnabled bool + searchEnabled bool + + bufferToolContent bool + emitEarlyToolDeltas bool + toolCallsEmitted bool + toolCallsDoneEmitted bool + + sieve toolStreamSieveState + thinkingSieve toolStreamSieveState + thinking strings.Builder + text strings.Builder + streamToolCallIDs map[int]string + streamFunctionIDs map[int]string + functionDone map[int]bool + toolCallsDoneSigs map[string]bool + reasoningItemID string + + persistResponse func(obj map[string]any) +} + +func newResponsesStreamRuntime( + w http.ResponseWriter, + rc *http.ResponseController, + canFlush bool, + responseID string, + model string, + finalPrompt string, + thinkingEnabled bool, + searchEnabled bool, + toolNames []string, + bufferToolContent bool, + emitEarlyToolDeltas bool, + persistResponse func(obj map[string]any), +) *responsesStreamRuntime { + return &responsesStreamRuntime{ + w: w, + rc: rc, + canFlush: canFlush, + responseID: responseID, + model: model, + finalPrompt: finalPrompt, + thinkingEnabled: thinkingEnabled, + searchEnabled: searchEnabled, + toolNames: toolNames, + bufferToolContent: bufferToolContent, + emitEarlyToolDeltas: emitEarlyToolDeltas, + streamToolCallIDs: map[int]string{}, + streamFunctionIDs: map[int]string{}, + functionDone: map[int]bool{}, + toolCallsDoneSigs: map[string]bool{}, + persistResponse: persistResponse, + } +} + +func (s *responsesStreamRuntime) finalize() { + finalThinking := s.thinking.String() + finalText := s.text.String() + if strings.TrimSpace(finalThinking) != "" { + s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking)) + } + if s.bufferToolContent { + s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true) + s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false) + } + // Compatibility fallback: some streams only emit incremental tool deltas. + // Ensure final function_call_arguments.done is emitted at least once. + if s.toolCallsEmitted { + detected := util.ParseToolCalls(finalText, s.toolNames) + if len(detected) == 0 { + detected = util.ParseToolCalls(finalThinking, s.toolNames) + } + if len(detected) > 0 { + if !s.toolCallsDoneEmitted { + s.emitToolCallsDone(detected) + } else { + s.emitFunctionCallDoneEvents(detected) + } + } + } + + obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames) + if s.toolCallsEmitted { + s.alignCompletedOutputCallIDs(obj) + } + if s.toolCallsEmitted { + obj["status"] = "completed" + } + if s.persistResponse != nil { + s.persistResponse(obj) + } + s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj)) + s.sendDone() +} + +func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision { + if !parsed.Parsed { + return streamengine.ParsedDecision{} + } + if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop { + return streamengine.ParsedDecision{Stop: true} + } + + contentSeen := false + for _, p := range parsed.Parts { + if p.Text == "" { + continue + } + if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) { + continue + } + contentSeen = true + if p.Type == "thinking" { + if !s.thinkingEnabled { + continue + } + s.thinking.WriteString(p.Text) + s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text)) + s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text)) + if s.bufferToolContent { + s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false) + } + continue + } + + s.text.WriteString(p.Text) + if !s.bufferToolContent { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text)) + continue + } + s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true) + } + + return streamengine.ParsedDecision{ContentSeen: contentSeen} +} diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go new file mode 100644 index 0000000..fd36b6a --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_events.go @@ -0,0 +1,52 @@ +package openai + +import ( + "encoding/json" + + openaifmt "ds2api/internal/format/openai" +) + +func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) { + b, _ := json.Marshal(payload) + _, _ = s.w.Write([]byte("event: " + event + "\n")) + _, _ = s.w.Write([]byte("data: ")) + _, _ = s.w.Write(b) + _, _ = s.w.Write([]byte("\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) sendCreated() { + s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model)) +} + +func (s *responsesStreamRuntime) sendDone() { + _, _ = s.w.Write([]byte("data: [DONE]\n\n")) + if s.canFlush { + _ = s.rc.Flush() + } +} + +func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) { + for _, evt := range events { + if emitContent && evt.Content != "" { + s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content)) + } + if len(evt.ToolCallDeltas) > 0 { + if !s.emitEarlyToolDeltas { + continue + } + formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs) + if len(formatted) == 0 { + continue + } + s.toolCallsEmitted = true + s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted)) + s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas) + } + if len(evt.ToolCalls) > 0 { + s.emitToolCallsDone(evt.ToolCalls) + } + } +} diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go new file mode 100644 index 0000000..7891425 --- /dev/null +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go @@ -0,0 +1,172 @@ +package openai + +import ( + "encoding/json" + "sort" + "strings" + + openaifmt "ds2api/internal/format/openai" + "ds2api/internal/util" + + "github.com/google/uuid" +) + +func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) { + if len(calls) == 0 { + return + } + sig := toolCallListSignature(calls) + if sig != "" && s.toolCallsDoneSigs[sig] { + return + } + if sig != "" { + s.toolCallsDoneSigs[sig] = true + } + formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs) + if len(formatted) == 0 { + return + } + s.toolCallsEmitted = true + s.toolCallsDoneEmitted = true + s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted)) + s.emitFunctionCallDoneEvents(calls) +} + +func (s *responsesStreamRuntime) ensureReasoningItemID() string { + if strings.TrimSpace(s.reasoningItemID) != "" { + return s.reasoningItemID + } + s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "") + return s.reasoningItemID +} + +func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string { + if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.streamFunctionIDs[index] = id + return id +} + +func (s *responsesStreamRuntime) ensureToolCallID(index int) string { + if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" { + return id + } + id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + s.streamToolCallIDs[index] = id + return id +} + +func (s *responsesStreamRuntime) functionOutputBaseIndex() int { + if strings.TrimSpace(s.thinking.String()) != "" { + return 1 + } + return 0 +} + +func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) { + for _, d := range deltas { + if strings.TrimSpace(d.Arguments) == "" { + continue + } + outputIndex := s.functionOutputBaseIndex() + d.Index + itemID := s.ensureFunctionItemID(outputIndex) + callID := s.ensureToolCallID(d.Index) + s.sendEvent( + "response.function_call_arguments.delta", + openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments), + ) + } +} + +func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) { + base := s.functionOutputBaseIndex() + for idx, tc := range calls { + if strings.TrimSpace(tc.Name) == "" { + continue + } + outputIndex := base + idx + if s.functionDone[outputIndex] { + continue + } + itemID := s.ensureFunctionItemID(outputIndex) + callID := s.ensureToolCallID(idx) + argsBytes, _ := json.Marshal(tc.Input) + s.sendEvent( + "response.function_call_arguments.done", + openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)), + ) + s.functionDone[outputIndex] = true + } +} + +func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) { + if obj == nil || len(s.streamToolCallIDs) == 0 { + return + } + output, _ := obj["output"].([]any) + if len(output) == 0 { + return + } + indices := make([]int, 0, len(s.streamToolCallIDs)) + for idx := range s.streamToolCallIDs { + indices = append(indices, idx) + } + sort.Ints(indices) + ordered := make([]string, 0, len(indices)) + for _, idx := range indices { + id := strings.TrimSpace(s.streamToolCallIDs[idx]) + if id == "" { + continue + } + ordered = append(ordered, id) + } + if len(ordered) == 0 { + return + } + + functionIdx := 0 + for _, item := range output { + m, _ := item.(map[string]any) + if m == nil { + continue + } + typ, _ := m["type"].(string) + switch typ { + case "function_call": + if functionIdx < len(ordered) { + m["call_id"] = ordered[functionIdx] + functionIdx++ + } + case "tool_calls": + tcArr, _ := m["tool_calls"].([]any) + for i, raw := range tcArr { + tc, _ := raw.(map[string]any) + if tc == nil { + continue + } + if i < len(ordered) { + tc["id"] = ordered[i] + } + } + } + } +} + +func toolCallListSignature(calls []util.ParsedToolCall) string { + if len(calls) == 0 { + return "" + } + var b strings.Builder + for i, tc := range calls { + if i > 0 { + b.WriteString("|") + } + b.WriteString(strings.TrimSpace(tc.Name)) + b.WriteString(":") + args, _ := json.Marshal(tc.Input) + b.Write(args) + } + return b.String() +} diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go deleted file mode 100644 index 9c46649..0000000 --- a/internal/adapter/openai/tool_sieve.go +++ /dev/null @@ -1,713 +0,0 @@ -package openai - -import ( - "strings" - - "ds2api/internal/util" -) - -type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool - recentTextTail string - disableDeltas bool - toolNameSent bool - toolName string - toolArgsStart int - toolArgsSent int - toolArgsString bool - toolArgsDone bool -} - -type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall - ToolCallDeltas []toolCallDelta -} - -type toolCallDelta struct { - Index int - Name string - Arguments string -} - -const toolSieveCaptureLimit = 8 * 1024 -const toolSieveContextTailLimit = 256 - -func (s *toolStreamSieveState) resetIncrementalToolState() { - s.disableDeltas = false - s.toolNameSent = false - s.toolName = "" - s.toolArgsStart = -1 - s.toolArgsSent = -1 - s.toolArgsString = false - s.toolArgsDone = false -} - -func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { - if state == nil { - return nil - } - if chunk != "" { - state.pending.WriteString(chunk) - } - events := make([]toolStreamEvent, 0, 2) - - for { - if state.capturing { - if state.pending.Len() > 0 { - state.capture.WriteString(state.pending.String()) - state.pending.Reset() - } - if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { - events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) - } - prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) - if !ready { - if state.capture.Len() > toolSieveCaptureLimit { - content := state.capture.String() - state.capture.Reset() - state.capturing = false - state.resetIncrementalToolState() - state.noteText(content) - events = append(events, toolStreamEvent{Content: content}) - continue - } - break - } - state.capture.Reset() - state.capturing = false - state.resetIncrementalToolState() - if prefix != "" { - state.noteText(prefix) - events = append(events, toolStreamEvent{Content: prefix}) - } - if len(calls) > 0 { - events = append(events, toolStreamEvent{ToolCalls: calls}) - } - if suffix != "" { - state.pending.WriteString(suffix) - } - continue - } - - pending := state.pending.String() - if pending == "" { - break - } - start := findToolSegmentStart(pending) - if start >= 0 { - prefix := pending[:start] - if prefix != "" { - state.noteText(prefix) - events = append(events, toolStreamEvent{Content: prefix}) - } - state.pending.Reset() - state.capture.WriteString(pending[start:]) - state.capturing = true - state.resetIncrementalToolState() - continue - } - - safe, hold := splitSafeContentForToolDetection(pending) - if safe == "" { - break - } - state.pending.Reset() - state.pending.WriteString(hold) - state.noteText(safe) - events = append(events, toolStreamEvent{Content: safe}) - } - - return events -} - -func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent { - if state == nil { - return nil - } - events := processToolSieveChunk(state, "", toolNames) - if state.capturing { - consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) - if ready { - if consumedPrefix != "" { - state.noteText(consumedPrefix) - events = append(events, toolStreamEvent{Content: consumedPrefix}) - } - if len(consumedCalls) > 0 { - events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) - } - if consumedSuffix != "" { - state.noteText(consumedSuffix) - events = append(events, toolStreamEvent{Content: consumedSuffix}) - } - } else { - content := state.capture.String() - if content != "" { - state.noteText(content) - events = append(events, toolStreamEvent{Content: content}) - } - } - state.capture.Reset() - state.capturing = false - state.resetIncrementalToolState() - } - if state.pending.Len() > 0 { - content := state.pending.String() - state.noteText(content) - events = append(events, toolStreamEvent{Content: content}) - state.pending.Reset() - } - return events -} - -func splitSafeContentForToolDetection(s string) (safe, hold string) { - if s == "" { - return "", "" - } - suspiciousStart := findSuspiciousPrefixStart(s) - if suspiciousStart < 0 { - return s, "" - } - if suspiciousStart > 0 { - return s[:suspiciousStart], s[suspiciousStart:] - } - // If suspicious content starts at position 0, keep holding until we can - // parse a complete tool JSON block or reach stream flush. - return "", s -} - -func findSuspiciousPrefixStart(s string) int { - start := -1 - indices := []int{ - strings.LastIndex(s, "{"), - strings.LastIndex(s, "["), - strings.LastIndex(s, "```"), - } - for _, idx := range indices { - if idx > start { - start = idx - } - } - return start -} - -func findToolSegmentStart(s string) int { - if s == "" { - return -1 - } - lower := strings.ToLower(s) - offset := 0 - for { - keyRel := strings.Index(lower[offset:], "tool_calls") - if keyRel < 0 { - return -1 - } - keyIdx := offset + keyRel - start := strings.LastIndex(s[:keyIdx], "{") - if start < 0 { - start = keyIdx - } - if !insideCodeFence(s[:start]) { - return start - } - offset = keyIdx + len("tool_calls") - } -} - -func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { - captured := state.capture.String() - if captured == "" { - return "", nil, "", false - } - lower := strings.ToLower(captured) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return "", nil, "", false - } - start := strings.LastIndex(captured[:keyIdx], "{") - if start < 0 { - return "", nil, "", false - } - obj, end, ok := extractJSONObjectFrom(captured, start) - if !ok { - return "", nil, "", false - } - prefixPart := captured[:start] - suffixPart := captured[end:] - if insideCodeFence(state.recentTextTail + prefixPart) { - return captured, nil, "", true - } - parsed := util.ParseStandaloneToolCalls(obj, toolNames) - if len(parsed) == 0 { - return captured, nil, "", true - } - return prefixPart, parsed, suffixPart, true -} - -func extractJSONObjectFrom(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '{' { - return "", 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - end := i + 1 - return text[start:end], end, true - } - } - } - return "", 0, false -} - -func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { - if state.disableDeltas { - return nil - } - captured := state.capture.String() - if captured == "" { - return nil - } - lower := strings.ToLower(captured) - keyIdx := strings.Index(lower, "tool_calls") - if keyIdx < 0 { - return nil - } - start := strings.LastIndex(captured[:keyIdx], "{") - if start < 0 { - return nil - } - if insideCodeFence(state.recentTextTail + captured[:start]) { - return nil - } - certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) - if hasMultiple { - state.disableDeltas = true - return nil - } - if !certainSingle { - // In uncertain phases (e.g. first call arrived but array not closed yet), - // avoid speculative deltas and wait for final parsed tool_calls payload. - return nil - } - callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) - if !ok { - return nil - } - deltas := make([]toolCallDelta, 0, 2) - if state.toolName == "" { - name, ok := extractToolCallName(captured, callStart) - if !ok || name == "" { - return nil - } - state.toolName = name - } - if state.toolArgsStart < 0 { - argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) - if ok { - state.toolArgsString = stringMode - if stringMode { - state.toolArgsStart = argsStart + 1 - } else { - state.toolArgsStart = argsStart - } - state.toolArgsSent = state.toolArgsStart - } - } - if !state.toolNameSent { - if state.toolArgsStart < 0 { - return nil - } - state.toolNameSent = true - deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) - } - if state.toolArgsStart < 0 || state.toolArgsDone { - return deltas - } - end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) - if !ok { - return deltas - } - if end > state.toolArgsSent { - deltas = append(deltas, toolCallDelta{ - Index: 0, - Arguments: captured[state.toolArgsSent:end], - }) - state.toolArgsSent = end - } - if complete { - state.toolArgsDone = true - } - return deltas -} - -func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) { - arrStart, ok := findToolCallsArrayStart(text, keyIdx) - if !ok { - return false, false - } - i := skipSpaces(text, arrStart+1) - if i >= len(text) || text[i] != '{' { - return false, false - } - count := 0 - depth := 0 - quote := byte(0) - escaped := false - for ; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - if depth == 0 { - count++ - if count > 1 { - return false, true - } - } - depth++ - continue - } - if ch == '}' { - if depth > 0 { - depth-- - } - continue - } - if ch == ',' && depth == 0 { - // top-level separator means at least one more tool call exists - // (or is expected). Treat as multi-call and stop incremental deltas. - return false, true - } - if ch == ']' && depth == 0 { - return count == 1, false - } - } - // array not closed yet: still uncertain whether more calls will appear - return false, false -} - -func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { - arrStart, ok := findToolCallsArrayStart(text, keyIdx) - if !ok { - return -1, false - } - i := skipSpaces(text, arrStart+1) - if i >= len(text) || text[i] != '{' { - return -1, false - } - return i, true -} - -func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { - i := keyIdx + len("tool_calls") - for i < len(text) && text[i] != ':' { - i++ - } - if i >= len(text) { - return -1, false - } - i = skipSpaces(text, i+1) - if i >= len(text) || text[i] != '[' { - return -1, false - } - return i, true -} - -func extractToolCallName(text string, callStart int) (string, bool) { - valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) - if !ok || valueStart >= len(text) || text[valueStart] != '"' { - fnStart, fnOK := findFunctionObjectStart(text, callStart) - if !fnOK { - return "", false - } - valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) - if !ok || valueStart >= len(text) || text[valueStart] != '"' { - return "", false - } - } - name, _, ok := parseJSONStringLiteral(text, valueStart) - if !ok { - return "", false - } - return name, true -} - -func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { - keys := []string{"input", "arguments", "args", "parameters", "params"} - valueStart, ok := findObjectFieldValueStart(text, callStart, keys) - if !ok { - fnStart, fnOK := findFunctionObjectStart(text, callStart) - if !fnOK { - return -1, false, false - } - valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) - if !ok { - return -1, false, false - } - } - if valueStart >= len(text) { - return -1, false, false - } - ch := text[valueStart] - if ch == '{' || ch == '[' { - return valueStart, false, true - } - if ch == '"' { - return valueStart, true, true - } - return -1, false, false -} - -func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { - if start < 0 || start > len(text) { - return 0, false, false - } - if stringMode { - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == '"' { - return i, true, true - } - } - return len(text), false, true - } - if start >= len(text) { - return start, false, false - } - if text[start] != '{' && text[start] != '[' { - return 0, false, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' || ch == '[' { - depth++ - continue - } - if ch == '}' || ch == ']' { - depth-- - if depth == 0 { - return i + 1, true, true - } - } - } - return len(text), false, true -} - -func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { - if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { - return 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := objStart; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - if depth == 1 { - key, end, ok := parseJSONStringLiteral(text, i) - if !ok { - return 0, false - } - j := skipSpaces(text, end) - if j >= len(text) || text[j] != ':' { - i = end - 1 - continue - } - j = skipSpaces(text, j+1) - if j >= len(text) { - return 0, false - } - if containsKey(keys, key) { - return j, true - } - i = j - 1 - continue - } - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - break - } - } - } - return 0, false -} - -func findFunctionObjectStart(text string, callStart int) (int, bool) { - valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) - if !ok || valueStart >= len(text) || text[valueStart] != '{' { - return -1, false - } - return valueStart, true -} - -func parseJSONStringLiteral(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '"' { - return "", 0, false - } - var b strings.Builder - escaped := false - for i := start + 1; i < len(text); i++ { - ch := text[i] - if escaped { - b.WriteByte(ch) - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == '"' { - return b.String(), i + 1, true - } - b.WriteByte(ch) - } - return "", 0, false -} - -func containsKey(keys []string, value string) bool { - for _, k := range keys { - if k == value { - return true - } - } - return false -} - -func skipSpaces(text string, i int) int { - for i < len(text) { - switch text[i] { - case ' ', '\t', '\n', '\r': - i++ - default: - return i - } - } - return i -} - -func (s *toolStreamSieveState) noteText(content string) { - if strings.TrimSpace(content) == "" { - return - } - s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) -} - -func appendTail(prev, next string, max int) string { - if max <= 0 { - return "" - } - combined := prev + next - if len(combined) <= max { - return combined - } - return combined[len(combined)-max:] -} - -func looksLikeToolExampleContext(text string) bool { - return insideCodeFence(text) -} - -func insideCodeFence(text string) bool { - if text == "" { - return false - } - return strings.Count(text, "```")%2 == 1 -} diff --git a/internal/adapter/openai/tool_sieve_core.go b/internal/adapter/openai/tool_sieve_core.go new file mode 100644 index 0000000..1bcf102 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_core.go @@ -0,0 +1,208 @@ +package openai + +import ( + "strings" + + "ds2api/internal/util" +) + +func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { + if state == nil { + return nil + } + if chunk != "" { + state.pending.WriteString(chunk) + } + events := make([]toolStreamEvent, 0, 2) + + for { + if state.capturing { + if state.pending.Len() > 0 { + state.capture.WriteString(state.pending.String()) + state.pending.Reset() + } + if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { + events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) + } + prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) + if !ready { + if state.capture.Len() > toolSieveCaptureLimit { + content := state.capture.String() + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + continue + } + break + } + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + if prefix != "" { + state.noteText(prefix) + events = append(events, toolStreamEvent{Content: prefix}) + } + if len(calls) > 0 { + events = append(events, toolStreamEvent{ToolCalls: calls}) + } + if suffix != "" { + state.pending.WriteString(suffix) + } + continue + } + + pending := state.pending.String() + if pending == "" { + break + } + start := findToolSegmentStart(pending) + if start >= 0 { + prefix := pending[:start] + if prefix != "" { + state.noteText(prefix) + events = append(events, toolStreamEvent{Content: prefix}) + } + state.pending.Reset() + state.capture.WriteString(pending[start:]) + state.capturing = true + state.resetIncrementalToolState() + continue + } + + safe, hold := splitSafeContentForToolDetection(pending) + if safe == "" { + break + } + state.pending.Reset() + state.pending.WriteString(hold) + state.noteText(safe) + events = append(events, toolStreamEvent{Content: safe}) + } + + return events +} + +func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent { + if state == nil { + return nil + } + events := processToolSieveChunk(state, "", toolNames) + if state.capturing { + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) + if ready { + if consumedPrefix != "" { + state.noteText(consumedPrefix) + events = append(events, toolStreamEvent{Content: consumedPrefix}) + } + if len(consumedCalls) > 0 { + events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) + } + if consumedSuffix != "" { + state.noteText(consumedSuffix) + events = append(events, toolStreamEvent{Content: consumedSuffix}) + } + } else { + content := state.capture.String() + if content != "" { + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + } + } + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + } + if state.pending.Len() > 0 { + content := state.pending.String() + state.noteText(content) + events = append(events, toolStreamEvent{Content: content}) + state.pending.Reset() + } + return events +} + +func splitSafeContentForToolDetection(s string) (safe, hold string) { + if s == "" { + return "", "" + } + suspiciousStart := findSuspiciousPrefixStart(s) + if suspiciousStart < 0 { + return s, "" + } + if suspiciousStart > 0 { + return s[:suspiciousStart], s[suspiciousStart:] + } + // If suspicious content starts at position 0, keep holding until we can + // parse a complete tool JSON block or reach stream flush. + return "", s +} + +func findSuspiciousPrefixStart(s string) int { + start := -1 + indices := []int{ + strings.LastIndex(s, "{"), + strings.LastIndex(s, "["), + strings.LastIndex(s, "```"), + } + for _, idx := range indices { + if idx > start { + start = idx + } + } + return start +} + +func findToolSegmentStart(s string) int { + if s == "" { + return -1 + } + lower := strings.ToLower(s) + offset := 0 + for { + keyRel := strings.Index(lower[offset:], "tool_calls") + if keyRel < 0 { + return -1 + } + keyIdx := offset + keyRel + start := strings.LastIndex(s[:keyIdx], "{") + if start < 0 { + start = keyIdx + } + if !insideCodeFence(s[:start]) { + return start + } + offset = keyIdx + len("tool_calls") + } +} + +func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + captured := state.capture.String() + if captured == "" { + return "", nil, "", false + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return "", nil, "", false + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 { + return "", nil, "", false + } + obj, end, ok := extractJSONObjectFrom(captured, start) + if !ok { + return "", nil, "", false + } + prefixPart := captured[:start] + suffixPart := captured[end:] + if insideCodeFence(state.recentTextTail + prefixPart) { + return captured, nil, "", true + } + parsed := util.ParseStandaloneToolCalls(obj, toolNames) + if len(parsed) == 0 { + return captured, nil, "", true + } + return prefixPart, parsed, suffixPart, true +} diff --git a/internal/adapter/openai/tool_sieve_incremental.go b/internal/adapter/openai/tool_sieve_incremental.go new file mode 100644 index 0000000..ad0f901 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_incremental.go @@ -0,0 +1,291 @@ +package openai + +import "strings" + +func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + if state.disableDeltas { + return nil + } + captured := state.capture.String() + if captured == "" { + return nil + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return nil + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 { + return nil + } + if insideCodeFence(state.recentTextTail + captured[:start]) { + return nil + } + certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx) + if hasMultiple { + state.disableDeltas = true + return nil + } + if !certainSingle { + // In uncertain phases (e.g. first call arrived but array not closed yet), + // avoid speculative deltas and wait for final parsed tool_calls payload. + return nil + } + callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) + if !ok { + return nil + } + deltas := make([]toolCallDelta, 0, 2) + if state.toolName == "" { + name, ok := extractToolCallName(captured, callStart) + if !ok || name == "" { + return nil + } + state.toolName = name + } + if state.toolArgsStart < 0 { + argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) + if ok { + state.toolArgsString = stringMode + if stringMode { + state.toolArgsStart = argsStart + 1 + } else { + state.toolArgsStart = argsStart + } + state.toolArgsSent = state.toolArgsStart + } + } + if !state.toolNameSent { + if state.toolArgsStart < 0 { + return nil + } + state.toolNameSent = true + deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) + } + if state.toolArgsStart < 0 || state.toolArgsDone { + return deltas + } + end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) + if !ok { + return deltas + } + if end > state.toolArgsSent { + deltas = append(deltas, toolCallDelta{ + Index: 0, + Arguments: captured[state.toolArgsSent:end], + }) + state.toolArgsSent = end + } + if complete { + state.toolArgsDone = true + } + return deltas +} + +func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return false, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return false, false + } + count := 0 + depth := 0 + quote := byte(0) + escaped := false + for ; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + if depth == 0 { + count++ + if count > 1 { + return false, true + } + } + depth++ + continue + } + if ch == '}' { + if depth > 0 { + depth-- + } + continue + } + if ch == ',' && depth == 0 { + // top-level separator means at least one more tool call exists + // (or is expected). Treat as multi-call and stop incremental deltas. + return false, true + } + if ch == ']' && depth == 0 { + return count == 1, false + } + } + // array not closed yet: still uncertain whether more calls will appear + return false, false +} + +func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return -1, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return -1, false + } + return i, true +} + +func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { + i := keyIdx + len("tool_calls") + for i < len(text) && text[i] != ':' { + i++ + } + if i >= len(text) { + return -1, false + } + i = skipSpaces(text, i+1) + if i >= len(text) || text[i] != '[' { + return -1, false + } + return i, true +} + +func extractToolCallName(text string, callStart int) (string, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return "", false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + return "", false + } + } + name, _, ok := parseJSONStringLiteral(text, valueStart) + if !ok { + return "", false + } + return name, true +} + +func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { + keys := []string{"input", "arguments", "args", "parameters", "params"} + valueStart, ok := findObjectFieldValueStart(text, callStart, keys) + if !ok { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return -1, false, false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) + if !ok { + return -1, false, false + } + } + if valueStart >= len(text) { + return -1, false, false + } + ch := text[valueStart] + if ch == '{' || ch == '[' { + return valueStart, false, true + } + if ch == '"' { + return valueStart, true, true + } + return -1, false, false +} + +func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { + if start < 0 || start > len(text) { + return 0, false, false + } + if stringMode { + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return i, true, true + } + } + return len(text), false, true + } + if start >= len(text) { + return start, false, false + } + if text[start] != '{' && text[start] != '[' { + return 0, false, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' || ch == '[' { + depth++ + continue + } + if ch == '}' || ch == ']' { + depth-- + if depth == 0 { + return i + 1, true, true + } + } + } + return len(text), false, true +} + +func findFunctionObjectStart(text string, callStart int) (int, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) + if !ok || valueStart >= len(text) || text[valueStart] != '{' { + return -1, false + } + return valueStart, true +} diff --git a/internal/adapter/openai/tool_sieve_jsonscan.go b/internal/adapter/openai/tool_sieve_jsonscan.go new file mode 100644 index 0000000..d3abcc5 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_jsonscan.go @@ -0,0 +1,152 @@ +package openai + +import "strings" + +func extractJSONObjectFrom(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '{' { + return "", 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + end := i + 1 + return text[start:end], end, true + } + } + } + return "", 0, false +} + +func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { + if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { + return 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := objStart; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + if depth == 1 { + key, end, ok := parseJSONStringLiteral(text, i) + if !ok { + return 0, false + } + j := skipSpaces(text, end) + if j >= len(text) || text[j] != ':' { + i = end - 1 + continue + } + j = skipSpaces(text, j+1) + if j >= len(text) { + return 0, false + } + if containsKey(keys, key) { + return j, true + } + i = j - 1 + continue + } + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + break + } + } + } + return 0, false +} + +func parseJSONStringLiteral(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '"' { + return "", 0, false + } + var b strings.Builder + escaped := false + for i := start + 1; i < len(text); i++ { + ch := text[i] + if escaped { + b.WriteByte(ch) + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return b.String(), i + 1, true + } + b.WriteByte(ch) + } + return "", 0, false +} + +func containsKey(keys []string, value string) bool { + for _, k := range keys { + if k == value { + return true + } + } + return false +} + +func skipSpaces(text string, i int) int { + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r': + i++ + default: + return i + } + } + return i +} diff --git a/internal/adapter/openai/tool_sieve_state.go b/internal/adapter/openai/tool_sieve_state.go new file mode 100644 index 0000000..04699e6 --- /dev/null +++ b/internal/adapter/openai/tool_sieve_state.go @@ -0,0 +1,75 @@ +package openai + +import ( + "strings" + + "ds2api/internal/util" +) + +type toolStreamSieveState struct { + pending strings.Builder + capture strings.Builder + capturing bool + recentTextTail string + disableDeltas bool + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool +} + +type toolStreamEvent struct { + Content string + ToolCalls []util.ParsedToolCall + ToolCallDeltas []toolCallDelta +} + +type toolCallDelta struct { + Index int + Name string + Arguments string +} + +const toolSieveCaptureLimit = 8 * 1024 +const toolSieveContextTailLimit = 256 + +func (s *toolStreamSieveState) resetIncrementalToolState() { + s.disableDeltas = false + s.toolNameSent = false + s.toolName = "" + s.toolArgsStart = -1 + s.toolArgsSent = -1 + s.toolArgsString = false + s.toolArgsDone = false +} + +func (s *toolStreamSieveState) noteText(content string) { + if strings.TrimSpace(content) == "" { + return + } + s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) +} + +func appendTail(prev, next string, max int) string { + if max <= 0 { + return "" + } + combined := prev + next + if len(combined) <= max { + return combined + } + return combined[len(combined)-max:] +} + +func looksLikeToolExampleContext(text string) bool { + return insideCodeFence(text) +} + +func insideCodeFence(text string) bool { + if text == "" { + return false + } + return strings.Count(text, "```")%2 == 1 +} diff --git a/internal/admin/handler_accounts_crud.go b/internal/admin/handler_accounts_crud.go new file mode 100644 index 0000000..daaa434 --- /dev/null +++ b/internal/admin/handler_accounts_crud.go @@ -0,0 +1,114 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { + page := intFromQuery(r, "page", 1) + pageSize := intFromQuery(r, "page_size", 10) + if page < 1 { + page = 1 + } + if pageSize < 1 { + pageSize = 1 + } + if pageSize > 100 { + pageSize = 100 + } + accounts := h.Store.Snapshot().Accounts + total := len(accounts) + reverseAccounts(accounts) + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + } + start := (page - 1) * pageSize + if start > total { + start = total + } + end := start + pageSize + if end > total { + end = total + } + items := make([]map[string]any, 0, end-start) + for _, acc := range accounts[start:end] { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + items = append(items, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": acc.Password != "", + "has_token": token != "", + "token_preview": preview, + }) + } + writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) +} + +func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + acc := toAccount(req) + if acc.Identifier() == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, a := range c.Accounts { + if acc.Email != "" && a.Email == acc.Email { + return fmt.Errorf("邮箱已存在") + } + if acc.Mobile != "" && a.Mobile == acc.Mobile { + return fmt.Errorf("手机号已存在") + } + } + c.Accounts = append(c.Accounts, acc) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} + +func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { + identifier := chi.URLParam(r, "identifier") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, a := range c.Accounts { + if accountMatchesIdentifier(a, identifier) { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("账号不存在") + } + c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) +} diff --git a/internal/admin/handler_accounts_queue.go b/internal/admin/handler_accounts_queue.go new file mode 100644 index 0000000..108f802 --- /dev/null +++ b/internal/admin/handler_accounts_queue.go @@ -0,0 +1,7 @@ +package admin + +import "net/http" + +func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, h.Pool.Status()) +} diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts_testing.go similarity index 69% rename from internal/admin/handler_accounts.go rename to internal/admin/handler_accounts_testing.go index 5cb88cc..2bd7706 100644 --- a/internal/admin/handler_accounts.go +++ b/internal/admin/handler_accounts_testing.go @@ -11,119 +11,11 @@ import ( "sync" "time" - "github.com/go-chi/chi/v5" - authn "ds2api/internal/auth" "ds2api/internal/config" "ds2api/internal/sse" ) -func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) { - page := intFromQuery(r, "page", 1) - pageSize := intFromQuery(r, "page_size", 10) - if page < 1 { - page = 1 - } - if pageSize < 1 { - pageSize = 1 - } - if pageSize > 100 { - pageSize = 100 - } - accounts := h.Store.Snapshot().Accounts - total := len(accounts) - reverseAccounts(accounts) - totalPages := 1 - if total > 0 { - totalPages = (total + pageSize - 1) / pageSize - } - start := (page - 1) * pageSize - if start > total { - start = total - } - end := start + pageSize - if end > total { - end = total - } - items := make([]map[string]any, 0, end-start) - for _, acc := range accounts[start:end] { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - items = append(items, map[string]any{ - "identifier": acc.Identifier(), - "email": acc.Email, - "mobile": acc.Mobile, - "has_password": acc.Password != "", - "has_token": token != "", - "token_preview": preview, - }) - } - writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages}) -} - -func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - acc := toAccount(req) - if acc.Identifier() == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, a := range c.Accounts { - if acc.Email != "" && a.Email == acc.Email { - return fmt.Errorf("邮箱已存在") - } - if acc.Mobile != "" && a.Mobile == acc.Mobile { - return fmt.Errorf("手机号已存在") - } - } - c.Accounts = append(c.Accounts, acc) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - identifier := chi.URLParam(r, "identifier") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, a := range c.Accounts { - if accountMatchesIdentifier(a, identifier) { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("账号不存在") - } - c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)}) -} - -func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) { - writeJSON(w, http.StatusOK, h.Pool.Status()) -} - func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) { var req map[string]any _ = json.NewDecoder(r.Body).Decode(&req) diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go deleted file mode 100644 index dfbd005..0000000 --- a/internal/admin/handler_config.go +++ /dev/null @@ -1,393 +0,0 @@ -package admin - -import ( - "crypto/md5" - "encoding/json" - "fmt" - "net/http" - "strings" - - "github.com/go-chi/chi/v5" - - "ds2api/internal/config" -) - -func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - safe := map[string]any{ - "keys": snap.Keys, - "accounts": []map[string]any{}, - "claude_mapping": func() map[string]string { - if len(snap.ClaudeMapping) > 0 { - return snap.ClaudeMapping - } - return snap.ClaudeModelMap - }(), - } - accounts := make([]map[string]any, 0, len(snap.Accounts)) - for _, acc := range snap.Accounts { - token := strings.TrimSpace(acc.Token) - preview := "" - if token != "" { - if len(token) > 20 { - preview = token[:20] + "..." - } else { - preview = token - } - } - accounts = append(accounts, map[string]any{ - "identifier": acc.Identifier(), - "email": acc.Email, - "mobile": acc.Mobile, - "has_password": strings.TrimSpace(acc.Password) != "", - "has_token": token != "", - "token_preview": preview, - }) - } - safe["accounts"] = accounts - writeJSON(w, http.StatusOK, safe) -} - -func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - old := h.Store.Snapshot() - err := h.Store.Update(func(c *config.Config) error { - if keys, ok := toStringSlice(req["keys"]); ok { - c.Keys = keys - } - if accountsRaw, ok := req["accounts"].([]any); ok { - existing := map[string]config.Account{} - for _, a := range old.Accounts { - existing[a.Identifier()] = a - } - accounts := make([]config.Account, 0, len(accountsRaw)) - for _, item := range accountsRaw { - m, ok := item.(map[string]any) - if !ok { - continue - } - acc := toAccount(m) - id := acc.Identifier() - if prev, ok := existing[id]; ok { - if strings.TrimSpace(acc.Password) == "" { - acc.Password = prev.Password - } - if strings.TrimSpace(acc.Token) == "" { - acc.Token = prev.Token - } - } - accounts = append(accounts, acc) - } - c.Accounts = accounts - } - if m, ok := req["claude_mapping"].(map[string]any); ok { - newMap := map[string]string{} - for k, v := range m { - newMap[k] = fmt.Sprintf("%v", v) - } - c.ClaudeMapping = newMap - } - return nil - }) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"}) -} - -func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) { - var req map[string]any - _ = json.NewDecoder(r.Body).Decode(&req) - key, _ := req["key"].(string) - key = strings.TrimSpace(key) - if key == "" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"}) - return - } - err := h.Store.Update(func(c *config.Config) error { - for _, k := range c.Keys { - if k == key { - return fmt.Errorf("Key 已存在") - } - } - c.Keys = append(c.Keys, key) - return nil - }) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) -} - -func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) { - key := chi.URLParam(r, "key") - err := h.Store.Update(func(c *config.Config) error { - idx := -1 - for i, k := range c.Keys { - if k == key { - idx = i - break - } - } - if idx < 0 { - return fmt.Errorf("Key 不存在") - } - c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...) - return nil - }) - if err != nil { - writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) -} - -func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"}) - return - } - importedKeys, importedAccounts := 0, 0 - err := h.Store.Update(func(c *config.Config) error { - if keys, ok := req["keys"].([]any); ok { - existing := map[string]bool{} - for _, k := range c.Keys { - existing[k] = true - } - for _, k := range keys { - key := strings.TrimSpace(fmt.Sprintf("%v", k)) - if key == "" || existing[key] { - continue - } - c.Keys = append(c.Keys, key) - existing[key] = true - importedKeys++ - } - } - if accounts, ok := req["accounts"].([]any); ok { - existing := map[string]bool{} - for _, a := range c.Accounts { - existing[a.Identifier()] = true - } - for _, item := range accounts { - m, ok := item.(map[string]any) - if !ok { - continue - } - acc := toAccount(m) - id := acc.Identifier() - if id == "" || existing[id] { - continue - } - c.Accounts = append(c.Accounts, acc) - existing[id] = true - importedAccounts++ - } - } - return nil - }) - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) -} - -func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { - h.configExport(w, nil) -} - -func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - jsonStr, b64, err := h.Store.ExportJSONAndBase64() - if err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - writeJSON(w, http.StatusOK, map[string]any{ - "success": true, - "config": snap, - "json": jsonStr, - "base64": b64, - }) -} - -func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - - mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode"))) - if mode == "" { - mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode"))) - } - if mode == "" { - mode = "merge" - } - if mode != "merge" && mode != "replace" { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"}) - return - } - - payload := req - if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 { - payload = raw - } - rawJSON, err := json.Marshal(payload) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"}) - return - } - var incoming config.Config - if err := json.Unmarshal(rawJSON, &incoming); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - - importedKeys, importedAccounts := 0, 0 - err = h.Store.Update(func(c *config.Config) error { - next := c.Clone() - if mode == "replace" { - next = incoming.Clone() - next.VercelSyncHash = c.VercelSyncHash - next.VercelSyncTime = c.VercelSyncTime - importedKeys = len(next.Keys) - importedAccounts = len(next.Accounts) - } else { - existingKeys := map[string]struct{}{} - for _, k := range next.Keys { - existingKeys[k] = struct{}{} - } - for _, k := range incoming.Keys { - key := strings.TrimSpace(k) - if key == "" { - continue - } - if _, ok := existingKeys[key]; ok { - continue - } - existingKeys[key] = struct{}{} - next.Keys = append(next.Keys, key) - importedKeys++ - } - - existingAccounts := map[string]struct{}{} - for _, acc := range next.Accounts { - existingAccounts[acc.Identifier()] = struct{}{} - } - for _, acc := range incoming.Accounts { - id := acc.Identifier() - if id == "" { - continue - } - if _, ok := existingAccounts[id]; ok { - continue - } - existingAccounts[id] = struct{}{} - next.Accounts = append(next.Accounts, acc) - importedAccounts++ - } - - if len(incoming.ClaudeMapping) > 0 { - if next.ClaudeMapping == nil { - next.ClaudeMapping = map[string]string{} - } - for k, v := range incoming.ClaudeMapping { - next.ClaudeMapping[k] = v - } - } - if len(incoming.ClaudeModelMap) > 0 { - if next.ClaudeModelMap == nil { - next.ClaudeModelMap = map[string]string{} - } - for k, v := range incoming.ClaudeModelMap { - next.ClaudeModelMap[k] = v - } - } - - if len(incoming.ModelAliases) > 0 { - if next.ModelAliases == nil { - next.ModelAliases = map[string]string{} - } - for k, v := range incoming.ModelAliases { - next.ModelAliases[k] = v - } - } - if strings.TrimSpace(incoming.Toolcall.Mode) != "" { - next.Toolcall.Mode = incoming.Toolcall.Mode - } - if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" { - next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence - } - if incoming.Responses.StoreTTLSeconds > 0 { - next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds - } - if strings.TrimSpace(incoming.Embeddings.Provider) != "" { - next.Embeddings.Provider = incoming.Embeddings.Provider - } - if strings.TrimSpace(incoming.Admin.PasswordHash) != "" { - next.Admin.PasswordHash = incoming.Admin.PasswordHash - } - if incoming.Admin.JWTExpireHours > 0 { - next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours - } - if incoming.Admin.JWTValidAfterUnix > 0 { - next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix - } - if incoming.Runtime.AccountMaxInflight > 0 { - next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight - } - if incoming.Runtime.AccountMaxQueue > 0 { - next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue - } - if incoming.Runtime.GlobalMaxInflight > 0 { - next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight - } - } - - normalizeSettingsConfig(&next) - if err := validateSettingsConfig(next); err != nil { - return newRequestError(err.Error()) - } - - *c = next - return nil - }) - if err != nil { - if detail, ok := requestErrorDetail(err); ok { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail}) - return - } - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - - h.Pool.Reset() - writeJSON(w, http.StatusOK, map[string]any{ - "success": true, - "mode": mode, - "imported_keys": importedKeys, - "imported_accounts": importedAccounts, - "message": "config imported", - }) -} - -func (h *Handler) computeSyncHash() string { - snap := h.Store.Snapshot().Clone() - snap.VercelSyncHash = "" - snap.VercelSyncTime = 0 - b, _ := json.Marshal(snap) - sum := md5.Sum(b) - return fmt.Sprintf("%x", sum) -} diff --git a/internal/admin/handler_config_import.go b/internal/admin/handler_config_import.go new file mode 100644 index 0000000..674d8b2 --- /dev/null +++ b/internal/admin/handler_config_import.go @@ -0,0 +1,182 @@ +package admin + +import ( + "crypto/md5" + "encoding/json" + "fmt" + "net/http" + "strings" + + "ds2api/internal/config" +) + +func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode"))) + if mode == "" { + mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode"))) + } + if mode == "" { + mode = "merge" + } + if mode != "merge" && mode != "replace" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"}) + return + } + + payload := req + if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 { + payload = raw + } + rawJSON, err := json.Marshal(payload) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"}) + return + } + var incoming config.Config + if err := json.Unmarshal(rawJSON, &incoming); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + + importedKeys, importedAccounts := 0, 0 + err = h.Store.Update(func(c *config.Config) error { + next := c.Clone() + if mode == "replace" { + next = incoming.Clone() + next.VercelSyncHash = c.VercelSyncHash + next.VercelSyncTime = c.VercelSyncTime + importedKeys = len(next.Keys) + importedAccounts = len(next.Accounts) + } else { + existingKeys := map[string]struct{}{} + for _, k := range next.Keys { + existingKeys[k] = struct{}{} + } + for _, k := range incoming.Keys { + key := strings.TrimSpace(k) + if key == "" { + continue + } + if _, ok := existingKeys[key]; ok { + continue + } + existingKeys[key] = struct{}{} + next.Keys = append(next.Keys, key) + importedKeys++ + } + + existingAccounts := map[string]struct{}{} + for _, acc := range next.Accounts { + existingAccounts[acc.Identifier()] = struct{}{} + } + for _, acc := range incoming.Accounts { + id := acc.Identifier() + if id == "" { + continue + } + if _, ok := existingAccounts[id]; ok { + continue + } + existingAccounts[id] = struct{}{} + next.Accounts = append(next.Accounts, acc) + importedAccounts++ + } + + if len(incoming.ClaudeMapping) > 0 { + if next.ClaudeMapping == nil { + next.ClaudeMapping = map[string]string{} + } + for k, v := range incoming.ClaudeMapping { + next.ClaudeMapping[k] = v + } + } + if len(incoming.ClaudeModelMap) > 0 { + if next.ClaudeModelMap == nil { + next.ClaudeModelMap = map[string]string{} + } + for k, v := range incoming.ClaudeModelMap { + next.ClaudeModelMap[k] = v + } + } + + if len(incoming.ModelAliases) > 0 { + if next.ModelAliases == nil { + next.ModelAliases = map[string]string{} + } + for k, v := range incoming.ModelAliases { + next.ModelAliases[k] = v + } + } + if strings.TrimSpace(incoming.Toolcall.Mode) != "" { + next.Toolcall.Mode = incoming.Toolcall.Mode + } + if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" { + next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence + } + if incoming.Responses.StoreTTLSeconds > 0 { + next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds + } + if strings.TrimSpace(incoming.Embeddings.Provider) != "" { + next.Embeddings.Provider = incoming.Embeddings.Provider + } + if strings.TrimSpace(incoming.Admin.PasswordHash) != "" { + next.Admin.PasswordHash = incoming.Admin.PasswordHash + } + if incoming.Admin.JWTExpireHours > 0 { + next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours + } + if incoming.Admin.JWTValidAfterUnix > 0 { + next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix + } + if incoming.Runtime.AccountMaxInflight > 0 { + next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight + } + if incoming.Runtime.AccountMaxQueue > 0 { + next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue + } + if incoming.Runtime.GlobalMaxInflight > 0 { + next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight + } + } + + normalizeSettingsConfig(&next) + if err := validateSettingsConfig(next); err != nil { + return newRequestError(err.Error()) + } + + *c = next + return nil + }) + if err != nil { + if detail, ok := requestErrorDetail(err); ok { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail}) + return + } + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "mode": mode, + "imported_keys": importedKeys, + "imported_accounts": importedAccounts, + "message": "config imported", + }) +} + +func (h *Handler) computeSyncHash() string { + snap := h.Store.Snapshot().Clone() + snap.VercelSyncHash = "" + snap.VercelSyncTime = 0 + b, _ := json.Marshal(snap) + sum := md5.Sum(b) + return fmt.Sprintf("%x", sum) +} diff --git a/internal/admin/handler_config_read.go b/internal/admin/handler_config_read.go new file mode 100644 index 0000000..e32aabd --- /dev/null +++ b/internal/admin/handler_config_read.go @@ -0,0 +1,61 @@ +package admin + +import ( + "net/http" + "strings" +) + +func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + safe := map[string]any{ + "keys": snap.Keys, + "accounts": []map[string]any{}, + "claude_mapping": func() map[string]string { + if len(snap.ClaudeMapping) > 0 { + return snap.ClaudeMapping + } + return snap.ClaudeModelMap + }(), + } + accounts := make([]map[string]any, 0, len(snap.Accounts)) + for _, acc := range snap.Accounts { + token := strings.TrimSpace(acc.Token) + preview := "" + if token != "" { + if len(token) > 20 { + preview = token[:20] + "..." + } else { + preview = token + } + } + accounts = append(accounts, map[string]any{ + "identifier": acc.Identifier(), + "email": acc.Email, + "mobile": acc.Mobile, + "has_password": strings.TrimSpace(acc.Password) != "", + "has_token": token != "", + "token_preview": preview, + }) + } + safe["accounts"] = accounts + writeJSON(w, http.StatusOK, safe) +} + +func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) { + h.configExport(w, nil) +} + +func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + jsonStr, b64, err := h.Store.ExportJSONAndBase64() + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "config": snap, + "json": jsonStr, + "base64": b64, + }) +} diff --git a/internal/admin/handler_config_write.go b/internal/admin/handler_config_write.go new file mode 100644 index 0000000..792e696 --- /dev/null +++ b/internal/admin/handler_config_write.go @@ -0,0 +1,166 @@ +package admin + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" + + "ds2api/internal/config" +) + +func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + old := h.Store.Snapshot() + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := toStringSlice(req["keys"]); ok { + c.Keys = keys + } + if accountsRaw, ok := req["accounts"].([]any); ok { + existing := map[string]config.Account{} + for _, a := range old.Accounts { + existing[a.Identifier()] = a + } + accounts := make([]config.Account, 0, len(accountsRaw)) + for _, item := range accountsRaw { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if prev, ok := existing[id]; ok { + if strings.TrimSpace(acc.Password) == "" { + acc.Password = prev.Password + } + if strings.TrimSpace(acc.Token) == "" { + acc.Token = prev.Token + } + } + accounts = append(accounts, acc) + } + c.Accounts = accounts + } + if m, ok := req["claude_mapping"].(map[string]any); ok { + newMap := map[string]string{} + for k, v := range m { + newMap[k] = fmt.Sprintf("%v", v) + } + c.ClaudeMapping = newMap + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"}) +} + +func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) { + var req map[string]any + _ = json.NewDecoder(r.Body).Decode(&req) + key, _ := req["key"].(string) + key = strings.TrimSpace(key) + if key == "" { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"}) + return + } + err := h.Store.Update(func(c *config.Config) error { + for _, k := range c.Keys { + if k == key { + return fmt.Errorf("Key 已存在") + } + } + c.Keys = append(c.Keys, key) + return nil + }) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) { + key := chi.URLParam(r, "key") + err := h.Store.Update(func(c *config.Config) error { + idx := -1 + for i, k := range c.Keys { + if k == key { + idx = i + break + } + } + if idx < 0 { + return fmt.Errorf("Key 不存在") + } + c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...) + return nil + }) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()}) + return + } + writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)}) +} + +func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"}) + return + } + importedKeys, importedAccounts := 0, 0 + err := h.Store.Update(func(c *config.Config) error { + if keys, ok := req["keys"].([]any); ok { + existing := map[string]bool{} + for _, k := range c.Keys { + existing[k] = true + } + for _, k := range keys { + key := strings.TrimSpace(fmt.Sprintf("%v", k)) + if key == "" || existing[key] { + continue + } + c.Keys = append(c.Keys, key) + existing[key] = true + importedKeys++ + } + } + if accounts, ok := req["accounts"].([]any); ok { + existing := map[string]bool{} + for _, a := range c.Accounts { + existing[a.Identifier()] = true + } + for _, item := range accounts { + m, ok := item.(map[string]any) + if !ok { + continue + } + acc := toAccount(m) + id := acc.Identifier() + if id == "" || existing[id] { + continue + } + c.Accounts = append(c.Accounts, acc) + existing[id] = true + importedAccounts++ + } + } + return nil + }) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + h.Pool.Reset() + writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts}) +} diff --git a/internal/admin/handler_settings.go b/internal/admin/handler_settings.go deleted file mode 100644 index 06c234c..0000000 --- a/internal/admin/handler_settings.go +++ /dev/null @@ -1,321 +0,0 @@ -package admin - -import ( - "encoding/json" - "fmt" - "net/http" - "strings" - "time" - - authn "ds2api/internal/auth" - "ds2api/internal/config" -) - -func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { - snap := h.Store.Snapshot() - recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight()) - needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash() - writeJSON(w, http.StatusOK, map[string]any{ - "success": true, - "admin": map[string]any{ - "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "", - "jwt_expire_hours": h.Store.AdminJWTExpireHours(), - "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix, - "default_password_warning": authn.UsingDefaultAdminKey(h.Store), - }, - "runtime": map[string]any{ - "account_max_inflight": h.Store.RuntimeAccountMaxInflight(), - "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended), - "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended), - }, - "toolcall": snap.Toolcall, - "responses": snap.Responses, - "embeddings": snap.Embeddings, - "claude_mapping": settingsClaudeMapping(snap), - "model_aliases": snap.ModelAliases, - "env_backed": h.Store.IsEnvBacked(), - "needs_vercel_sync": needsSync, - }) -} - -func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - - adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) - if err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - if runtimeCfg != nil { - if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) - return - } - } - - if err := h.Store.Update(func(c *config.Config) error { - if adminCfg != nil { - if adminCfg.JWTExpireHours > 0 { - c.Admin.JWTExpireHours = adminCfg.JWTExpireHours - } - } - if runtimeCfg != nil { - if runtimeCfg.AccountMaxInflight > 0 { - c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight - } - if runtimeCfg.AccountMaxQueue > 0 { - c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue - } - if runtimeCfg.GlobalMaxInflight > 0 { - c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight - } - } - if toolcallCfg != nil { - if strings.TrimSpace(toolcallCfg.Mode) != "" { - c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode) - } - if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" { - c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) - } - } - if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 { - c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds - } - if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { - c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) - } - if claudeMap != nil { - c.ClaudeMapping = claudeMap - c.ClaudeModelMap = nil - } - if aliasMap != nil { - c.ModelAliases = aliasMap - } - return nil - }); err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - - h.applyRuntimeSettings() - needsSync := config.IsVercel() || h.Store.IsEnvBacked() - writeJSON(w, http.StatusOK, map[string]any{ - "success": true, - "message": "settings updated and hot reloaded", - "env_backed": h.Store.IsEnvBacked(), - "needs_vercel_sync": needsSync, - "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。", - }) -} - -func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error { - merged := current - if incoming != nil { - if incoming.AccountMaxInflight > 0 { - merged.AccountMaxInflight = incoming.AccountMaxInflight - } - if incoming.AccountMaxQueue > 0 { - merged.AccountMaxQueue = incoming.AccountMaxQueue - } - if incoming.GlobalMaxInflight > 0 { - merged.GlobalMaxInflight = incoming.GlobalMaxInflight - } - } - return validateRuntimeSettings(merged) -} - -func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) { - var req map[string]any - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) - return - } - newPassword := strings.TrimSpace(fieldString(req, "new_password")) - if newPassword == "" { - newPassword = strings.TrimSpace(fieldString(req, "password")) - } - if len(newPassword) < 4 { - writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"}) - return - } - - now := time.Now().Unix() - hash := authn.HashAdminPassword(newPassword) - if err := h.Store.Update(func(c *config.Config) error { - c.Admin.PasswordHash = hash - c.Admin.JWTValidAfterUnix = now - return nil - }); err != nil { - writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) - return - } - - writeJSON(w, http.StatusOK, map[string]any{ - "success": true, - "message": "password updated", - "force_relogin": true, - "jwt_valid_after_unix": now, - }) -} - -func (h *Handler) applyRuntimeSettings() { - if h == nil || h.Store == nil || h.Pool == nil { - return - } - accountCount := len(h.Store.Accounts()) - maxPer := h.Store.RuntimeAccountMaxInflight() - recommended := defaultRuntimeRecommended(accountCount, maxPer) - maxQueue := h.Store.RuntimeAccountMaxQueue(recommended) - global := h.Store.RuntimeGlobalMaxInflight(recommended) - h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global) -} - -func defaultRuntimeRecommended(accountCount, maxPer int) int { - if maxPer <= 0 { - maxPer = 1 - } - if accountCount <= 0 { - return maxPer - } - return accountCount * maxPer -} - -func settingsClaudeMapping(c config.Config) map[string]string { - if len(c.ClaudeMapping) > 0 { - return c.ClaudeMapping - } - if len(c.ClaudeModelMap) > 0 { - return c.ClaudeModelMap - } - return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} -} - -func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { - var ( - adminCfg *config.AdminConfig - runtimeCfg *config.RuntimeConfig - toolcallCfg *config.ToolcallConfig - respCfg *config.ResponsesConfig - embCfg *config.EmbeddingsConfig - claudeMap map[string]string - aliasMap map[string]string - ) - - if raw, ok := req["admin"].(map[string]any); ok { - cfg := &config.AdminConfig{} - if v, exists := raw["jwt_expire_hours"]; exists { - n := intFrom(v) - if n < 1 || n > 720 { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") - } - cfg.JWTExpireHours = n - } - adminCfg = cfg - } - - if raw, ok := req["runtime"].(map[string]any); ok { - cfg := &config.RuntimeConfig{} - if v, exists := raw["account_max_inflight"]; exists { - n := intFrom(v) - if n < 1 || n > 256 { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") - } - cfg.AccountMaxInflight = n - } - if v, exists := raw["account_max_queue"]; exists { - n := intFrom(v) - if n < 1 || n > 200000 { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") - } - cfg.AccountMaxQueue = n - } - if v, exists := raw["global_max_inflight"]; exists { - n := intFrom(v) - if n < 1 || n > 200000 { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") - } - cfg.GlobalMaxInflight = n - } - if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") - } - runtimeCfg = cfg - } - - if raw, ok := req["toolcall"].(map[string]any); ok { - cfg := &config.ToolcallConfig{} - if v, exists := raw["mode"]; exists { - mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) - switch mode { - case "feature_match", "off": - cfg.Mode = mode - default: - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") - } - } - if v, exists := raw["early_emit_confidence"]; exists { - level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) - switch level { - case "high", "low", "off": - cfg.EarlyEmitConfidence = level - default: - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") - } - } - toolcallCfg = cfg - } - - if raw, ok := req["responses"].(map[string]any); ok { - cfg := &config.ResponsesConfig{} - if v, exists := raw["store_ttl_seconds"]; exists { - n := intFrom(v) - if n < 30 || n > 86400 { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") - } - cfg.StoreTTLSeconds = n - } - respCfg = cfg - } - - if raw, ok := req["embeddings"].(map[string]any); ok { - cfg := &config.EmbeddingsConfig{} - if v, exists := raw["provider"]; exists { - p := strings.TrimSpace(fmt.Sprintf("%v", v)) - if p == "" { - return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty") - } - cfg.Provider = p - } - embCfg = cfg - } - - if raw, ok := req["claude_mapping"].(map[string]any); ok { - claudeMap = map[string]string{} - for k, v := range raw { - key := strings.TrimSpace(k) - val := strings.TrimSpace(fmt.Sprintf("%v", v)) - if key == "" || val == "" { - continue - } - claudeMap[key] = val - } - } - - if raw, ok := req["model_aliases"].(map[string]any); ok { - aliasMap = map[string]string{} - for k, v := range raw { - key := strings.TrimSpace(k) - val := strings.TrimSpace(fmt.Sprintf("%v", v)) - if key == "" || val == "" { - continue - } - aliasMap[key] = val - } - } - - return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil -} diff --git a/internal/admin/handler_settings_parse.go b/internal/admin/handler_settings_parse.go new file mode 100644 index 0000000..6c5b7ee --- /dev/null +++ b/internal/admin/handler_settings_parse.go @@ -0,0 +1,134 @@ +package admin + +import ( + "fmt" + "strings" + + "ds2api/internal/config" +) + +func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) { + var ( + adminCfg *config.AdminConfig + runtimeCfg *config.RuntimeConfig + toolcallCfg *config.ToolcallConfig + respCfg *config.ResponsesConfig + embCfg *config.EmbeddingsConfig + claudeMap map[string]string + aliasMap map[string]string + ) + + if raw, ok := req["admin"].(map[string]any); ok { + cfg := &config.AdminConfig{} + if v, exists := raw["jwt_expire_hours"]; exists { + n := intFrom(v) + if n < 1 || n > 720 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + } + cfg.JWTExpireHours = n + } + adminCfg = cfg + } + + if raw, ok := req["runtime"].(map[string]any); ok { + cfg := &config.RuntimeConfig{} + if v, exists := raw["account_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 256 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + } + cfg.AccountMaxInflight = n + } + if v, exists := raw["account_max_queue"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + } + cfg.AccountMaxQueue = n + } + if v, exists := raw["global_max_inflight"]; exists { + n := intFrom(v) + if n < 1 || n > 200000 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + } + cfg.GlobalMaxInflight = n + } + if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + runtimeCfg = cfg + } + + if raw, ok := req["toolcall"].(map[string]any); ok { + cfg := &config.ToolcallConfig{} + if v, exists := raw["mode"]; exists { + mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch mode { + case "feature_match", "off": + cfg.Mode = mode + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off") + } + } + if v, exists := raw["early_emit_confidence"]; exists { + level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) + switch level { + case "high", "low", "off": + cfg.EarlyEmitConfidence = level + default: + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off") + } + } + toolcallCfg = cfg + } + + if raw, ok := req["responses"].(map[string]any); ok { + cfg := &config.ResponsesConfig{} + if v, exists := raw["store_ttl_seconds"]; exists { + n := intFrom(v) + if n < 30 || n > 86400 { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + } + cfg.StoreTTLSeconds = n + } + respCfg = cfg + } + + if raw, ok := req["embeddings"].(map[string]any); ok { + cfg := &config.EmbeddingsConfig{} + if v, exists := raw["provider"]; exists { + p := strings.TrimSpace(fmt.Sprintf("%v", v)) + if p == "" { + return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty") + } + cfg.Provider = p + } + embCfg = cfg + } + + if raw, ok := req["claude_mapping"].(map[string]any); ok { + claudeMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + claudeMap[key] = val + } + } + + if raw, ok := req["model_aliases"].(map[string]any); ok { + aliasMap = map[string]string{} + for k, v := range raw { + key := strings.TrimSpace(k) + val := strings.TrimSpace(fmt.Sprintf("%v", v)) + if key == "" || val == "" { + continue + } + aliasMap[key] = val + } + } + + return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil +} diff --git a/internal/admin/handler_settings_read.go b/internal/admin/handler_settings_read.go new file mode 100644 index 0000000..565519f --- /dev/null +++ b/internal/admin/handler_settings_read.go @@ -0,0 +1,36 @@ +package admin + +import ( + "net/http" + "strings" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) { + snap := h.Store.Snapshot() + recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight()) + needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "admin": map[string]any{ + "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "", + "jwt_expire_hours": h.Store.AdminJWTExpireHours(), + "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix, + "default_password_warning": authn.UsingDefaultAdminKey(h.Store), + }, + "runtime": map[string]any{ + "account_max_inflight": h.Store.RuntimeAccountMaxInflight(), + "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended), + "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended), + }, + "toolcall": snap.Toolcall, + "responses": snap.Responses, + "embeddings": snap.Embeddings, + "claude_mapping": settingsClaudeMapping(snap), + "model_aliases": snap.ModelAliases, + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + }) +} diff --git a/internal/admin/handler_settings_runtime.go b/internal/admin/handler_settings_runtime.go new file mode 100644 index 0000000..6ff6902 --- /dev/null +++ b/internal/admin/handler_settings_runtime.go @@ -0,0 +1,51 @@ +package admin + +import "ds2api/internal/config" + +func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error { + merged := current + if incoming != nil { + if incoming.AccountMaxInflight > 0 { + merged.AccountMaxInflight = incoming.AccountMaxInflight + } + if incoming.AccountMaxQueue > 0 { + merged.AccountMaxQueue = incoming.AccountMaxQueue + } + if incoming.GlobalMaxInflight > 0 { + merged.GlobalMaxInflight = incoming.GlobalMaxInflight + } + } + return validateRuntimeSettings(merged) +} + +func (h *Handler) applyRuntimeSettings() { + if h == nil || h.Store == nil || h.Pool == nil { + return + } + accountCount := len(h.Store.Accounts()) + maxPer := h.Store.RuntimeAccountMaxInflight() + recommended := defaultRuntimeRecommended(accountCount, maxPer) + maxQueue := h.Store.RuntimeAccountMaxQueue(recommended) + global := h.Store.RuntimeGlobalMaxInflight(recommended) + h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global) +} + +func defaultRuntimeRecommended(accountCount, maxPer int) int { + if maxPer <= 0 { + maxPer = 1 + } + if accountCount <= 0 { + return maxPer + } + return accountCount * maxPer +} + +func settingsClaudeMapping(c config.Config) map[string]string { + if len(c.ClaudeMapping) > 0 { + return c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + return c.ClaudeModelMap + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} diff --git a/internal/admin/handler_settings_write.go b/internal/admin/handler_settings_write.go new file mode 100644 index 0000000..c0076ea --- /dev/null +++ b/internal/admin/handler_settings_write.go @@ -0,0 +1,119 @@ +package admin + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + authn "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + + adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + if runtimeCfg != nil { + if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()}) + return + } + } + + if err := h.Store.Update(func(c *config.Config) error { + if adminCfg != nil { + if adminCfg.JWTExpireHours > 0 { + c.Admin.JWTExpireHours = adminCfg.JWTExpireHours + } + } + if runtimeCfg != nil { + if runtimeCfg.AccountMaxInflight > 0 { + c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight + } + if runtimeCfg.AccountMaxQueue > 0 { + c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue + } + if runtimeCfg.GlobalMaxInflight > 0 { + c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight + } + } + if toolcallCfg != nil { + if strings.TrimSpace(toolcallCfg.Mode) != "" { + c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode) + } + if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" { + c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) + } + } + if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 { + c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds + } + if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" { + c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider) + } + if claudeMap != nil { + c.ClaudeMapping = claudeMap + c.ClaudeModelMap = nil + } + if aliasMap != nil { + c.ModelAliases = aliasMap + } + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + h.applyRuntimeSettings() + needsSync := config.IsVercel() || h.Store.IsEnvBacked() + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "settings updated and hot reloaded", + "env_backed": h.Store.IsEnvBacked(), + "needs_vercel_sync": needsSync, + "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。", + }) +} + +func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"}) + return + } + newPassword := strings.TrimSpace(fieldString(req, "new_password")) + if newPassword == "" { + newPassword = strings.TrimSpace(fieldString(req, "password")) + } + if len(newPassword) < 4 { + writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"}) + return + } + + now := time.Now().Unix() + hash := authn.HashAdminPassword(newPassword) + if err := h.Store.Update(func(c *config.Config) error { + c.Admin.PasswordHash = hash + c.Admin.JWTValidAfterUnix = now + return nil + }); err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "password updated", + "force_relogin": true, + "jwt_valid_after_unix": now, + }) +} diff --git a/internal/config/account.go b/internal/config/account.go new file mode 100644 index 0000000..29a4947 --- /dev/null +++ b/internal/config/account.go @@ -0,0 +1,24 @@ +package config + +import ( + "crypto/sha256" + "encoding/hex" + "strings" +) + +func (a Account) Identifier() string { + if strings.TrimSpace(a.Email) != "" { + return strings.TrimSpace(a.Email) + } + if strings.TrimSpace(a.Mobile) != "" { + return strings.TrimSpace(a.Mobile) + } + // Backward compatibility: old configs may contain token-only accounts. + // Use a stable non-sensitive synthetic id so they can still join the pool. + token := strings.TrimSpace(a.Token) + if token == "" { + return "" + } + sum := sha256.Sum256([]byte(token)) + return "token:" + hex.EncodeToString(sum[:8]) +} diff --git a/internal/config/codec.go b/internal/config/codec.go new file mode 100644 index 0000000..2a23e20 --- /dev/null +++ b/internal/config/codec.go @@ -0,0 +1,241 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "slices" + "strings" +) + +func (c Config) MarshalJSON() ([]byte, error) { + m := map[string]any{} + for k, v := range c.AdditionalFields { + m[k] = v + } + if len(c.Keys) > 0 { + m["keys"] = c.Keys + } + if len(c.Accounts) > 0 { + m["accounts"] = c.Accounts + } + if len(c.ClaudeMapping) > 0 { + m["claude_mapping"] = c.ClaudeMapping + } + if len(c.ClaudeModelMap) > 0 { + m["claude_model_mapping"] = c.ClaudeModelMap + } + if len(c.ModelAliases) > 0 { + m["model_aliases"] = c.ModelAliases + } + if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 { + m["admin"] = c.Admin + } + if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 { + m["runtime"] = c.Runtime + } + if c.Compat.WideInputStrictOutput != nil { + m["compat"] = c.Compat + } + if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { + m["toolcall"] = c.Toolcall + } + if c.Responses.StoreTTLSeconds > 0 { + m["responses"] = c.Responses + } + if strings.TrimSpace(c.Embeddings.Provider) != "" { + m["embeddings"] = c.Embeddings + } + if c.VercelSyncHash != "" { + m["_vercel_sync_hash"] = c.VercelSyncHash + } + if c.VercelSyncTime != 0 { + m["_vercel_sync_time"] = c.VercelSyncTime + } + return json.Marshal(m) +} + +func (c *Config) UnmarshalJSON(b []byte) error { + raw := map[string]json.RawMessage{} + if err := json.Unmarshal(b, &raw); err != nil { + return err + } + c.AdditionalFields = map[string]any{} + for k, v := range raw { + switch k { + case "keys": + if err := json.Unmarshal(v, &c.Keys); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "accounts": + if err := json.Unmarshal(v, &c.Accounts); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "claude_mapping": + if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "claude_model_mapping": + if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "model_aliases": + if err := json.Unmarshal(v, &c.ModelAliases); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "admin": + if err := json.Unmarshal(v, &c.Admin); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "runtime": + if err := json.Unmarshal(v, &c.Runtime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "compat": + if err := json.Unmarshal(v, &c.Compat); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "toolcall": + if err := json.Unmarshal(v, &c.Toolcall); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "responses": + if err := json.Unmarshal(v, &c.Responses); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "embeddings": + if err := json.Unmarshal(v, &c.Embeddings); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "_vercel_sync_hash": + if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + case "_vercel_sync_time": + if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil { + return fmt.Errorf("invalid field %q: %w", k, err) + } + default: + var anyVal any + if err := json.Unmarshal(v, &anyVal); err == nil { + c.AdditionalFields[k] = anyVal + } + } + } + return nil +} + +func (c Config) Clone() Config { + clone := Config{ + Keys: slices.Clone(c.Keys), + Accounts: slices.Clone(c.Accounts), + ClaudeMapping: cloneStringMap(c.ClaudeMapping), + ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), + ModelAliases: cloneStringMap(c.ModelAliases), + Admin: c.Admin, + Runtime: c.Runtime, + Compat: CompatConfig{ + WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), + }, + Toolcall: c.Toolcall, + Responses: c.Responses, + Embeddings: c.Embeddings, + VercelSyncHash: c.VercelSyncHash, + VercelSyncTime: c.VercelSyncTime, + AdditionalFields: map[string]any{}, + } + for k, v := range c.AdditionalFields { + clone.AdditionalFields[k] = v + } + return clone +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func cloneBoolPtr(in *bool) *bool { + if in == nil { + return nil + } + v := *in + return &v +} + +func parseConfigString(raw string) (Config, error) { + var cfg Config + candidates := []string{raw} + if normalized := normalizeConfigInput(raw); normalized != raw { + candidates = append(candidates, normalized) + } + for _, candidate := range candidates { + if err := json.Unmarshal([]byte(candidate), &cfg); err == nil { + return cfg, nil + } + } + + base64Input := candidates[len(candidates)-1] + decoded, err := decodeConfigBase64(base64Input) + if err != nil { + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err) + } + if err := json.Unmarshal(decoded, &cfg); err != nil { + return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err) + } + return cfg, nil +} + +func normalizeConfigInput(raw string) string { + normalized := strings.TrimSpace(raw) + if normalized == "" { + return normalized + } + for { + changed := false + if len(normalized) >= 2 { + first := normalized[0] + last := normalized[len(normalized)-1] + if (first == '"' && last == '"') || (first == '\'' && last == '\'') { + normalized = strings.TrimSpace(normalized[1 : len(normalized)-1]) + changed = true + } + } + if strings.HasPrefix(strings.ToLower(normalized), "base64:") { + normalized = strings.TrimSpace(normalized[len("base64:"):]) + changed = true + } + if !changed { + break + } + } + return strings.TrimSpace(normalized) +} + +func decodeConfigBase64(raw string) ([]byte, error) { + encodings := []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } + var lastErr error + for _, enc := range encodings { + decoded, err := enc.DecodeString(raw) + if err == nil { + return decoded, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("base64 decode failed") +} diff --git a/internal/config/config.go b/internal/config/config.go index 3bc0409..4b281a2 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,63 +1,5 @@ package config -import ( - "crypto/sha256" - "encoding/base64" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "log/slog" - "os" - "path/filepath" - "slices" - "strconv" - "strings" - "sync" -) - -var Logger = newLogger() - -func newLogger() *slog.Logger { - level := new(slog.LevelVar) - switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) { - case "DEBUG": - level.Set(slog.LevelDebug) - case "WARN": - level.Set(slog.LevelWarn) - case "ERROR": - level.Set(slog.LevelError) - default: - level.Set(slog.LevelInfo) - } - h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) - return slog.New(h) -} - -type Account struct { - Email string `json:"email,omitempty"` - Mobile string `json:"mobile,omitempty"` - Password string `json:"password,omitempty"` - Token string `json:"token,omitempty"` -} - -func (a Account) Identifier() string { - if strings.TrimSpace(a.Email) != "" { - return strings.TrimSpace(a.Email) - } - if strings.TrimSpace(a.Mobile) != "" { - return strings.TrimSpace(a.Mobile) - } - // Backward compatibility: old configs may contain token-only accounts. - // Use a stable non-sensitive synthetic id so they can still join the pool. - token := strings.TrimSpace(a.Token) - if token == "" { - return "" - } - sum := sha256.Sum256([]byte(token)) - return "token:" + hex.EncodeToString(sum[:8]) -} - type Config struct { Keys []string `json:"keys,omitempty"` Accounts []Account `json:"accounts,omitempty"` @@ -75,6 +17,13 @@ type Config struct { AdditionalFields map[string]any `json:"-"` } +type Account struct { + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` +} + type CompatConfig struct { WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"` } @@ -103,641 +52,3 @@ type ResponsesConfig struct { type EmbeddingsConfig struct { Provider string `json:"provider,omitempty"` } - -func (c Config) MarshalJSON() ([]byte, error) { - m := map[string]any{} - for k, v := range c.AdditionalFields { - m[k] = v - } - if len(c.Keys) > 0 { - m["keys"] = c.Keys - } - if len(c.Accounts) > 0 { - m["accounts"] = c.Accounts - } - if len(c.ClaudeMapping) > 0 { - m["claude_mapping"] = c.ClaudeMapping - } - if len(c.ClaudeModelMap) > 0 { - m["claude_model_mapping"] = c.ClaudeModelMap - } - if len(c.ModelAliases) > 0 { - m["model_aliases"] = c.ModelAliases - } - if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 { - m["admin"] = c.Admin - } - if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 { - m["runtime"] = c.Runtime - } - if c.Compat.WideInputStrictOutput != nil { - m["compat"] = c.Compat - } - if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" { - m["toolcall"] = c.Toolcall - } - if c.Responses.StoreTTLSeconds > 0 { - m["responses"] = c.Responses - } - if strings.TrimSpace(c.Embeddings.Provider) != "" { - m["embeddings"] = c.Embeddings - } - if c.VercelSyncHash != "" { - m["_vercel_sync_hash"] = c.VercelSyncHash - } - if c.VercelSyncTime != 0 { - m["_vercel_sync_time"] = c.VercelSyncTime - } - return json.Marshal(m) -} - -func (c *Config) UnmarshalJSON(b []byte) error { - raw := map[string]json.RawMessage{} - if err := json.Unmarshal(b, &raw); err != nil { - return err - } - c.AdditionalFields = map[string]any{} - for k, v := range raw { - switch k { - case "keys": - if err := json.Unmarshal(v, &c.Keys); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "accounts": - if err := json.Unmarshal(v, &c.Accounts); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "claude_mapping": - if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "claude_model_mapping": - if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "model_aliases": - if err := json.Unmarshal(v, &c.ModelAliases); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "admin": - if err := json.Unmarshal(v, &c.Admin); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "runtime": - if err := json.Unmarshal(v, &c.Runtime); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "compat": - if err := json.Unmarshal(v, &c.Compat); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "toolcall": - if err := json.Unmarshal(v, &c.Toolcall); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "responses": - if err := json.Unmarshal(v, &c.Responses); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "embeddings": - if err := json.Unmarshal(v, &c.Embeddings); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "_vercel_sync_hash": - if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - case "_vercel_sync_time": - if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil { - return fmt.Errorf("invalid field %q: %w", k, err) - } - default: - var anyVal any - if err := json.Unmarshal(v, &anyVal); err == nil { - c.AdditionalFields[k] = anyVal - } - } - } - return nil -} - -func (c Config) Clone() Config { - clone := Config{ - Keys: slices.Clone(c.Keys), - Accounts: slices.Clone(c.Accounts), - ClaudeMapping: cloneStringMap(c.ClaudeMapping), - ClaudeModelMap: cloneStringMap(c.ClaudeModelMap), - ModelAliases: cloneStringMap(c.ModelAliases), - Admin: c.Admin, - Runtime: c.Runtime, - Compat: CompatConfig{ - WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput), - }, - Toolcall: c.Toolcall, - Responses: c.Responses, - Embeddings: c.Embeddings, - VercelSyncHash: c.VercelSyncHash, - VercelSyncTime: c.VercelSyncTime, - AdditionalFields: map[string]any{}, - } - for k, v := range c.AdditionalFields { - clone.AdditionalFields[k] = v - } - return clone -} - -func cloneStringMap(in map[string]string) map[string]string { - if len(in) == 0 { - return nil - } - out := make(map[string]string, len(in)) - for k, v := range in { - out[k] = v - } - return out -} - -func cloneBoolPtr(in *bool) *bool { - if in == nil { - return nil - } - v := *in - return &v -} - -type Store struct { - mu sync.RWMutex - cfg Config - path string - fromEnv bool - keyMap map[string]struct{} // O(1) API key lookup index - accMap map[string]int // O(1) account lookup: identifier -> slice index -} - -func BaseDir() string { - cwd, err := os.Getwd() - if err != nil { - return "." - } - return cwd -} - -func IsVercel() bool { - return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" -} - -func ResolvePath(envKey, defaultRel string) string { - raw := strings.TrimSpace(os.Getenv(envKey)) - if raw != "" { - if filepath.IsAbs(raw) { - return raw - } - return filepath.Join(BaseDir(), raw) - } - return filepath.Join(BaseDir(), defaultRel) -} - -func ConfigPath() string { - return ResolvePath("DS2API_CONFIG_PATH", "config.json") -} - -func WASMPath() string { - return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") -} - -func StaticAdminDir() string { - return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin") -} - -func LoadStore() *Store { - cfg, fromEnv, err := loadConfig() - if err != nil { - Logger.Warn("[config] load failed", "error", err) - } - if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { - Logger.Warn("[config] empty config loaded") - } - s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} - s.rebuildIndexes() - return s -} - -// rebuildIndexes must be called with the lock already held (or during init). -func (s *Store) rebuildIndexes() { - s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) - for _, k := range s.cfg.Keys { - s.keyMap[k] = struct{}{} - } - s.accMap = make(map[string]int, len(s.cfg.Accounts)) - for i, acc := range s.cfg.Accounts { - id := acc.Identifier() - if id != "" { - s.accMap[id] = i - } - } -} - -func loadConfig() (Config, bool, error) { - rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON")) - if rawCfg == "" { - rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON")) - } - if rawCfg != "" { - cfg, err := parseConfigString(rawCfg) - return cfg, true, err - } - - content, err := os.ReadFile(ConfigPath()) - if err != nil { - if IsVercel() { - // Vercel one-click deploy may start without a writable/present config file. - // Keep an in-memory config so users can bootstrap via WebUI then sync env. - return Config{}, true, nil - } - return Config{}, false, err - } - var cfg Config - if err := json.Unmarshal(content, &cfg); err != nil { - return Config{}, false, err - } - if IsVercel() { - // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. - return cfg, true, nil - } - return cfg, false, nil -} - -func parseConfigString(raw string) (Config, error) { - var cfg Config - candidates := []string{raw} - if normalized := normalizeConfigInput(raw); normalized != raw { - candidates = append(candidates, normalized) - } - for _, candidate := range candidates { - if err := json.Unmarshal([]byte(candidate), &cfg); err == nil { - return cfg, nil - } - } - - base64Input := candidates[len(candidates)-1] - decoded, err := decodeConfigBase64(base64Input) - if err != nil { - return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err) - } - if err := json.Unmarshal(decoded, &cfg); err != nil { - return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err) - } - return cfg, nil -} - -func normalizeConfigInput(raw string) string { - normalized := strings.TrimSpace(raw) - if normalized == "" { - return normalized - } - for { - changed := false - if len(normalized) >= 2 { - first := normalized[0] - last := normalized[len(normalized)-1] - if (first == '"' && last == '"') || (first == '\'' && last == '\'') { - normalized = strings.TrimSpace(normalized[1 : len(normalized)-1]) - changed = true - } - } - if strings.HasPrefix(strings.ToLower(normalized), "base64:") { - normalized = strings.TrimSpace(normalized[len("base64:"):]) - changed = true - } - if !changed { - break - } - } - return strings.TrimSpace(normalized) -} - -func decodeConfigBase64(raw string) ([]byte, error) { - encodings := []*base64.Encoding{ - base64.StdEncoding, - base64.RawStdEncoding, - base64.URLEncoding, - base64.RawURLEncoding, - } - var lastErr error - for _, enc := range encodings { - decoded, err := enc.DecodeString(raw) - if err == nil { - return decoded, nil - } - lastErr = err - } - if lastErr != nil { - return nil, lastErr - } - return nil, errors.New("base64 decode failed") -} - -func (s *Store) Snapshot() Config { - s.mu.RLock() - defer s.mu.RUnlock() - return s.cfg.Clone() -} - -func (s *Store) HasAPIKey(k string) bool { - s.mu.RLock() - defer s.mu.RUnlock() - _, ok := s.keyMap[k] - return ok -} - -func (s *Store) Keys() []string { - s.mu.RLock() - defer s.mu.RUnlock() - return slices.Clone(s.cfg.Keys) -} - -func (s *Store) Accounts() []Account { - s.mu.RLock() - defer s.mu.RUnlock() - return slices.Clone(s.cfg.Accounts) -} - -func (s *Store) FindAccount(identifier string) (Account, bool) { - identifier = strings.TrimSpace(identifier) - s.mu.RLock() - defer s.mu.RUnlock() - if idx, ok := s.findAccountIndexLocked(identifier); ok { - return s.cfg.Accounts[idx], true - } - return Account{}, false -} - -func (s *Store) UpdateAccountToken(identifier, token string) error { - identifier = strings.TrimSpace(identifier) - s.mu.Lock() - defer s.mu.Unlock() - idx, ok := s.findAccountIndexLocked(identifier) - if !ok { - return errors.New("account not found") - } - oldID := s.cfg.Accounts[idx].Identifier() - s.cfg.Accounts[idx].Token = token - newID := s.cfg.Accounts[idx].Identifier() - // Keep historical aliases usable for long-lived queues while also adding - // the latest identifier after token refresh. - if identifier != "" { - s.accMap[identifier] = idx - } - if oldID != "" { - s.accMap[oldID] = idx - } - if newID != "" { - s.accMap[newID] = idx - } - return s.saveLocked() -} - -func (s *Store) Replace(cfg Config) error { - s.mu.Lock() - defer s.mu.Unlock() - s.cfg = cfg.Clone() - s.rebuildIndexes() - return s.saveLocked() -} - -func (s *Store) Update(mutator func(*Config) error) error { - s.mu.Lock() - defer s.mu.Unlock() - cfg := s.cfg.Clone() - if err := mutator(&cfg); err != nil { - return err - } - s.cfg = cfg - s.rebuildIndexes() - return s.saveLocked() -} - -func (s *Store) Save() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.fromEnv { - Logger.Info("[save_config] source from env, skip write") - return nil - } - b, err := json.MarshalIndent(s.cfg, "", " ") - if err != nil { - return err - } - return os.WriteFile(s.path, b, 0o644) -} - -func (s *Store) saveLocked() error { - if s.fromEnv { - Logger.Info("[save_config] source from env, skip write") - return nil - } - b, err := json.MarshalIndent(s.cfg, "", " ") - if err != nil { - return err - } - return os.WriteFile(s.path, b, 0o644) -} - -// findAccountIndexLocked expects the store lock to already be held. -func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { - if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) { - return idx, true - } - // Fallback for token-only accounts whose derived identifier changed after - // a token refresh; this preserves correctness on map misses. - for i, acc := range s.cfg.Accounts { - if acc.Identifier() == identifier { - return i, true - } - } - return -1, false -} - -func (s *Store) IsEnvBacked() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.fromEnv -} - -func (s *Store) SetVercelSync(hash string, ts int64) error { - return s.Update(func(c *Config) error { - c.VercelSyncHash = hash - c.VercelSyncTime = ts - return nil - }) -} - -func (s *Store) ExportJSONAndBase64() (string, string, error) { - s.mu.RLock() - defer s.mu.RUnlock() - b, err := json.Marshal(s.cfg) - if err != nil { - return "", "", err - } - return string(b), base64.StdEncoding.EncodeToString(b), nil -} - -func (s *Store) ClaudeMapping() map[string]string { - s.mu.RLock() - defer s.mu.RUnlock() - if len(s.cfg.ClaudeModelMap) > 0 { - return cloneStringMap(s.cfg.ClaudeModelMap) - } - if len(s.cfg.ClaudeMapping) > 0 { - return cloneStringMap(s.cfg.ClaudeMapping) - } - return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} -} - -func (s *Store) ModelAliases() map[string]string { - s.mu.RLock() - defer s.mu.RUnlock() - out := DefaultModelAliases() - for k, v := range s.cfg.ModelAliases { - key := strings.TrimSpace(lower(k)) - val := strings.TrimSpace(lower(v)) - if key == "" || val == "" { - continue - } - out[key] = val - } - return out -} - -func (s *Store) CompatWideInputStrictOutput() bool { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Compat.WideInputStrictOutput == nil { - return true - } - return *s.cfg.Compat.WideInputStrictOutput -} - -func (s *Store) ToolcallMode() string { - s.mu.RLock() - defer s.mu.RUnlock() - mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode)) - if mode == "" { - return "feature_match" - } - return mode -} - -func (s *Store) ToolcallEarlyEmitConfidence() string { - s.mu.RLock() - defer s.mu.RUnlock() - level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence)) - if level == "" { - return "high" - } - return level -} - -func (s *Store) ResponsesStoreTTLSeconds() int { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Responses.StoreTTLSeconds > 0 { - return s.cfg.Responses.StoreTTLSeconds - } - return 900 -} - -func (s *Store) EmbeddingsProvider() string { - s.mu.RLock() - defer s.mu.RUnlock() - return strings.TrimSpace(s.cfg.Embeddings.Provider) -} - -func (s *Store) AdminPasswordHash() string { - s.mu.RLock() - defer s.mu.RUnlock() - return strings.TrimSpace(s.cfg.Admin.PasswordHash) -} - -func (s *Store) AdminJWTExpireHours() int { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Admin.JWTExpireHours > 0 { - return s.cfg.Admin.JWTExpireHours - } - if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" { - if n, err := strconv.Atoi(raw); err == nil && n > 0 { - return n - } - } - return 24 -} - -func (s *Store) AdminJWTValidAfterUnix() int64 { - s.mu.RLock() - defer s.mu.RUnlock() - return s.cfg.Admin.JWTValidAfterUnix -} - -func (s *Store) RuntimeAccountMaxInflight() int { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Runtime.AccountMaxInflight > 0 { - return s.cfg.Runtime.AccountMaxInflight - } - for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n > 0 { - return n - } - } - return 2 -} - -func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Runtime.AccountMaxQueue > 0 { - return s.cfg.Runtime.AccountMaxQueue - } - for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n >= 0 { - return n - } - } - if defaultSize < 0 { - return 0 - } - return defaultSize -} - -func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int { - s.mu.RLock() - defer s.mu.RUnlock() - if s.cfg.Runtime.GlobalMaxInflight > 0 { - return s.cfg.Runtime.GlobalMaxInflight - } - for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} { - raw := strings.TrimSpace(os.Getenv(key)) - if raw == "" { - continue - } - n, err := strconv.Atoi(raw) - if err == nil && n > 0 { - return n - } - } - if defaultSize < 0 { - return 0 - } - return defaultSize -} diff --git a/internal/config/logger.go b/internal/config/logger.go new file mode 100644 index 0000000..8b2de91 --- /dev/null +++ b/internal/config/logger.go @@ -0,0 +1,25 @@ +package config + +import ( + "log/slog" + "os" + "strings" +) + +var Logger = newLogger() + +func newLogger() *slog.Logger { + level := new(slog.LevelVar) + switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) { + case "DEBUG": + level.Set(slog.LevelDebug) + case "WARN": + level.Set(slog.LevelWarn) + case "ERROR": + level.Set(slog.LevelError) + default: + level.Set(slog.LevelInfo) + } + h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) + return slog.New(h) +} diff --git a/internal/config/paths.go b/internal/config/paths.go new file mode 100644 index 0000000..23dfe54 --- /dev/null +++ b/internal/config/paths.go @@ -0,0 +1,42 @@ +package config + +import ( + "os" + "path/filepath" + "strings" +) + +func BaseDir() string { + cwd, err := os.Getwd() + if err != nil { + return "." + } + return cwd +} + +func IsVercel() bool { + return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != "" +} + +func ResolvePath(envKey, defaultRel string) string { + raw := strings.TrimSpace(os.Getenv(envKey)) + if raw != "" { + if filepath.IsAbs(raw) { + return raw + } + return filepath.Join(BaseDir(), raw) + } + return filepath.Join(BaseDir(), defaultRel) +} + +func ConfigPath() string { + return ResolvePath("DS2API_CONFIG_PATH", "config.json") +} + +func WASMPath() string { + return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm") +} + +func StaticAdminDir() string { + return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin") +} diff --git a/internal/config/store.go b/internal/config/store.go new file mode 100644 index 0000000..2e6fcaf --- /dev/null +++ b/internal/config/store.go @@ -0,0 +1,193 @@ +package config + +import ( + "encoding/base64" + "encoding/json" + "errors" + "os" + "slices" + "strings" + "sync" +) + +type Store struct { + mu sync.RWMutex + cfg Config + path string + fromEnv bool + keyMap map[string]struct{} // O(1) API key lookup index + accMap map[string]int // O(1) account lookup: identifier -> slice index +} + +func LoadStore() *Store { + cfg, fromEnv, err := loadConfig() + if err != nil { + Logger.Warn("[config] load failed", "error", err) + } + if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { + Logger.Warn("[config] empty config loaded") + } + s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} + s.rebuildIndexes() + return s +} + +func loadConfig() (Config, bool, error) { + rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON")) + if rawCfg == "" { + rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON")) + } + if rawCfg != "" { + cfg, err := parseConfigString(rawCfg) + return cfg, true, err + } + + content, err := os.ReadFile(ConfigPath()) + if err != nil { + if IsVercel() { + // Vercel one-click deploy may start without a writable/present config file. + // Keep an in-memory config so users can bootstrap via WebUI then sync env. + return Config{}, true, nil + } + return Config{}, false, err + } + var cfg Config + if err := json.Unmarshal(content, &cfg); err != nil { + return Config{}, false, err + } + if IsVercel() { + // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors. + return cfg, true, nil + } + return cfg, false, nil +} + +func (s *Store) Snapshot() Config { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Clone() +} + +func (s *Store) HasAPIKey(k string) bool { + s.mu.RLock() + defer s.mu.RUnlock() + _, ok := s.keyMap[k] + return ok +} + +func (s *Store) Keys() []string { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Keys) +} + +func (s *Store) Accounts() []Account { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.cfg.Accounts) +} + +func (s *Store) FindAccount(identifier string) (Account, bool) { + identifier = strings.TrimSpace(identifier) + s.mu.RLock() + defer s.mu.RUnlock() + if idx, ok := s.findAccountIndexLocked(identifier); ok { + return s.cfg.Accounts[idx], true + } + return Account{}, false +} + +func (s *Store) UpdateAccountToken(identifier, token string) error { + identifier = strings.TrimSpace(identifier) + s.mu.Lock() + defer s.mu.Unlock() + idx, ok := s.findAccountIndexLocked(identifier) + if !ok { + return errors.New("account not found") + } + oldID := s.cfg.Accounts[idx].Identifier() + s.cfg.Accounts[idx].Token = token + newID := s.cfg.Accounts[idx].Identifier() + // Keep historical aliases usable for long-lived queues while also adding + // the latest identifier after token refresh. + if identifier != "" { + s.accMap[identifier] = idx + } + if oldID != "" { + s.accMap[oldID] = idx + } + if newID != "" { + s.accMap[newID] = idx + } + return s.saveLocked() +} + +func (s *Store) Replace(cfg Config) error { + s.mu.Lock() + defer s.mu.Unlock() + s.cfg = cfg.Clone() + s.rebuildIndexes() + return s.saveLocked() +} + +func (s *Store) Update(mutator func(*Config) error) error { + s.mu.Lock() + defer s.mu.Unlock() + cfg := s.cfg.Clone() + if err := mutator(&cfg); err != nil { + return err + } + s.cfg = cfg + s.rebuildIndexes() + return s.saveLocked() +} + +func (s *Store) Save() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) saveLocked() error { + if s.fromEnv { + Logger.Info("[save_config] source from env, skip write") + return nil + } + b, err := json.MarshalIndent(s.cfg, "", " ") + if err != nil { + return err + } + return os.WriteFile(s.path, b, 0o644) +} + +func (s *Store) IsEnvBacked() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.fromEnv +} + +func (s *Store) SetVercelSync(hash string, ts int64) error { + return s.Update(func(c *Config) error { + c.VercelSyncHash = hash + c.VercelSyncTime = ts + return nil + }) +} + +func (s *Store) ExportJSONAndBase64() (string, string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + b, err := json.Marshal(s.cfg) + if err != nil { + return "", "", err + } + return string(b), base64.StdEncoding.EncodeToString(b), nil +} diff --git a/internal/config/store_accessors.go b/internal/config/store_accessors.go new file mode 100644 index 0000000..f0c5938 --- /dev/null +++ b/internal/config/store_accessors.go @@ -0,0 +1,167 @@ +package config + +import ( + "os" + "strconv" + "strings" +) + +func (s *Store) ClaudeMapping() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + if len(s.cfg.ClaudeModelMap) > 0 { + return cloneStringMap(s.cfg.ClaudeModelMap) + } + if len(s.cfg.ClaudeMapping) > 0 { + return cloneStringMap(s.cfg.ClaudeMapping) + } + return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"} +} + +func (s *Store) ModelAliases() map[string]string { + s.mu.RLock() + defer s.mu.RUnlock() + out := DefaultModelAliases() + for k, v := range s.cfg.ModelAliases { + key := strings.TrimSpace(lower(k)) + val := strings.TrimSpace(lower(v)) + if key == "" || val == "" { + continue + } + out[key] = val + } + return out +} + +func (s *Store) CompatWideInputStrictOutput() bool { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Compat.WideInputStrictOutput == nil { + return true + } + return *s.cfg.Compat.WideInputStrictOutput +} + +func (s *Store) ToolcallMode() string { + s.mu.RLock() + defer s.mu.RUnlock() + mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode)) + if mode == "" { + return "feature_match" + } + return mode +} + +func (s *Store) ToolcallEarlyEmitConfidence() string { + s.mu.RLock() + defer s.mu.RUnlock() + level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence)) + if level == "" { + return "high" + } + return level +} + +func (s *Store) ResponsesStoreTTLSeconds() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Responses.StoreTTLSeconds > 0 { + return s.cfg.Responses.StoreTTLSeconds + } + return 900 +} + +func (s *Store) EmbeddingsProvider() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Embeddings.Provider) +} + +func (s *Store) AdminPasswordHash() string { + s.mu.RLock() + defer s.mu.RUnlock() + return strings.TrimSpace(s.cfg.Admin.PasswordHash) +} + +func (s *Store) AdminJWTExpireHours() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Admin.JWTExpireHours > 0 { + return s.cfg.Admin.JWTExpireHours + } + if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" { + if n, err := strconv.Atoi(raw); err == nil && n > 0 { + return n + } + } + return 24 +} + +func (s *Store) AdminJWTValidAfterUnix() int64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.cfg.Admin.JWTValidAfterUnix +} + +func (s *Store) RuntimeAccountMaxInflight() int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxInflight > 0 { + return s.cfg.Runtime.AccountMaxInflight + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + return 2 +} + +func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.AccountMaxQueue > 0 { + return s.cfg.Runtime.AccountMaxQueue + } + for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n >= 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} + +func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int { + s.mu.RLock() + defer s.mu.RUnlock() + if s.cfg.Runtime.GlobalMaxInflight > 0 { + return s.cfg.Runtime.GlobalMaxInflight + } + for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} { + raw := strings.TrimSpace(os.Getenv(key)) + if raw == "" { + continue + } + n, err := strconv.Atoi(raw) + if err == nil && n > 0 { + return n + } + } + if defaultSize < 0 { + return 0 + } + return defaultSize +} diff --git a/internal/config/store_index.go b/internal/config/store_index.go new file mode 100644 index 0000000..7d0f62a --- /dev/null +++ b/internal/config/store_index.go @@ -0,0 +1,31 @@ +package config + +// rebuildIndexes must be called with the lock already held (or during init). +func (s *Store) rebuildIndexes() { + s.keyMap = make(map[string]struct{}, len(s.cfg.Keys)) + for _, k := range s.cfg.Keys { + s.keyMap[k] = struct{}{} + } + s.accMap = make(map[string]int, len(s.cfg.Accounts)) + for i, acc := range s.cfg.Accounts { + id := acc.Identifier() + if id != "" { + s.accMap[id] = i + } + } +} + +// findAccountIndexLocked expects the store lock to already be held. +func (s *Store) findAccountIndexLocked(identifier string) (int, bool) { + if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) { + return idx, true + } + // Fallback for token-only accounts whose derived identifier changed after + // a token refresh; this preserves correctness on map misses. + for i, acc := range s.cfg.Accounts { + if acc.Identifier() == identifier { + return i, true + } + } + return -1, false +} diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go deleted file mode 100644 index 2ffe05d..0000000 --- a/internal/deepseek/client.go +++ /dev/null @@ -1,347 +0,0 @@ -package deepseek - -import ( - "bufio" - "bytes" - "compress/gzip" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strings" - "time" - - "ds2api/internal/auth" - "ds2api/internal/config" - trans "ds2api/internal/deepseek/transport" - "ds2api/internal/devcapture" - "ds2api/internal/util" - - "github.com/andybalholm/brotli" -) - -// intFrom is a package-internal alias for the shared util version. -var intFrom = util.IntFrom - -type Client struct { - Store *config.Store - Auth *auth.Resolver - capture *devcapture.Store - regular trans.Doer - stream trans.Doer - fallback *http.Client - fallbackS *http.Client - powSolver *PowSolver - maxRetries int -} - -func NewClient(store *config.Store, resolver *auth.Resolver) *Client { - return &Client{ - Store: store, - Auth: resolver, - capture: devcapture.Global(), - regular: trans.New(60 * time.Second), - stream: trans.New(0), - fallback: &http.Client{Timeout: 60 * time.Second}, - fallbackS: &http.Client{Timeout: 0}, - powSolver: NewPowSolver(config.WASMPath()), - maxRetries: 3, - } -} - -func (c *Client) PreloadPow(ctx context.Context) error { - return c.powSolver.init(ctx) -} - -func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { - payload := map[string]any{ - "password": strings.TrimSpace(acc.Password), - "device_id": "deepseek_to_api", - "os": "android", - } - if email := strings.TrimSpace(acc.Email); email != "" { - payload["email"] = email - } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { - payload["mobile"] = mobile - payload["area_code"] = nil - } else { - return "", errors.New("missing email/mobile") - } - resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload) - if err != nil { - return "", err - } - code := intFrom(resp["code"]) - if code != 0 { - return "", fmt.Errorf("login failed: %v", resp["msg"]) - } - data, _ := resp["data"].(map[string]any) - if intFrom(data["biz_code"]) != 0 { - return "", fmt.Errorf("login failed: %v", data["biz_msg"]) - } - bizData, _ := data["biz_data"].(map[string]any) - user, _ := bizData["user"].(map[string]any) - token, _ := user["token"].(string) - if strings.TrimSpace(token) == "" { - return "", errors.New("missing login token") - } - return token, nil -} - -func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - attempts := 0 - refreshed := false - for attempts < maxAttempts { - headers := c.authHeaders(a.DeepSeekToken) - resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"}) - if err != nil { - config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID) - attempts++ - continue - } - code := intFrom(resp["code"]) - if status == http.StatusOK && code == 0 { - data, _ := resp["data"].(map[string]any) - bizData, _ := data["biz_data"].(map[string]any) - sessionID, _ := bizData["id"].(string) - if sessionID != "" { - return sessionID, nil - } - } - msg, _ := resp["msg"].(string) - config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) - if a.UseConfigToken { - if isTokenInvalid(status, code, msg) && !refreshed { - if c.Auth.RefreshToken(ctx, a) { - refreshed = true - continue - } - } - if c.Auth.SwitchAccount(ctx, a) { - refreshed = false - attempts++ - continue - } - } - attempts++ - } - return "", errors.New("create session failed") -} - -func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - attempts := 0 - for attempts < maxAttempts { - headers := c.authHeaders(a.DeepSeekToken) - resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) - if err != nil { - config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID) - attempts++ - continue - } - code := intFrom(resp["code"]) - if status == http.StatusOK && code == 0 { - data, _ := resp["data"].(map[string]any) - bizData, _ := data["biz_data"].(map[string]any) - challenge, _ := bizData["challenge"].(map[string]any) - answer, err := c.powSolver.Compute(ctx, challenge) - if err != nil { - attempts++ - continue - } - return BuildPowHeader(challenge, answer) - } - msg, _ := resp["msg"].(string) - config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) - if a.UseConfigToken { - if isTokenInvalid(status, code, msg) { - if c.Auth.RefreshToken(ctx, a) { - continue - } - } - if c.Auth.SwitchAccount(ctx, a) { - attempts++ - continue - } - } - attempts++ - } - return "", errors.New("get pow failed") -} - -func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) { - if maxAttempts <= 0 { - maxAttempts = c.maxRetries - } - headers := c.authHeaders(a.DeepSeekToken) - headers["x-ds-pow-response"] = powResp - captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload) - attempts := 0 - for attempts < maxAttempts { - resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) - if err != nil { - attempts++ - time.Sleep(time.Second) - continue - } - if resp.StatusCode == http.StatusOK { - if captureSession != nil { - resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) - } - return resp, nil - } - if captureSession != nil { - resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) - } - _ = resp.Body.Close() - attempts++ - time.Sleep(time.Second) - } - return nil, errors.New("completion failed") -} - -func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) { - body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload) - if err != nil { - return nil, err - } - if status == 0 { - return nil, errors.New("request failed") - } - return body, nil -} - -func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) { - b, err := json.Marshal(payload) - if err != nil { - return nil, 0, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if err != nil { - return nil, 0, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - resp, err := doer.Do(req) - if err != nil { - config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) - req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if reqErr != nil { - return nil, 0, err - } - for k, v := range headers { - req2.Header.Set(k, v) - } - resp, err = c.fallback.Do(req2) - if err != nil { - return nil, 0, err - } - } - defer resp.Body.Close() - payloadBytes, err := readResponseBody(resp) - if err != nil { - return nil, resp.StatusCode, err - } - out := map[string]any{} - if len(payloadBytes) > 0 { - if err := json.Unmarshal(payloadBytes, &out); err != nil { - config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes)) - } - } - return out, resp.StatusCode, nil -} - -func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) { - b, err := json.Marshal(payload) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if err != nil { - return nil, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - resp, err := c.stream.Do(req) - if err != nil { - config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) - req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) - if reqErr != nil { - return nil, err - } - for k, v := range headers { - req2.Header.Set(k, v) - } - return c.fallbackS.Do(req2) - } - return resp, nil -} - -func (c *Client) authHeaders(token string) map[string]string { - headers := make(map[string]string, len(BaseHeaders)+1) - for k, v := range BaseHeaders { - headers[k] = v - } - headers["authorization"] = "Bearer " + token - return headers -} - -func isTokenInvalid(status int, code int, msg string) bool { - msg = strings.ToLower(msg) - if status == http.StatusUnauthorized || status == http.StatusForbidden { - return true - } - if code == 40001 || code == 40002 || code == 40003 { - return true - } - return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") -} - -func readResponseBody(resp *http.Response) ([]byte, error) { - encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) - var reader io.Reader = resp.Body - switch encoding { - case "gzip": - gz, err := gzip.NewReader(resp.Body) - if err != nil { - return nil, err - } - defer gz.Close() - reader = gz - case "br": - reader = brotli.NewReader(resp.Body) - } - return io.ReadAll(reader) -} - -func preview(b []byte) string { - s := strings.TrimSpace(string(b)) - if len(s) > 160 { - return s[:160] - } - return s -} - -func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error { - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 2*1024*1024) - for scanner.Scan() { - if !onLine(scanner.Bytes()) { - break - } - } - if err := scanner.Err(); err != nil { - return err - } - return nil -} diff --git a/internal/deepseek/client_auth.go b/internal/deepseek/client_auth.go new file mode 100644 index 0000000..820acaf --- /dev/null +++ b/internal/deepseek/client_auth.go @@ -0,0 +1,153 @@ +package deepseek + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) { + payload := map[string]any{ + "password": strings.TrimSpace(acc.Password), + "device_id": "deepseek_to_api", + "os": "android", + } + if email := strings.TrimSpace(acc.Email); email != "" { + payload["email"] = email + } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" { + payload["mobile"] = mobile + payload["area_code"] = nil + } else { + return "", errors.New("missing email/mobile") + } + resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload) + if err != nil { + return "", err + } + code := intFrom(resp["code"]) + if code != 0 { + return "", fmt.Errorf("login failed: %v", resp["msg"]) + } + data, _ := resp["data"].(map[string]any) + if intFrom(data["biz_code"]) != 0 { + return "", fmt.Errorf("login failed: %v", data["biz_msg"]) + } + bizData, _ := data["biz_data"].(map[string]any) + user, _ := bizData["user"].(map[string]any) + token, _ := user["token"].(string) + if strings.TrimSpace(token) == "" { + return "", errors.New("missing login token") + } + return token, nil +} + +func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + refreshed := false + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"}) + if err != nil { + config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + sessionID, _ := bizData["id"].(string) + if sessionID != "" { + return sessionID, nil + } + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) && !refreshed { + if c.Auth.RefreshToken(ctx, a) { + refreshed = true + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + refreshed = false + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("create session failed") +} + +func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + attempts := 0 + for attempts < maxAttempts { + headers := c.authHeaders(a.DeepSeekToken) + resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"}) + if err != nil { + config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID) + attempts++ + continue + } + code := intFrom(resp["code"]) + if status == http.StatusOK && code == 0 { + data, _ := resp["data"].(map[string]any) + bizData, _ := data["biz_data"].(map[string]any) + challenge, _ := bizData["challenge"].(map[string]any) + answer, err := c.powSolver.Compute(ctx, challenge) + if err != nil { + attempts++ + continue + } + return BuildPowHeader(challenge, answer) + } + msg, _ := resp["msg"].(string) + config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID) + if a.UseConfigToken { + if isTokenInvalid(status, code, msg) { + if c.Auth.RefreshToken(ctx, a) { + continue + } + } + if c.Auth.SwitchAccount(ctx, a) { + attempts++ + continue + } + } + attempts++ + } + return "", errors.New("get pow failed") +} + +func (c *Client) authHeaders(token string) map[string]string { + headers := make(map[string]string, len(BaseHeaders)+1) + for k, v := range BaseHeaders { + headers[k] = v + } + headers["authorization"] = "Bearer " + token + return headers +} + +func isTokenInvalid(status int, code int, msg string) bool { + msg = strings.ToLower(msg) + if status == http.StatusUnauthorized || status == http.StatusForbidden { + return true + } + if code == 40001 || code == 40002 || code == 40003 { + return true + } + return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized") +} diff --git a/internal/deepseek/client_completion.go b/internal/deepseek/client_completion.go new file mode 100644 index 0000000..051bffe --- /dev/null +++ b/internal/deepseek/client_completion.go @@ -0,0 +1,71 @@ +package deepseek + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" +) + +func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) { + if maxAttempts <= 0 { + maxAttempts = c.maxRetries + } + headers := c.authHeaders(a.DeepSeekToken) + headers["x-ds-pow-response"] = powResp + captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload) + attempts := 0 + for attempts < maxAttempts { + resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload) + if err != nil { + attempts++ + time.Sleep(time.Second) + continue + } + if resp.StatusCode == http.StatusOK { + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } + return resp, nil + } + if captureSession != nil { + resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode) + } + _ = resp.Body.Close() + attempts++ + time.Sleep(time.Second) + } + return nil, errors.New("completion failed") +} + +func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := c.stream.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + return c.fallbackS.Do(req2) + } + return resp, nil +} diff --git a/internal/deepseek/client_core.go b/internal/deepseek/client_core.go new file mode 100644 index 0000000..cda6edc --- /dev/null +++ b/internal/deepseek/client_core.go @@ -0,0 +1,46 @@ +package deepseek + +import ( + "context" + "net/http" + "time" + + "ds2api/internal/auth" + "ds2api/internal/config" + trans "ds2api/internal/deepseek/transport" + "ds2api/internal/devcapture" + "ds2api/internal/util" +) + +// intFrom is a package-internal alias for the shared util version. +var intFrom = util.IntFrom + +type Client struct { + Store *config.Store + Auth *auth.Resolver + capture *devcapture.Store + regular trans.Doer + stream trans.Doer + fallback *http.Client + fallbackS *http.Client + powSolver *PowSolver + maxRetries int +} + +func NewClient(store *config.Store, resolver *auth.Resolver) *Client { + return &Client{ + Store: store, + Auth: resolver, + capture: devcapture.Global(), + regular: trans.New(60 * time.Second), + stream: trans.New(0), + fallback: &http.Client{Timeout: 60 * time.Second}, + fallbackS: &http.Client{Timeout: 0}, + powSolver: NewPowSolver(config.WASMPath()), + maxRetries: 3, + } +} + +func (c *Client) PreloadPow(ctx context.Context) error { + return c.powSolver.init(ctx) +} diff --git a/internal/deepseek/client_http_helpers.go b/internal/deepseek/client_http_helpers.go new file mode 100644 index 0000000..05de224 --- /dev/null +++ b/internal/deepseek/client_http_helpers.go @@ -0,0 +1,51 @@ +package deepseek + +import ( + "bufio" + "compress/gzip" + "io" + "net/http" + "strings" + + "github.com/andybalholm/brotli" +) + +func readResponseBody(resp *http.Response) ([]byte, error) { + encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding"))) + var reader io.Reader = resp.Body + switch encoding { + case "gzip": + gz, err := gzip.NewReader(resp.Body) + if err != nil { + return nil, err + } + defer gz.Close() + reader = gz + case "br": + reader = brotli.NewReader(resp.Body) + } + return io.ReadAll(reader) +} + +func preview(b []byte) string { + s := strings.TrimSpace(string(b)) + if len(s) > 160 { + return s[:160] + } + return s +} + +func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error { + scanner := bufio.NewScanner(resp.Body) + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 2*1024*1024) + for scanner.Scan() { + if !onLine(scanner.Bytes()) { + break + } + } + if err := scanner.Err(); err != nil { + return err + } + return nil +} diff --git a/internal/deepseek/client_http_json.go b/internal/deepseek/client_http_json.go new file mode 100644 index 0000000..6d3599d --- /dev/null +++ b/internal/deepseek/client_http_json.go @@ -0,0 +1,64 @@ +package deepseek + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + + "ds2api/internal/config" + trans "ds2api/internal/deepseek/transport" +) + +func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) { + body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload) + if err != nil { + return nil, err + } + if status == 0 { + return nil, errors.New("request failed") + } + return body, nil +} + +func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) { + b, err := json.Marshal(payload) + if err != nil { + return nil, 0, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if err != nil { + return nil, 0, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := doer.Do(req) + if err != nil { + config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err) + req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b)) + if reqErr != nil { + return nil, 0, err + } + for k, v := range headers { + req2.Header.Set(k, v) + } + resp, err = c.fallback.Do(req2) + if err != nil { + return nil, 0, err + } + } + defer resp.Body.Close() + payloadBytes, err := readResponseBody(resp) + if err != nil { + return nil, resp.StatusCode, err + } + out := map[string]any{} + if len(payloadBytes) > 0 { + if err := json.Unmarshal(payloadBytes, &out); err != nil { + config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes)) + } + } + return out, resp.StatusCode, nil +} diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go deleted file mode 100644 index 2107d4e..0000000 --- a/internal/format/openai/render.go +++ /dev/null @@ -1,307 +0,0 @@ -package openai - -import ( - "encoding/json" - "strings" - "time" - - "github.com/google/uuid" - - "ds2api/internal/util" -) - -func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - detected := util.ParseToolCalls(finalText, toolNames) - finishReason := "stop" - messageObj := map[string]any{"role": "assistant", "content": finalText} - if strings.TrimSpace(finalThinking) != "" { - messageObj["reasoning_content"] = finalThinking - } - if len(detected) > 0 { - finishReason = "tool_calls" - messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) - messageObj["content"] = nil - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - - return map[string]any{ - "id": completionID, - "object": "chat.completion", - "created": time.Now().Unix(), - "model": model, - "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, - "usage": map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - }, - } -} - -func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - // Align responses tool-call semantics with chat/completions: - // mixed prose + tool_call payloads should still be interpreted as tool calls. - detected := util.ParseToolCalls(finalText, toolNames) - if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { - detected = util.ParseToolCalls(finalThinking, toolNames) - } - exposedOutputText := finalText - output := make([]any, 0, 2) - if len(detected) > 0 { - exposedOutputText = "" - if strings.TrimSpace(finalThinking) != "" { - output = append(output, map[string]any{ - "type": "reasoning", - "text": finalThinking, - }) - } - formatted := util.FormatOpenAIToolCalls(detected) - output = append(output, toResponsesFunctionCallItems(formatted)...) - output = append(output, map[string]any{ - "type": "tool_calls", - "tool_calls": formatted, - }) - } else { - content := make([]any, 0, 2) - if finalThinking != "" { - content = append([]any{map[string]any{ - "type": "reasoning", - "text": finalThinking, - }}, content...) - } - if strings.TrimSpace(finalText) != "" { - content = append(content, map[string]any{ - "type": "output_text", - "text": finalText, - }) - } - if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" { - exposedOutputText = finalThinking - } - output = append(output, map[string]any{ - "type": "message", - "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "role": "assistant", - "content": content, - }) - } - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - return map[string]any{ - "id": responseID, - "type": "response", - "object": "response", - "created_at": time.Now().Unix(), - "status": "completed", - "model": model, - "output": output, - "output_text": exposedOutputText, - "usage": map[string]any{ - "input_tokens": promptTokens, - "output_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - }, - } -} - -func toResponsesFunctionCallItems(toolCalls []map[string]any) []any { - if len(toolCalls) == 0 { - return nil - } - out := make([]any, 0, len(toolCalls)) - for _, tc := range toolCalls { - callID, _ := tc["id"].(string) - if strings.TrimSpace(callID) == "" { - callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") - } - name := "" - args := "{}" - if fn, ok := tc["function"].(map[string]any); ok { - if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" { - name = n - } - if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" { - args = a - } - } - out = append(out, map[string]any{ - "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "type": "function_call", - "call_id": callID, - "name": name, - "arguments": normalizeJSONString(args), - "status": "completed", - }) - } - return out -} - -func normalizeJSONString(raw string) string { - s := strings.TrimSpace(raw) - if s == "" { - return "{}" - } - var v any - if err := json.Unmarshal([]byte(s), &v); err != nil { - return raw - } - b, err := json.Marshal(v) - if err != nil { - return raw - } - return string(b) -} - -func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { - return map[string]any{ - "delta": delta, - "index": index, - } -} - -func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any { - return map[string]any{ - "delta": map[string]any{}, - "index": index, - "finish_reason": finishReason, - } -} - -func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { - out := map[string]any{ - "id": completionID, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": choices, - } - if len(usage) > 0 { - out["usage"] = usage - } - return out -} - -func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) - return map[string]any{ - "prompt_tokens": promptTokens, - "completion_tokens": reasoningTokens + completionTokens, - "total_tokens": promptTokens + reasoningTokens + completionTokens, - "completion_tokens_details": map[string]any{ - "reasoning_tokens": reasoningTokens, - }, - } -} - -func BuildResponsesCreatedPayload(responseID, model string) map[string]any { - return map[string]any{ - "type": "response.created", - "id": responseID, - "response_id": responseID, - "object": "response", - "model": model, - "status": "in_progress", - } -} - -func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { - return map[string]any{ - "type": "response.output_text.delta", - "id": responseID, - "response_id": responseID, - "delta": delta, - } -} - -func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { - return map[string]any{ - "type": "response.reasoning.delta", - "id": responseID, - "response_id": responseID, - "delta": delta, - } -} - -func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { - return map[string]any{ - "type": "response.reasoning_text.delta", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "content_index": contentIndex, - "delta": delta, - } -} - -func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any { - return map[string]any{ - "type": "response.reasoning_text.done", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "content_index": contentIndex, - "text": text, - } -} - -func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.delta", - "id": responseID, - "response_id": responseID, - "tool_calls": toolCalls, - } -} - -func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { - return map[string]any{ - "type": "response.output_tool_call.done", - "id": responseID, - "response_id": responseID, - "tool_calls": toolCalls, - } -} - -func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any { - return map[string]any{ - "type": "response.function_call_arguments.delta", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "call_id": callID, - "delta": delta, - } -} - -func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any { - return map[string]any{ - "type": "response.function_call_arguments.done", - "id": responseID, - "response_id": responseID, - "item_id": itemID, - "output_index": outputIndex, - "call_id": callID, - "name": name, - "arguments": normalizeJSONString(arguments), - } -} - -func BuildResponsesCompletedPayload(response map[string]any) map[string]any { - responseID, _ := response["id"].(string) - return map[string]any{ - "type": "response.completed", - "response_id": responseID, - "response": response, - } -} diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go new file mode 100644 index 0000000..1e58fbd --- /dev/null +++ b/internal/format/openai/render_chat.go @@ -0,0 +1,60 @@ +package openai + +import ( + "strings" + "time" + + "ds2api/internal/util" +) + +func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + detected := util.ParseToolCalls(finalText, toolNames) + finishReason := "stop" + messageObj := map[string]any{"role": "assistant", "content": finalText} + if strings.TrimSpace(finalThinking) != "" { + messageObj["reasoning_content"] = finalThinking + } + if len(detected) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["content"] = nil + } + + return map[string]any{ + "id": completionID, + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, + "usage": BuildChatUsage(finalPrompt, finalThinking, finalText), + } +} + +func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any { + return map[string]any{ + "delta": delta, + "index": index, + } +} + +func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "index": index, + "finish_reason": finishReason, + } +} + +func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any { + out := map[string]any{ + "id": completionID, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": choices, + } + if len(usage) > 0 { + out["usage"] = usage + } + return out +} diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go new file mode 100644 index 0000000..4fd17c3 --- /dev/null +++ b/internal/format/openai/render_responses.go @@ -0,0 +1,119 @@ +package openai + +import ( + "encoding/json" + "strings" + "time" + + "github.com/google/uuid" + + "ds2api/internal/util" +) + +func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { + // Align responses tool-call semantics with chat/completions: + // mixed prose + tool_call payloads should still be interpreted as tool calls. + detected := util.ParseToolCalls(finalText, toolNames) + if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" { + detected = util.ParseToolCalls(finalThinking, toolNames) + } + exposedOutputText := finalText + output := make([]any, 0, 2) + if len(detected) > 0 { + exposedOutputText = "" + if strings.TrimSpace(finalThinking) != "" { + output = append(output, map[string]any{ + "type": "reasoning", + "text": finalThinking, + }) + } + formatted := util.FormatOpenAIToolCalls(detected) + output = append(output, toResponsesFunctionCallItems(formatted)...) + output = append(output, map[string]any{ + "type": "tool_calls", + "tool_calls": formatted, + }) + } else { + content := make([]any, 0, 2) + if finalThinking != "" { + content = append([]any{map[string]any{ + "type": "reasoning", + "text": finalThinking, + }}, content...) + } + if strings.TrimSpace(finalText) != "" { + content = append(content, map[string]any{ + "type": "output_text", + "text": finalText, + }) + } + if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" { + exposedOutputText = finalThinking + } + output = append(output, map[string]any{ + "type": "message", + "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "role": "assistant", + "content": content, + }) + } + return map[string]any{ + "id": responseID, + "type": "response", + "object": "response", + "created_at": time.Now().Unix(), + "status": "completed", + "model": model, + "output": output, + "output_text": exposedOutputText, + "usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText), + } +} + +func toResponsesFunctionCallItems(toolCalls []map[string]any) []any { + if len(toolCalls) == 0 { + return nil + } + out := make([]any, 0, len(toolCalls)) + for _, tc := range toolCalls { + callID, _ := tc["id"].(string) + if strings.TrimSpace(callID) == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + } + name := "" + args := "{}" + if fn, ok := tc["function"].(map[string]any); ok { + if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" { + name = n + } + if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" { + args = a + } + } + out = append(out, map[string]any{ + "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function_call", + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(args), + "status": "completed", + }) + } + return out +} + +func normalizeJSONString(raw string) string { + s := strings.TrimSpace(raw) + if s == "" { + return "{}" + } + var v any + if err := json.Unmarshal([]byte(s), &v); err != nil { + return raw + } + b, err := json.Marshal(v) + if err != nil { + return raw + } + return string(b) +} diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go new file mode 100644 index 0000000..cc62604 --- /dev/null +++ b/internal/format/openai/render_stream_events.go @@ -0,0 +1,106 @@ +package openai + +func BuildResponsesCreatedPayload(responseID, model string) map[string]any { + return map[string]any{ + "type": "response.created", + "id": responseID, + "response_id": responseID, + "object": "response", + "model": model, + "status": "in_progress", + } +} + +func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.output_text.delta", + "id": responseID, + "response_id": responseID, + "delta": delta, + } +} + +func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning.delta", + "id": responseID, + "response_id": responseID, + "delta": delta, + } +} + +func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any { + return map[string]any{ + "type": "response.reasoning_text.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "delta": delta, + } +} + +func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any { + return map[string]any{ + "type": "response.reasoning_text.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "text": text, + } +} + +func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.delta", + "id": responseID, + "response_id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any { + return map[string]any{ + "type": "response.output_tool_call.done", + "id": responseID, + "response_id": responseID, + "tool_calls": toolCalls, + } +} + +func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.delta", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "delta": delta, + } +} + +func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any { + return map[string]any{ + "type": "response.function_call_arguments.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "call_id": callID, + "name": name, + "arguments": normalizeJSONString(arguments), + } +} + +func BuildResponsesCompletedPayload(response map[string]any) map[string]any { + responseID, _ := response["id"].(string) + return map[string]any{ + "type": "response.completed", + "response_id": responseID, + "response": response, + } +} diff --git a/internal/format/openai/render_usage.go b/internal/format/openai/render_usage.go new file mode 100644 index 0000000..b328d20 --- /dev/null +++ b/internal/format/openai/render_usage.go @@ -0,0 +1,28 @@ +package openai + +import "ds2api/internal/util" + +func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + "completion_tokens_details": map[string]any{ + "reasoning_tokens": reasoningTokens, + }, + } +} + +func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) + reasoningTokens := util.EstimateTokens(finalThinking) + completionTokens := util.EstimateTokens(finalText) + return map[string]any{ + "input_tokens": promptTokens, + "output_tokens": reasoningTokens + completionTokens, + "total_tokens": promptTokens + reasoningTokens + completionTokens, + } +} diff --git a/internal/testsuite/edge_cases.go b/internal/testsuite/edge_cases.go index cba0b5a..50bc8ac 100644 --- a/internal/testsuite/edge_cases.go +++ b/internal/testsuite/edge_cases.go @@ -1,7 +1,6 @@ package testsuite import ( - "bytes" "context" "encoding/json" "fmt" @@ -125,72 +124,6 @@ func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) er return nil } -func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error { - cc.seq++ - traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) - cc.traceIDsSet[traceID] = struct{}{} - fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) - if err != nil { - return err - } - headers := map[string]string{} - for k, v := range spec.Headers { - headers[k] = v - } - headers["X-Ds2-Test-Trace"] = traceID - bodyBytes, _ := json.Marshal(spec.Body) - headers["Content-Type"] = "application/json" - cc.requests = append(cc.requests, requestLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - Method: spec.Method, - URL: fullURL, - Headers: headers, - Body: spec.Body, - Timestamp: time.Now().Format(time.RFC3339Nano), - }) - - reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) - defer cancel() - req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) - if err != nil { - return err - } - for k, v := range headers { - req.Header.Set(k, v) - } - start := time.Now() - resp, err := cc.runner.httpClient.Do(req) - if err != nil { - cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - StatusCode: 0, - DurationMS: time.Since(start).Milliseconds(), - NetworkErr: err.Error(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - return err - } - defer resp.Body.Close() - buf := make([]byte, 512) - _, _ = resp.Body.Read(buf) - _ = resp.Body.Close() - cc.responses = append(cc.responses, responseLog{ - Seq: cc.seq, - Attempt: 1, - TraceID: traceID, - StatusCode: resp.StatusCode, - Headers: resp.Header, - BodyText: "aborted_after_first_chunk", - DurationMS: time.Since(start).Milliseconds(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - return nil -} - func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error { payload := toolcallPayload(true) payload["messages"] = []map[string]any{ @@ -293,167 +226,6 @@ func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) erro return nil } -func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error { - resp, err := cc.requestOnce(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-not-exists", - "messages": []map[string]any{ - {"role": "user", "content": "hi"}, - }, - "stream": false, - }, - Retryable: false, - }, 1) - if err != nil { - return err - } - cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) - return nil -} - -func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error { - if len(r.configRaw.Accounts) == 0 { - cc.assert("account_present", false, "no account in config") - return nil - } - acc := r.configRaw.Accounts[0] - id := strings.TrimSpace(acc.Email) - if id == "" { - id = strings.TrimSpace(acc.Mobile) - } - if id == "" { - cc.assert("account_identifier", false, "first account has no identifier") - return nil - } - if strings.TrimSpace(acc.Password) == "" { - r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty") - cc.assert("account_password_present", true, "skipped strict refresh check due empty password") - return nil - } - invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID) - update := map[string]any{ - "keys": r.configRaw.Keys, - "accounts": []map[string]any{ - { - "email": acc.Email, - "mobile": acc.Mobile, - "password": acc.Password, - "token": invalidToken, - }, - }, - } - updResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: update, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode)) - - chatResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "X-Ds2-Target-Account": id, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "token refresh test"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body))) - - cfgResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - var cfg map[string]any - _ = json.Unmarshal(cfgResp.Body, &cfg) - accounts, _ := cfg["accounts"].([]any) - preview := "" - hasToken := false - for _, item := range accounts { - m, _ := item.(map[string]any) - e := asString(m["email"]) - mo := asString(m["mobile"]) - if e == acc.Email && mo == acc.Mobile { - preview = asString(m["token_preview"]) - hasToken, _ = m["has_token"].(bool) - break - } - } - cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body))) - cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20])) - return nil -} - func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) { resp, err := cc.request(ctx, requestSpec{ Method: http.MethodGet, diff --git a/internal/testsuite/edge_cases_abort.go b/internal/testsuite/edge_cases_abort.go new file mode 100644 index 0000000..2cc1fc1 --- /dev/null +++ b/internal/testsuite/edge_cases_abort.go @@ -0,0 +1,76 @@ +package testsuite + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "time" +) + +func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error { + cc.seq++ + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq) + cc.traceIDsSet[traceID] = struct{}{} + fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) + if err != nil { + return err + } + headers := map[string]string{} + for k, v := range spec.Headers { + headers[k] = v + } + headers["X-Ds2-Test-Trace"] = traceID + bodyBytes, _ := json.Marshal(spec.Body) + headers["Content-Type"] = "application/json" + cc.requests = append(cc.requests, requestLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + Method: spec.Method, + URL: fullURL, + Headers: headers, + Body: spec.Body, + Timestamp: time.Now().Format(time.RFC3339Nano), + }) + + reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) + if err != nil { + return err + } + for k, v := range headers { + req.Header.Set(k, v) + } + start := time.Now() + resp, err := cc.runner.httpClient.Do(req) + if err != nil { + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: 0, + DurationMS: time.Since(start).Milliseconds(), + NetworkErr: err.Error(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return err + } + defer resp.Body.Close() + buf := make([]byte, 512) + _, _ = resp.Body.Read(buf) + _ = resp.Body.Close() + cc.responses = append(cc.responses, responseLog{ + Seq: cc.seq, + Attempt: 1, + TraceID: traceID, + StatusCode: resp.StatusCode, + Headers: resp.Header, + BodyText: "aborted_after_first_chunk", + DurationMS: time.Since(start).Milliseconds(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + return nil +} diff --git a/internal/testsuite/edge_cases_error_contract.go b/internal/testsuite/edge_cases_error_contract.go new file mode 100644 index 0000000..d65ce6d --- /dev/null +++ b/internal/testsuite/edge_cases_error_contract.go @@ -0,0 +1,170 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error { + resp, err := cc.requestOnce(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-not-exists", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "stream": false, + }, + Retryable: false, + }, 1) + if err != nil { + return err + } + cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) + return nil +} + +func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error { + if len(r.configRaw.Accounts) == 0 { + cc.assert("account_present", false, "no account in config") + return nil + } + acc := r.configRaw.Accounts[0] + id := strings.TrimSpace(acc.Email) + if id == "" { + id = strings.TrimSpace(acc.Mobile) + } + if id == "" { + cc.assert("account_identifier", false, "first account has no identifier") + return nil + } + if strings.TrimSpace(acc.Password) == "" { + r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty") + cc.assert("account_password_present", true, "skipped strict refresh check due empty password") + return nil + } + invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID) + update := map[string]any{ + "keys": r.configRaw.Keys, + "accounts": []map[string]any{ + { + "email": acc.Email, + "mobile": acc.Mobile, + "password": acc.Password, + "token": invalidToken, + }, + }, + } + updResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: update, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode)) + + chatResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "X-Ds2-Target-Account": id, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "token refresh test"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body))) + + cfgResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + var cfg map[string]any + _ = json.Unmarshal(cfgResp.Body, &cfg) + accounts, _ := cfg["accounts"].([]any) + preview := "" + hasToken := false + for _, item := range accounts { + m, _ := item.(map[string]any) + e := asString(m["email"]) + mo := asString(m["mobile"]) + if e == acc.Email && mo == acc.Mobile { + preview = asString(m["token_preview"]) + hasToken, _ = m["has_token"].(bool) + break + } + } + cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body))) + cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20])) + return nil +} diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go deleted file mode 100644 index 33e7580..0000000 --- a/internal/testsuite/runner.go +++ /dev/null @@ -1,1766 +0,0 @@ -package testsuite - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/url" - "os" - "os/exec" - "path/filepath" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "time" -) - -type Options struct { - ConfigPath string - AdminKey string - OutputDir string - Port int - Timeout time.Duration - Retries int - NoPreflight bool - MaxKeepRuns int -} - -type runSummary struct { - RunID string `json:"run_id"` - StartedAt string `json:"started_at"` - EndedAt string `json:"ended_at"` - DurationMS int64 `json:"duration_ms"` - Stats map[string]any `json:"stats"` - Environment map[string]any `json:"environment"` - Cases []caseResult `json:"cases"` - Warnings []string `json:"warnings,omitempty"` -} - -type caseResult struct { - CaseID string `json:"case_id"` - Passed bool `json:"passed"` - DurationMS int64 `json:"duration_ms"` - TraceIDs []string `json:"trace_ids"` - StatusCodes []int `json:"status_codes"` - Error string `json:"error,omitempty"` - ArtifactPath string `json:"artifact_path"` - Assertions []assertionResult `json:"assertions"` -} - -type assertionResult struct { - Name string `json:"name"` - Passed bool `json:"passed"` - Detail string `json:"detail,omitempty"` -} - -type requestLog struct { - Seq int `json:"seq"` - Attempt int `json:"attempt"` - TraceID string `json:"trace_id"` - Method string `json:"method"` - URL string `json:"url"` - Headers map[string]string `json:"headers"` - Body any `json:"body,omitempty"` - Timestamp string `json:"timestamp"` -} - -type responseLog struct { - Seq int `json:"seq"` - Attempt int `json:"attempt"` - TraceID string `json:"trace_id"` - StatusCode int `json:"status_code"` - Headers map[string][]string `json:"headers"` - BodyText string `json:"body_text"` - DurationMS int64 `json:"duration_ms"` - NetworkErr string `json:"network_error,omitempty"` - ReceivedAt string `json:"received_at"` -} - -type caseContext struct { - runner *Runner - id string - dir string - startedAt time.Time - mu sync.Mutex - seq int - assertions []assertionResult - requests []requestLog - responses []responseLog - streamRaw strings.Builder - traceIDsSet map[string]struct{} -} - -type requestSpec struct { - Method string - Path string - Headers map[string]string - Body any - Stream bool - Retryable bool -} - -type responseResult struct { - StatusCode int - Headers http.Header - Body []byte - TraceID string - URL string -} - -type Runner struct { - opts Options - - runID string - runDir string - serverLog string - preflightLog string - - baseURL string - httpClient *http.Client - serverCmd *exec.Cmd - serverLogFd *os.File - - configCopyPath string - originalConfigPath string - originalConfigHash string - - configRaw runConfig - apiKey string - adminKey string - adminJWT string - accountID string - - warnings []string - results []caseResult -} - -type runConfig struct { - Keys []string `json:"keys"` - Accounts []struct { - Email string `json:"email,omitempty"` - Mobile string `json:"mobile,omitempty"` - Password string `json:"password,omitempty"` - Token string `json:"token,omitempty"` - } `json:"accounts"` -} - -func DefaultOptions() Options { - return Options{ - ConfigPath: "config.json", - AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")), - OutputDir: "artifacts/testsuite", - Port: 0, - Timeout: 120 * time.Second, - Retries: 2, - NoPreflight: false, - MaxKeepRuns: 5, - } -} - -func Run(ctx context.Context, opts Options) error { - r, err := newRunner(opts) - if err != nil { - return err - } - start := time.Now() - defer func() { - _ = r.stopServer() - }() - - if err := r.prepareRunDir(); err != nil { - return err - } - - if !r.opts.NoPreflight { - if err := r.runPreflight(ctx); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - } - - if err := r.prepareConfigIsolation(); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - - if err := r.startServer(ctx); err != nil { - _ = r.writeSummary(start, time.Now()) - return err - } - - if err := r.prepareAuth(ctx); err != nil { - r.warnings = append(r.warnings, "auth prepare failed: "+err.Error()) - } - - for _, c := range r.cases() { - r.runCase(ctx, c) - } - - if err := r.ensureOriginalConfigUntouched(); err != nil { - r.warnings = append(r.warnings, err.Error()) - } - - end := time.Now() - if err := r.writeSummary(start, end); err != nil { - return err - } - - // Prune old test runs, keeping only the most recent N. - if err := r.pruneOldRuns(); err != nil { - r.warnings = append(r.warnings, "prune old runs: "+err.Error()) - } - - failed := 0 - for _, cs := range r.results { - if !cs.Passed { - failed++ - } - } - if failed > 0 { - return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md")) - } - return nil -} - -func newRunner(opts Options) (*Runner, error) { - if strings.TrimSpace(opts.ConfigPath) == "" { - opts.ConfigPath = "config.json" - } - if strings.TrimSpace(opts.OutputDir) == "" { - opts.OutputDir = "artifacts/testsuite" - } - if opts.Timeout <= 0 { - opts.Timeout = 120 * time.Second - } - if opts.Retries < 0 { - opts.Retries = 0 - } - adminKey := strings.TrimSpace(opts.AdminKey) - if adminKey == "" { - adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) - } - if adminKey == "" { - adminKey = "admin" - } - opts.AdminKey = adminKey - - return &Runner{ - opts: opts, - httpClient: &http.Client{ - Timeout: 0, - }, - runID: time.Now().UTC().Format("20060102T150405Z"), - adminKey: adminKey, - }, nil -} - -func (r *Runner) prepareRunDir() error { - r.runDir = filepath.Join(r.opts.OutputDir, r.runID) - if err := os.MkdirAll(r.runDir, 0o755); err != nil { - return err - } - if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil { - return err - } - r.serverLog = filepath.Join(r.runDir, "server.log") - r.preflightLog = filepath.Join(r.runDir, "preflight.log") - return nil -} - -// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns. -// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order. -func (r *Runner) pruneOldRuns() error { - keep := r.opts.MaxKeepRuns - if keep <= 0 { - return nil // 0 or negative means no pruning - } - - entries, err := os.ReadDir(r.opts.OutputDir) - if err != nil { - return err - } - - // Collect only directories (each run is a directory). - var runDirs []string - for _, e := range entries { - if !e.IsDir() { - continue - } - runDirs = append(runDirs, e.Name()) - } - - sort.Strings(runDirs) - - if len(runDirs) <= keep { - return nil - } - - // Remove oldest runs (those at the beginning of the sorted list). - toRemove := runDirs[:len(runDirs)-keep] - var errs []string - for _, name := range toRemove { - dirPath := filepath.Join(r.opts.OutputDir, name) - if err := os.RemoveAll(dirPath); err != nil { - errs = append(errs, fmt.Sprintf("remove %s: %v", name, err)) - } else { - fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name) - } - } - - if len(errs) > 0 { - return errors.New(strings.Join(errs, "; ")) - } - return nil -} - -func (r *Runner) runPreflight(ctx context.Context) error { - steps := [][]string{ - {"go", "test", "./...", "-count=1"}, - {"node", "--check", "api/chat-stream.js"}, - {"node", "--check", "api/helpers/stream-tool-sieve.js"}, - {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"}, - {"npm", "run", "build", "--prefix", "webui"}, - } - f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) - if err != nil { - return err - } - defer f.Close() - for _, step := range steps { - if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil { - return err - } - cmd := exec.CommandContext(ctx, step[0], step[1:]...) - cmd.Stdout = f - cmd.Stderr = f - if err := cmd.Run(); err != nil { - return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err) - } - } - return nil -} - -func (r *Runner) prepareConfigIsolation() error { - abs, err := filepath.Abs(r.opts.ConfigPath) - if err != nil { - return err - } - r.originalConfigPath = abs - raw, err := os.ReadFile(abs) - if err != nil { - return err - } - sum := sha256.Sum256(raw) - r.originalConfigHash = hex.EncodeToString(sum[:]) - - tmpDir := filepath.Join(r.runDir, "tmp") - if err := os.MkdirAll(tmpDir, 0o755); err != nil { - return err - } - r.configCopyPath = filepath.Join(tmpDir, "config.json") - if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil { - return err - } - var cfg runConfig - if err := json.Unmarshal(raw, &cfg); err != nil { - return fmt.Errorf("parse config failed: %w", err) - } - r.configRaw = cfg - if len(cfg.Keys) > 0 { - r.apiKey = strings.TrimSpace(cfg.Keys[0]) - } - for _, acc := range cfg.Accounts { - id := strings.TrimSpace(acc.Email) - if id == "" { - id = strings.TrimSpace(acc.Mobile) - } - if id != "" { - r.accountID = id - break - } - } - return nil -} - -func (r *Runner) startServer(ctx context.Context) error { - port := r.opts.Port - if port <= 0 { - p, err := findFreePort() - if err != nil { - return err - } - port = p - } - r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port) - - logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) - if err != nil { - return err - } - r.serverLogFd = logFd - cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api") - cmd.Stdout = logFd - cmd.Stderr = logFd - cmd.Env = prepareServerEnv(os.Environ(), map[string]string{ - "PORT": strconv.Itoa(port), - "DS2API_CONFIG_PATH": r.configCopyPath, - "DS2API_AUTO_BUILD_WEBUI": "false", - "DS2API_CONFIG_JSON": "", - "CONFIG_JSON": "", - }) - if err := cmd.Start(); err != nil { - _ = logFd.Close() - return err - } - r.serverCmd = cmd - - deadline := time.Now().Add(90 * time.Second) - for time.Now().Before(deadline) { - if r.ping("/healthz") == nil && r.ping("/readyz") == nil { - return nil - } - time.Sleep(500 * time.Millisecond) - } - return errors.New("server readiness timeout") -} - -func (r *Runner) stopServer() error { - var errs []string - if r.serverCmd != nil && r.serverCmd.Process != nil { - _ = r.serverCmd.Process.Signal(os.Interrupt) - done := make(chan error, 1) - go func() { done <- r.serverCmd.Wait() }() - select { - case <-time.After(5 * time.Second): - _ = r.serverCmd.Process.Kill() - <-done - case <-done: - } - } - if r.serverLogFd != nil { - if err := r.serverLogFd.Close(); err != nil { - errs = append(errs, err.Error()) - } - } - if len(errs) > 0 { - return errors.New(strings.Join(errs, "; ")) - } - return nil -} - -func (r *Runner) ping(path string) error { - resp, err := r.httpClient.Get(r.baseURL + path) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("status=%d", resp.StatusCode) - } - return nil -} - -func (r *Runner) prepareAuth(ctx context.Context) error { - reqBody := map[string]any{ - "admin_key": r.adminKey, - "expire_hours": 24, - } - resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody) - if err != nil { - return err - } - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body)) - } - var m map[string]any - if err := json.Unmarshal(resp.Body, &m); err != nil { - return err - } - token, _ := m["token"].(string) - if strings.TrimSpace(token) == "" { - return errors.New("empty admin jwt token") - } - r.adminJWT = token - return nil -} - -func (r *Runner) ensureOriginalConfigUntouched() error { - raw, err := os.ReadFile(r.originalConfigPath) - if err != nil { - return err - } - sum := sha256.Sum256(raw) - current := hex.EncodeToString(sum[:]) - if current != r.originalConfigHash { - return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath) - } - return nil -} - -func (r *Runner) runCase(ctx context.Context, c caseDef) { - caseDir := filepath.Join(r.runDir, "cases", c.ID) - _ = os.MkdirAll(caseDir, 0o755) - cc := &caseContext{ - runner: r, - id: c.ID, - dir: caseDir, - startedAt: time.Now(), - traceIDsSet: map[string]struct{}{}, - } - err := c.Run(ctx, cc) - duration := time.Since(cc.startedAt).Milliseconds() - - if err != nil { - cc.assertions = append(cc.assertions, assertionResult{ - Name: "case_error", - Passed: false, - Detail: err.Error(), - }) - } - passed := err == nil - for _, a := range cc.assertions { - if !a.Passed { - passed = false - break - } - } - - traceIDs := make([]string, 0, len(cc.traceIDsSet)) - for t := range cc.traceIDsSet { - traceIDs = append(traceIDs, t) - } - sort.Strings(traceIDs) - statuses := uniqueStatusCodes(cc.responses) - cs := caseResult{ - CaseID: c.ID, - Passed: passed, - DurationMS: duration, - TraceIDs: traceIDs, - StatusCodes: statuses, - ArtifactPath: caseDir, - Assertions: cc.assertions, - } - if err != nil { - cs.Error = err.Error() - } - _ = cc.flushArtifacts(cs) - r.results = append(r.results, cs) -} - -func (cc *caseContext) assert(name string, ok bool, detail string) { - cc.mu.Lock() - defer cc.mu.Unlock() - cc.assertions = append(cc.assertions, assertionResult{ - Name: name, - Passed: ok, - Detail: detail, - }) -} - -func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) { - retries := cc.runner.opts.Retries - if !spec.Retryable { - retries = 0 - } - var lastErr error - for attempt := 1; attempt <= retries+1; attempt++ { - resp, err := cc.requestOnce(ctx, spec, attempt) - if err == nil && resp.StatusCode < 500 { - return resp, nil - } - if err != nil { - lastErr = err - } else if resp.StatusCode >= 500 { - lastErr = fmt.Errorf("status=%d", resp.StatusCode) - } - if attempt <= retries { - sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond - time.Sleep(sleep) - } - } - return nil, lastErr -} - -func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) { - cc.mu.Lock() - cc.seq++ - seq := cc.seq - traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq) - cc.traceIDsSet[traceID] = struct{}{} - cc.mu.Unlock() - - fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) - if err != nil { - return nil, err - } - - headers := map[string]string{} - for k, v := range spec.Headers { - headers[k] = v - } - headers["X-Ds2-Test-Trace"] = traceID - - var bodyBytes []byte - var bodyAny any - if spec.Body != nil { - b, err := json.Marshal(spec.Body) - if err != nil { - return nil, err - } - bodyBytes = b - bodyAny = spec.Body - headers["Content-Type"] = "application/json" - } - cc.mu.Lock() - cc.requests = append(cc.requests, requestLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - Method: spec.Method, - URL: fullURL, - Headers: headers, - Body: bodyAny, - Timestamp: time.Now().Format(time.RFC3339Nano), - }) - cc.mu.Unlock() - - reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) - defer cancel() - req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) - if err != nil { - return nil, err - } - for k, v := range headers { - req.Header.Set(k, v) - } - start := time.Now() - resp, err := cc.runner.httpClient.Do(req) - if err != nil { - cc.mu.Lock() - cc.responses = append(cc.responses, responseLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - StatusCode: 0, - DurationMS: time.Since(start).Milliseconds(), - NetworkErr: err.Error(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - cc.mu.Unlock() - return nil, err - } - defer resp.Body.Close() - body, _ := io.ReadAll(resp.Body) - - cc.mu.Lock() - cc.responses = append(cc.responses, responseLog{ - Seq: seq, - Attempt: attempt, - TraceID: traceID, - StatusCode: resp.StatusCode, - Headers: resp.Header, - BodyText: string(body), - DurationMS: time.Since(start).Milliseconds(), - ReceivedAt: time.Now().Format(time.RFC3339Nano), - }) - - if spec.Stream { - cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL)) - cc.streamRaw.Write(body) - cc.streamRaw.WriteString("\n\n") - } - cc.mu.Unlock() - - return &responseResult{ - StatusCode: resp.StatusCode, - Headers: resp.Header, - Body: body, - TraceID: traceID, - URL: fullURL, - }, nil -} - -func (cc *caseContext) flushArtifacts(cs caseResult) error { - requestPath := filepath.Join(cc.dir, "request.json") - headersPath := filepath.Join(cc.dir, "response.headers") - bodyPath := filepath.Join(cc.dir, "response.body") - streamPath := filepath.Join(cc.dir, "stream.raw") - assertPath := filepath.Join(cc.dir, "assertions.json") - metaPath := filepath.Join(cc.dir, "meta.json") - - if err := writeJSONFile(requestPath, cc.requests); err != nil { - return err - } - respHeaders := make([]map[string]any, 0, len(cc.responses)) - respBodies := make([]map[string]any, 0, len(cc.responses)) - for _, r := range cc.responses { - respHeaders = append(respHeaders, map[string]any{ - "seq": r.Seq, - "attempt": r.Attempt, - "trace_id": r.TraceID, - "status_code": r.StatusCode, - "headers": r.Headers, - }) - respBodies = append(respBodies, map[string]any{ - "seq": r.Seq, - "attempt": r.Attempt, - "trace_id": r.TraceID, - "status_code": r.StatusCode, - "body_text": r.BodyText, - "network_error": r.NetworkErr, - "duration_ms": r.DurationMS, - }) - } - if err := writeJSONFile(headersPath, respHeaders); err != nil { - return err - } - if err := writeJSONFile(bodyPath, respBodies); err != nil { - return err - } - if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil { - return err - } - if err := writeJSONFile(assertPath, cc.assertions); err != nil { - return err - } - meta := map[string]any{ - "case_id": cs.CaseID, - "trace_id": strings.Join(cs.TraceIDs, ","), - "attempt": len(cc.responses), - "duration_ms": cs.DurationMS, - "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed], - "status_codes": cs.StatusCodes, - "assertions": cs.Assertions, - "artifact_path": cs.ArtifactPath, - } - return writeJSONFile(metaPath, meta) -} - -type caseDef struct { - ID string - Run func(context.Context, *caseContext) error -} - -func (r *Runner) cases() []caseDef { - return []caseDef{ - {ID: "healthz_ok", Run: r.caseHealthz}, - {ID: "readyz_ok", Run: r.caseReadyz}, - {ID: "models_openai", Run: r.caseModelsOpenAI}, - {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, - {ID: "models_claude", Run: r.caseModelsClaude}, - {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, - {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, - {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, - {ID: "chat_stream_basic", Run: r.caseChatStream}, - {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, - {ID: "responses_stream_basic", Run: r.caseResponsesStream}, - {ID: "embeddings_contract", Run: r.caseEmbeddings}, - {ID: "reasoner_stream", Run: r.caseReasonerStream}, - {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, - {ID: "toolcall_stream", Run: r.caseToolcallStream}, - {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream}, - {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream}, - {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens}, - {ID: "admin_account_test_single", Run: r.caseAdminAccountTest}, - {ID: "concurrency_burst", Run: r.caseConcurrencyBurst}, - {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit}, - {ID: "stream_abort_release", Run: r.caseStreamAbortRelease}, - {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed}, - {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity}, - {ID: "error_contract_invalid_model", Run: r.caseInvalidModel}, - {ID: "error_contract_missing_messages", Run: r.caseMissingMessages}, - {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized}, - {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated}, - {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount}, - {ID: "error_contract_invalid_key", Run: r.caseInvalidKey}, - } -} - -func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - ids := extractModelIDs(resp.Body) - cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ",")) - cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ",")) - return nil -} - -func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - ids := extractModelIDs(resp.Body) - cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids)) - return nil -} - -func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error { - loginResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/login", - Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24}, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode)) - var payload map[string]any - _ = json.Unmarshal(loginResp.Body, &payload) - token := asString(payload["token"]) - cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body))) - if token == "" { - return nil - } - verifyResp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/verify", - Headers: map[string]string{ - "Authorization": "Bearer " + token, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode)) - var v map[string]any - _ = json.Unmarshal(verifyResp.Body, &v) - valid, _ := v["valid"].(bool) - cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body))) - return nil -} - -func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/queue/status", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - _, hasRec := m["recommended_concurrency"] - _, hasQueue := m["max_queue_size"] - cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "请简单回复一句话"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body))) - choices, _ := m["choices"].([]any) - cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "请流式回复一句话"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/responses", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "gpt-4o", - "input": "请简要回答 hello", - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) - responseID := asString(m["id"]) - cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) - if responseID != "" { - getResp, getErr := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/v1/responses/" + responseID, - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Retryable: true, - }) - if getErr != nil { - return getErr - } - cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) - } - return nil -} - -func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/responses", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "gpt-4o", - "input": "请流式回答 hello", - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) - hasCreated := false - hasCompleted := false - for _, f := range frames { - switch asString(f["type"]) { - case "response.created": - hasCreated = true - case "response.completed": - hasCompleted = true - } - } - cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/embeddings", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "gpt-4o", - "input": []string{"hello", "world"}, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - if resp.StatusCode == http.StatusOK { - cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) - data, _ := m["data"].([]any) - cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil - } - errObj, _ := m["error"].(map[string]any) - _, hasCode := errObj["code"] - _, hasParam := errObj["param"] - cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-reasoner", - "messages": []map[string]any{ - {"role": "user", "content": "先思考后回答:1+1"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - hasReasoning := false - for _, f := range frames { - choices, _ := f["choices"].([]any) - for _, c := range choices { - ch, _ := c.(map[string]any) - delta, _ := ch["delta"].(map[string]any) - if asString(delta["reasoning_content"]) != "" { - hasReasoning = true - } - } - } - cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found") - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: toolcallPayload(false), - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - choices, _ := m["choices"].([]any) - if len(choices) == 0 { - cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body))) - return nil - } - c0, _ := choices[0].(map[string]any) - cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body))) - msg, _ := c0["message"].(map[string]any) - tc, _ := msg["tool_calls"].([]any) - cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: toolcallPayload(true), - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - frames, done := parseSSEFrames(resp.Body) - hasTool := false - rawLeak := false - for _, f := range frames { - choices, _ := f["choices"].([]any) - for _, c := range choices { - ch, _ := c.(map[string]any) - delta, _ := ch["delta"].(map[string]any) - if _, ok := delta["tool_calls"]; ok { - hasTool = true - } - content := asString(delta["content"]) - if strings.Contains(strings.ToLower(content), `"tool_calls"`) { - rawLeak = true - } - } - } - cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing") - cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content") - cc.assert("done_terminated", done, "expected [DONE]") - return nil -} - -func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "hello"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "stream hello"}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - events := parseClaudeStreamEvents(resp.Body) - cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events)) - cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events)) - return nil -} - -func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/anthropic/v1/messages/count_tokens", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - "anthropic-version": "2023-06-01", - "content-type": "application/json", - }, - Body: map[string]any{ - "model": "claude-sonnet-4-5", - "messages": []map[string]any{ - {"role": "user", "content": "count me"}, - }, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - v := toInt(m["input_tokens"]) - cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error { - if strings.TrimSpace(r.accountID) == "" { - cc.assert("account_present", false, "no account in config") - return nil - } - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/accounts/test", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: map[string]any{ - "identifier": r.accountID, - "model": "deepseek-chat", - "message": "ping", - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - ok, _ := m["success"].(bool) - cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error { - accountCount := len(r.configRaw.Accounts) - n := accountCount*2 + 2 - if n < 2 { - n = 2 - } - type one struct { - Status int - Err string - } - results := make([]one, n) - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func(idx int) { - defer wg.Done() - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer " + r.apiKey, - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)}, - }, - "stream": true, - }, - Stream: true, - Retryable: true, - }) - if err != nil { - results[idx] = one{Err: err.Error()} - return - } - results[idx] = one{Status: resp.StatusCode} - }(i) - } - wg.Wait() - - dist := map[int]int{} - success := 0 - for _, it := range results { - if it.Status > 0 { - dist[it.Status]++ - if it.Status == http.StatusOK { - success++ - } - } - } - cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist)) - _, has5xx := has5xx(dist) - cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist)) - if err := r.ping("/healthz"); err != nil { - cc.assert("server_alive", false, err.Error()) - } else { - cc.assert("server_alive", true, "") - } - return nil -} - -func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error { - k := "testsuite-temp-" + sanitizeID(r.runID) - add, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/admin/keys", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Body: map[string]any{"key": k}, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode)) - - cfg1, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - containsAdded := strings.Contains(string(cfg1.Body), k) - cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config") - - delPath := "/admin/keys/" + url.PathEscape(k) - del, err := cc.request(ctx, requestSpec{ - Method: http.MethodDelete, - Path: delPath, - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode)) - - cfg2, err := cc.request(ctx, requestSpec{ - Method: http.MethodGet, - Path: "/admin/config", - Headers: map[string]string{ - "Authorization": "Bearer " + r.adminJWT, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present") - - if err := r.ensureOriginalConfigUntouched(); err != nil { - cc.assert("original_config_unchanged", false, err.Error()) - } else { - cc.assert("original_config_unchanged", true, "") - } - return nil -} - -func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error { - resp, err := cc.request(ctx, requestSpec{ - Method: http.MethodPost, - Path: "/v1/chat/completions", - Headers: map[string]string{ - "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID), - }, - Body: map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - {"role": "user", "content": "hi"}, - }, - "stream": false, - }, - Retryable: true, - }) - if err != nil { - return err - } - cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) - var m map[string]any - _ = json.Unmarshal(resp.Body, &m) - e, _ := m["error"].(map[string]any) - cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body))) - cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body))) - return nil -} - -func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) { - cc := &caseContext{ - runner: r, - id: "auth_prepare", - traceIDsSet: map[string]struct{}{}, - } - return cc.request(ctx, requestSpec{ - Method: method, - Path: path, - Headers: headers, - Body: body, - Retryable: true, - }) -} - -func (r *Runner) writeSummary(start, end time.Time) error { - passed := 0 - failed := 0 - for _, cs := range r.results { - if cs.Passed { - passed++ - } else { - failed++ - } - } - summary := runSummary{ - RunID: r.runID, - StartedAt: start.Format(time.RFC3339Nano), - EndedAt: end.Format(time.RFC3339Nano), - DurationMS: end.Sub(start).Milliseconds(), - Stats: map[string]any{ - "total": len(r.results), - "passed": passed, - "failed": failed, - }, - Environment: map[string]any{ - "go_version": runtime.Version(), - "os": runtime.GOOS, - "arch": runtime.GOARCH, - "base_url": r.baseURL, - "config_source": r.originalConfigPath, - "config_isolated": r.configCopyPath, - "server_log": r.serverLog, - "preflight_log": r.preflightLog, - "retries": r.opts.Retries, - "timeout_seconds": int(r.opts.Timeout.Seconds()), - }, - Cases: r.results, - Warnings: r.warnings, - } - if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil { - return err - } - return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644) -} - -func (r *Runner) summaryMarkdown(s runSummary) string { - var b strings.Builder - b.WriteString("# DS2API Live Testsuite Summary\n\n") - b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n") - fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID) - fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt) - fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt) - fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS) - fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"]) - if len(s.Warnings) > 0 { - b.WriteString("## Warnings\n\n") - for _, w := range s.Warnings { - fmt.Fprintf(&b, "- %s\n", w) - } - b.WriteString("\n") - } - b.WriteString("## Failed Cases\n\n") - hasFailed := false - for _, c := range s.Cases { - if c.Passed { - continue - } - hasFailed = true - fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error) - if len(c.TraceIDs) > 0 { - fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", ")) - fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log")) - } - fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath) - } - if !hasFailed { - b.WriteString("- none\n") - } - b.WriteString("\n## Case Table\n\n") - b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n") - b.WriteString("|---|---:|---:|---|---|\n") - for _, c := range s.Cases { - status := "PASS" - if !c.Passed { - status = "FAIL" - } - fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath) - } - return b.String() -} - -func toolcallPayload(stream bool) map[string]any { - return map[string]any{ - "model": "deepseek-chat", - "messages": []map[string]any{ - { - "role": "user", - "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。", - }, - }, - "tools": []map[string]any{ - { - "type": "function", - "function": map[string]any{ - "name": "search", - "description": "search documents", - "parameters": map[string]any{ - "type": "object", - "properties": map[string]any{ - "q": map[string]any{ - "type": "string", - }, - }, - "required": []string{"q"}, - }, - }, - }, - }, - "stream": stream, - } -} - -func parseSSEFrames(body []byte) ([]map[string]any, bool) { - lines := strings.Split(string(body), "\n") - frames := make([]map[string]any, 0, len(lines)) - done := false - for _, line := range lines { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "data:") { - continue - } - payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if payload == "" { - continue - } - if payload == "[DONE]" { - done = true - continue - } - var m map[string]any - if err := json.Unmarshal([]byte(payload), &m); err == nil { - frames = append(frames, m) - } - } - return frames, done -} - -func parseClaudeStreamEvents(body []byte) []string { - events := []string{} - seen := map[string]bool{} - lines := strings.Split(string(body), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if !strings.HasPrefix(line, "data:") { - continue - } - payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) - if payload == "" { - continue - } - var m map[string]any - if err := json.Unmarshal([]byte(payload), &m); err != nil { - continue - } - t := asString(m["type"]) - if t == "" || seen[t] { - continue - } - seen[t] = true - events = append(events, t) - } - return events -} - -func extractModelIDs(body []byte) []string { - var m map[string]any - if err := json.Unmarshal(body, &m); err != nil { - return nil - } - out := []string{} - data, _ := m["data"].([]any) - for _, it := range data { - item, _ := it.(map[string]any) - id := asString(item["id"]) - if id != "" { - out = append(out, id) - } - } - return out -} - -func withTraceQuery(rawURL, traceID string) (string, error) { - u, err := url.Parse(rawURL) - if err != nil { - return "", err - } - q := u.Query() - q.Set("__trace_id", traceID) - u.RawQuery = q.Encode() - return u.String(), nil -} - -func writeJSONFile(path string, v any) error { - b, err := json.MarshalIndent(v, "", " ") - if err != nil { - return err - } - return os.WriteFile(path, b, 0o644) -} - -func prepareServerEnv(base []string, overrides map[string]string) []string { - out := make([]string, 0, len(base)+len(overrides)) - skip := map[string]struct{}{} - for k := range overrides { - skip[k] = struct{}{} - } - for _, e := range base { - parts := strings.SplitN(e, "=", 2) - if len(parts) != 2 { - continue - } - if _, ok := skip[parts[0]]; ok { - continue - } - out = append(out, e) - } - for k, v := range overrides { - out = append(out, k+"="+v) - } - return out -} - -func findFreePort() (int, error) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return 0, err - } - defer ln.Close() - addr, ok := ln.Addr().(*net.TCPAddr) - if !ok { - return 0, errors.New("failed to detect tcp port") - } - return addr.Port, nil -} - -func uniqueStatusCodes(in []responseLog) []int { - set := map[int]struct{}{} - for _, it := range in { - if it.StatusCode > 0 { - set[it.StatusCode] = struct{}{} - } - } - out := make([]int, 0, len(set)) - for k := range set { - out = append(out, k) - } - sort.Ints(out) - return out -} - -func has5xx(dist map[int]int) (int, bool) { - for k := range dist { - if k >= 500 { - return k, true - } - } - return 0, false -} - -func sanitizeID(s string) string { - s = strings.ReplaceAll(s, ":", "_") - s = strings.ReplaceAll(s, "/", "_") - s = strings.ReplaceAll(s, " ", "_") - return s -} - -func asString(v any) string { - if v == nil { - return "" - } - switch x := v.(type) { - case string: - return strings.TrimSpace(x) - default: - return strings.TrimSpace(fmt.Sprintf("%v", v)) - } -} - -func toInt(v any) int { - switch x := v.(type) { - case float64: - return int(x) - case float32: - return int(x) - case int: - return x - case int64: - return int(x) - default: - return 0 - } -} - -func contains(xs []string, target string) bool { - for _, x := range xs { - if x == target { - return true - } - } - return false -} diff --git a/internal/testsuite/runner_cases_admin.go b/internal/testsuite/runner_cases_admin.go new file mode 100644 index 0000000..d66adea --- /dev/null +++ b/internal/testsuite/runner_cases_admin.go @@ -0,0 +1,161 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" +) + +func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error { + loginResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/login", + Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24}, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode)) + var payload map[string]any + _ = json.Unmarshal(loginResp.Body, &payload) + token := asString(payload["token"]) + cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body))) + if token == "" { + return nil + } + verifyResp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/verify", + Headers: map[string]string{ + "Authorization": "Bearer " + token, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode)) + var v map[string]any + _ = json.Unmarshal(verifyResp.Body, &v) + valid, _ := v["valid"].(bool) + cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body))) + return nil +} + +func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/queue/status", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + _, hasRec := m["recommended_concurrency"] + _, hasQueue := m["max_queue_size"] + cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error { + if strings.TrimSpace(r.accountID) == "" { + cc.assert("account_present", false, "no account in config") + return nil + } + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/accounts/test", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: map[string]any{ + "identifier": r.accountID, + "model": "deepseek-chat", + "message": "ping", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + ok, _ := m["success"].(bool) + cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error { + k := "testsuite-temp-" + sanitizeID(r.runID) + add, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/admin/keys", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Body: map[string]any{"key": k}, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode)) + + cfg1, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + containsAdded := strings.Contains(string(cfg1.Body), k) + cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config") + + delPath := "/admin/keys/" + url.PathEscape(k) + del, err := cc.request(ctx, requestSpec{ + Method: http.MethodDelete, + Path: delPath, + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode)) + + cfg2, err := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/admin/config", + Headers: map[string]string{ + "Authorization": "Bearer " + r.adminJWT, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present") + + if err := r.ensureOriginalConfigUntouched(); err != nil { + cc.assert("original_config_unchanged", false, err.Error()) + } else { + cc.assert("original_config_unchanged", true, "") + } + return nil +} diff --git a/internal/testsuite/runner_cases_claude.go b/internal/testsuite/runner_cases_claude.go new file mode 100644 index 0000000..590e524 --- /dev/null +++ b/internal/testsuite/runner_cases_claude.go @@ -0,0 +1,103 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" +) + +func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + ids := extractModelIDs(resp.Body) + cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids)) + return nil +} +func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "hello"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "stream hello"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + events := parseClaudeStreamEvents(resp.Body) + cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events)) + cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events)) + return nil +} + +func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/anthropic/v1/messages/count_tokens", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + "anthropic-version": "2023-06-01", + "content-type": "application/json", + }, + Body: map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []map[string]any{ + {"role": "user", "content": "count me"}, + }, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + v := toInt(m["input_tokens"]) + cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} diff --git a/internal/testsuite/runner_cases_openai.go b/internal/testsuite/runner_cases_openai.go new file mode 100644 index 0000000..4ca2e40 --- /dev/null +++ b/internal/testsuite/runner_cases_openai.go @@ -0,0 +1,221 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" +) + +func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + ids := extractModelIDs(resp.Body) + cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ",")) + cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ",")) + return nil +} + +func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true}) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} +func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "请简单回复一句话"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body))) + choices, _ := m["choices"].([]any) + cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "请流式回复一句话"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请简要回答 hello", + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body))) + responseID := asString(m["id"]) + cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body))) + if responseID != "" { + getResp, getErr := cc.request(ctx, requestSpec{ + Method: http.MethodGet, + Path: "/v1/responses/" + responseID, + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Retryable: true, + }) + if getErr != nil { + return getErr + } + cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode)) + } + return nil +} + +func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/responses", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": "请流式回答 hello", + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames))) + hasCreated := false + hasCompleted := false + for _, f := range frames { + switch asString(f["type"]) { + case "response.created": + hasCreated = true + case "response.completed": + hasCompleted = true + } + } + cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/embeddings", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "gpt-4o", + "input": []string{"hello", "world"}, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + if resp.StatusCode == http.StatusOK { + cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body))) + data, _ := m["data"].([]any) + cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + errObj, _ := m["error"].(map[string]any) + _, hasCode := errObj["code"] + _, hasParam := errObj["param"] + cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} diff --git a/internal/testsuite/runner_cases_openai_advanced.go b/internal/testsuite/runner_cases_openai_advanced.go new file mode 100644 index 0000000..34e9f01 --- /dev/null +++ b/internal/testsuite/runner_cases_openai_advanced.go @@ -0,0 +1,236 @@ +package testsuite + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" +) + +func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-reasoner", + "messages": []map[string]any{ + {"role": "user", "content": "先思考后回答:1+1"}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + hasReasoning := false + for _, f := range frames { + choices, _ := f["choices"].([]any) + for _, c := range choices { + ch, _ := c.(map[string]any) + delta, _ := ch["delta"].(map[string]any) + if asString(delta["reasoning_content"]) != "" { + hasReasoning = true + } + } + } + cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found") + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: toolcallPayload(false), + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + choices, _ := m["choices"].([]any) + if len(choices) == 0 { + cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body))) + return nil + } + c0, _ := choices[0].(map[string]any) + cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body))) + msg, _ := c0["message"].(map[string]any) + tc, _ := msg["tool_calls"].([]any) + cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: toolcallPayload(true), + Stream: true, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode)) + frames, done := parseSSEFrames(resp.Body) + hasTool := false + rawLeak := false + for _, f := range frames { + choices, _ := f["choices"].([]any) + for _, c := range choices { + ch, _ := c.(map[string]any) + delta, _ := ch["delta"].(map[string]any) + if _, ok := delta["tool_calls"]; ok { + hasTool = true + } + content := asString(delta["content"]) + if strings.Contains(strings.ToLower(content), `"tool_calls"`) { + rawLeak = true + } + } + } + cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing") + cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content") + cc.assert("done_terminated", done, "expected [DONE]") + return nil +} + +func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error { + accountCount := len(r.configRaw.Accounts) + n := accountCount*2 + 2 + if n < 2 { + n = 2 + } + type one struct { + Status int + Err string + } + results := make([]one, n) + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer " + r.apiKey, + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)}, + }, + "stream": true, + }, + Stream: true, + Retryable: true, + }) + if err != nil { + results[idx] = one{Err: err.Error()} + return + } + results[idx] = one{Status: resp.StatusCode} + }(i) + } + wg.Wait() + + dist := map[int]int{} + success := 0 + for _, it := range results { + if it.Status > 0 { + dist[it.Status]++ + if it.Status == http.StatusOK { + success++ + } + } + } + cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist)) + _, has5xx := has5xx(dist) + cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist)) + if err := r.ping("/healthz"); err != nil { + cc.assert("server_alive", false, err.Error()) + } else { + cc.assert("server_alive", true, "") + } + return nil +} + +func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error { + resp, err := cc.request(ctx, requestSpec{ + Method: http.MethodPost, + Path: "/v1/chat/completions", + Headers: map[string]string{ + "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID), + }, + Body: map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + {"role": "user", "content": "hi"}, + }, + "stream": false, + }, + Retryable: true, + }) + if err != nil { + return err + } + cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode)) + var m map[string]any + _ = json.Unmarshal(resp.Body, &m) + e, _ := m["error"].(map[string]any) + cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body))) + cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body))) + return nil +} + +func toolcallPayload(stream bool) map[string]any { + return map[string]any{ + "model": "deepseek-chat", + "messages": []map[string]any{ + { + "role": "user", + "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。", + }, + }, + "tools": []map[string]any{ + { + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search documents", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "q": map[string]any{ + "type": "string", + }, + }, + "required": []string{"q"}, + }, + }, + }, + }, + "stream": stream, + } +} diff --git a/internal/testsuite/runner_core.go b/internal/testsuite/runner_core.go new file mode 100644 index 0000000..06eafa5 --- /dev/null +++ b/internal/testsuite/runner_core.go @@ -0,0 +1,290 @@ +package testsuite + +import ( + "context" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "sync" + "time" +) + +type Options struct { + ConfigPath string + AdminKey string + OutputDir string + Port int + Timeout time.Duration + Retries int + NoPreflight bool + MaxKeepRuns int +} + +type runSummary struct { + RunID string `json:"run_id"` + StartedAt string `json:"started_at"` + EndedAt string `json:"ended_at"` + DurationMS int64 `json:"duration_ms"` + Stats map[string]any `json:"stats"` + Environment map[string]any `json:"environment"` + Cases []caseResult `json:"cases"` + Warnings []string `json:"warnings,omitempty"` +} + +type caseResult struct { + CaseID string `json:"case_id"` + Passed bool `json:"passed"` + DurationMS int64 `json:"duration_ms"` + TraceIDs []string `json:"trace_ids"` + StatusCodes []int `json:"status_codes"` + Error string `json:"error,omitempty"` + ArtifactPath string `json:"artifact_path"` + Assertions []assertionResult `json:"assertions"` +} + +type assertionResult struct { + Name string `json:"name"` + Passed bool `json:"passed"` + Detail string `json:"detail,omitempty"` +} + +type requestLog struct { + Seq int `json:"seq"` + Attempt int `json:"attempt"` + TraceID string `json:"trace_id"` + Method string `json:"method"` + URL string `json:"url"` + Headers map[string]string `json:"headers"` + Body any `json:"body,omitempty"` + Timestamp string `json:"timestamp"` +} + +type responseLog struct { + Seq int `json:"seq"` + Attempt int `json:"attempt"` + TraceID string `json:"trace_id"` + StatusCode int `json:"status_code"` + Headers map[string][]string `json:"headers"` + BodyText string `json:"body_text"` + DurationMS int64 `json:"duration_ms"` + NetworkErr string `json:"network_error,omitempty"` + ReceivedAt string `json:"received_at"` +} + +type caseContext struct { + runner *Runner + id string + dir string + startedAt time.Time + mu sync.Mutex + seq int + assertions []assertionResult + requests []requestLog + responses []responseLog + streamRaw strings.Builder + traceIDsSet map[string]struct{} +} + +type requestSpec struct { + Method string + Path string + Headers map[string]string + Body any + Stream bool + Retryable bool +} + +type responseResult struct { + StatusCode int + Headers http.Header + Body []byte + TraceID string + URL string +} + +type Runner struct { + opts Options + + runID string + runDir string + serverLog string + preflightLog string + + baseURL string + httpClient *http.Client + serverCmd *exec.Cmd + serverLogFd *os.File + + configCopyPath string + originalConfigPath string + originalConfigHash string + + configRaw runConfig + apiKey string + adminKey string + adminJWT string + accountID string + + warnings []string + results []caseResult +} + +type runConfig struct { + Keys []string `json:"keys"` + Accounts []struct { + Email string `json:"email,omitempty"` + Mobile string `json:"mobile,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` + } `json:"accounts"` +} + +func Run(ctx context.Context, opts Options) error { + r, err := newRunner(opts) + if err != nil { + return err + } + start := time.Now() + defer func() { + _ = r.stopServer() + }() + + if err := r.prepareRunDir(); err != nil { + return err + } + + if !r.opts.NoPreflight { + if err := r.runPreflight(ctx); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + } + + if err := r.prepareConfigIsolation(); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + + if err := r.startServer(ctx); err != nil { + _ = r.writeSummary(start, time.Now()) + return err + } + + if err := r.prepareAuth(ctx); err != nil { + r.warnings = append(r.warnings, "auth prepare failed: "+err.Error()) + } + + for _, c := range r.cases() { + r.runCase(ctx, c) + } + + if err := r.ensureOriginalConfigUntouched(); err != nil { + r.warnings = append(r.warnings, err.Error()) + } + + end := time.Now() + if err := r.writeSummary(start, end); err != nil { + return err + } + + // Prune old test runs, keeping only the most recent N. + if err := r.pruneOldRuns(); err != nil { + r.warnings = append(r.warnings, "prune old runs: "+err.Error()) + } + + failed := 0 + for _, cs := range r.results { + if !cs.Passed { + failed++ + } + } + if failed > 0 { + return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md")) + } + return nil +} + +func newRunner(opts Options) (*Runner, error) { + if strings.TrimSpace(opts.ConfigPath) == "" { + opts.ConfigPath = "config.json" + } + if strings.TrimSpace(opts.OutputDir) == "" { + opts.OutputDir = "artifacts/testsuite" + } + if opts.Timeout <= 0 { + opts.Timeout = 120 * time.Second + } + if opts.Retries < 0 { + opts.Retries = 0 + } + adminKey := strings.TrimSpace(opts.AdminKey) + if adminKey == "" { + adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")) + } + if adminKey == "" { + adminKey = "admin" + } + opts.AdminKey = adminKey + + return &Runner{ + opts: opts, + httpClient: &http.Client{ + Timeout: 0, + }, + runID: time.Now().UTC().Format("20060102T150405Z"), + adminKey: adminKey, + }, nil +} +func (r *Runner) runCase(ctx context.Context, c caseDef) { + caseDir := filepath.Join(r.runDir, "cases", c.ID) + _ = os.MkdirAll(caseDir, 0o755) + cc := &caseContext{ + runner: r, + id: c.ID, + dir: caseDir, + startedAt: time.Now(), + traceIDsSet: map[string]struct{}{}, + } + err := c.Run(ctx, cc) + duration := time.Since(cc.startedAt).Milliseconds() + + if err != nil { + cc.assertions = append(cc.assertions, assertionResult{ + Name: "case_error", + Passed: false, + Detail: err.Error(), + }) + } + passed := err == nil + for _, a := range cc.assertions { + if !a.Passed { + passed = false + break + } + } + + traceIDs := make([]string, 0, len(cc.traceIDsSet)) + for t := range cc.traceIDsSet { + traceIDs = append(traceIDs, t) + } + sort.Strings(traceIDs) + statuses := uniqueStatusCodes(cc.responses) + cs := caseResult{ + CaseID: c.ID, + Passed: passed, + DurationMS: duration, + TraceIDs: traceIDs, + StatusCodes: statuses, + ArtifactPath: caseDir, + Assertions: cc.assertions, + } + if err != nil { + cs.Error = err.Error() + } + _ = cc.flushArtifacts(cs) + r.results = append(r.results, cs) +} diff --git a/internal/testsuite/runner_defaults.go b/internal/testsuite/runner_defaults.go new file mode 100644 index 0000000..ab30bf1 --- /dev/null +++ b/internal/testsuite/runner_defaults.go @@ -0,0 +1,20 @@ +package testsuite + +import ( + "os" + "strings" + "time" +) + +func DefaultOptions() Options { + return Options{ + ConfigPath: "config.json", + AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")), + OutputDir: "artifacts/testsuite", + Port: 0, + Timeout: 120 * time.Second, + Retries: 2, + NoPreflight: false, + MaxKeepRuns: 5, + } +} diff --git a/internal/testsuite/runner_env.go b/internal/testsuite/runner_env.go new file mode 100644 index 0000000..3ae0ba4 --- /dev/null +++ b/internal/testsuite/runner_env.go @@ -0,0 +1,261 @@ +package testsuite + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "sort" + "strconv" + "strings" + "time" +) + +func (r *Runner) prepareRunDir() error { + r.runDir = filepath.Join(r.opts.OutputDir, r.runID) + if err := os.MkdirAll(r.runDir, 0o755); err != nil { + return err + } + if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil { + return err + } + r.serverLog = filepath.Join(r.runDir, "server.log") + r.preflightLog = filepath.Join(r.runDir, "preflight.log") + return nil +} + +// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns. +// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order. +func (r *Runner) pruneOldRuns() error { + keep := r.opts.MaxKeepRuns + if keep <= 0 { + return nil // 0 or negative means no pruning + } + + entries, err := os.ReadDir(r.opts.OutputDir) + if err != nil { + return err + } + + // Collect only directories (each run is a directory). + var runDirs []string + for _, e := range entries { + if !e.IsDir() { + continue + } + runDirs = append(runDirs, e.Name()) + } + + sort.Strings(runDirs) + + if len(runDirs) <= keep { + return nil + } + + // Remove oldest runs (those at the beginning of the sorted list). + toRemove := runDirs[:len(runDirs)-keep] + var errs []string + for _, name := range toRemove { + dirPath := filepath.Join(r.opts.OutputDir, name) + if err := os.RemoveAll(dirPath); err != nil { + errs = append(errs, fmt.Sprintf("remove %s: %v", name, err)) + } else { + fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name) + } + } + + if len(errs) > 0 { + return errors.New(strings.Join(errs, "; ")) + } + return nil +} + +func (r *Runner) runPreflight(ctx context.Context) error { + steps := [][]string{ + {"go", "test", "./...", "-count=1"}, + {"node", "--check", "api/chat-stream.js"}, + {"node", "--check", "api/helpers/stream-tool-sieve.js"}, + {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"}, + {"npm", "run", "build", "--prefix", "webui"}, + } + f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return err + } + defer f.Close() + for _, step := range steps { + if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil { + return err + } + cmd := exec.CommandContext(ctx, step[0], step[1:]...) + cmd.Stdout = f + cmd.Stderr = f + if err := cmd.Run(); err != nil { + return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err) + } + } + return nil +} + +func (r *Runner) prepareConfigIsolation() error { + abs, err := filepath.Abs(r.opts.ConfigPath) + if err != nil { + return err + } + r.originalConfigPath = abs + raw, err := os.ReadFile(abs) + if err != nil { + return err + } + sum := sha256.Sum256(raw) + r.originalConfigHash = hex.EncodeToString(sum[:]) + + tmpDir := filepath.Join(r.runDir, "tmp") + if err := os.MkdirAll(tmpDir, 0o755); err != nil { + return err + } + r.configCopyPath = filepath.Join(tmpDir, "config.json") + if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil { + return err + } + var cfg runConfig + if err := json.Unmarshal(raw, &cfg); err != nil { + return fmt.Errorf("parse config failed: %w", err) + } + r.configRaw = cfg + if len(cfg.Keys) > 0 { + r.apiKey = strings.TrimSpace(cfg.Keys[0]) + } + for _, acc := range cfg.Accounts { + id := strings.TrimSpace(acc.Email) + if id == "" { + id = strings.TrimSpace(acc.Mobile) + } + if id != "" { + r.accountID = id + break + } + } + return nil +} + +func (r *Runner) startServer(ctx context.Context) error { + port := r.opts.Port + if port <= 0 { + p, err := findFreePort() + if err != nil { + return err + } + port = p + } + r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port) + + logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644) + if err != nil { + return err + } + r.serverLogFd = logFd + cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api") + cmd.Stdout = logFd + cmd.Stderr = logFd + cmd.Env = prepareServerEnv(os.Environ(), map[string]string{ + "PORT": strconv.Itoa(port), + "DS2API_CONFIG_PATH": r.configCopyPath, + "DS2API_AUTO_BUILD_WEBUI": "false", + "DS2API_CONFIG_JSON": "", + "CONFIG_JSON": "", + }) + if err := cmd.Start(); err != nil { + _ = logFd.Close() + return err + } + r.serverCmd = cmd + + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + if r.ping("/healthz") == nil && r.ping("/readyz") == nil { + return nil + } + time.Sleep(500 * time.Millisecond) + } + return errors.New("server readiness timeout") +} + +func (r *Runner) stopServer() error { + var errs []string + if r.serverCmd != nil && r.serverCmd.Process != nil { + _ = r.serverCmd.Process.Signal(os.Interrupt) + done := make(chan error, 1) + go func() { done <- r.serverCmd.Wait() }() + select { + case <-time.After(5 * time.Second): + _ = r.serverCmd.Process.Kill() + <-done + case <-done: + } + } + if r.serverLogFd != nil { + if err := r.serverLogFd.Close(); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return errors.New(strings.Join(errs, "; ")) + } + return nil +} + +func (r *Runner) ping(path string) error { + resp, err := r.httpClient.Get(r.baseURL + path) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status=%d", resp.StatusCode) + } + return nil +} + +func (r *Runner) prepareAuth(ctx context.Context) error { + reqBody := map[string]any{ + "admin_key": r.adminKey, + "expire_hours": 24, + } + resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body)) + } + var m map[string]any + if err := json.Unmarshal(resp.Body, &m); err != nil { + return err + } + token, _ := m["token"].(string) + if strings.TrimSpace(token) == "" { + return errors.New("empty admin jwt token") + } + r.adminJWT = token + return nil +} + +func (r *Runner) ensureOriginalConfigUntouched() error { + raw, err := os.ReadFile(r.originalConfigPath) + if err != nil { + return err + } + sum := sha256.Sum256(raw) + current := hex.EncodeToString(sum[:]) + if current != r.originalConfigHash { + return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath) + } + return nil +} diff --git a/internal/testsuite/runner_http.go b/internal/testsuite/runner_http.go new file mode 100644 index 0000000..d98c60a --- /dev/null +++ b/internal/testsuite/runner_http.go @@ -0,0 +1,217 @@ +package testsuite + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +func (cc *caseContext) assert(name string, ok bool, detail string) { + cc.mu.Lock() + defer cc.mu.Unlock() + cc.assertions = append(cc.assertions, assertionResult{ + Name: name, + Passed: ok, + Detail: detail, + }) +} + +func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) { + retries := cc.runner.opts.Retries + if !spec.Retryable { + retries = 0 + } + var lastErr error + for attempt := 1; attempt <= retries+1; attempt++ { + resp, err := cc.requestOnce(ctx, spec, attempt) + if err == nil && resp.StatusCode < 500 { + return resp, nil + } + if err != nil { + lastErr = err + } else if resp.StatusCode >= 500 { + lastErr = fmt.Errorf("status=%d", resp.StatusCode) + } + if attempt <= retries { + sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond + time.Sleep(sleep) + } + } + return nil, lastErr +} + +func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) { + cc.mu.Lock() + cc.seq++ + seq := cc.seq + traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq) + cc.traceIDsSet[traceID] = struct{}{} + cc.mu.Unlock() + + fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID) + if err != nil { + return nil, err + } + + headers := map[string]string{} + for k, v := range spec.Headers { + headers[k] = v + } + headers["X-Ds2-Test-Trace"] = traceID + + var bodyBytes []byte + var bodyAny any + if spec.Body != nil { + b, err := json.Marshal(spec.Body) + if err != nil { + return nil, err + } + bodyBytes = b + bodyAny = spec.Body + headers["Content-Type"] = "application/json" + } + cc.mu.Lock() + cc.requests = append(cc.requests, requestLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + Method: spec.Method, + URL: fullURL, + Headers: headers, + Body: bodyAny, + Timestamp: time.Now().Format(time.RFC3339Nano), + }) + cc.mu.Unlock() + + reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + start := time.Now() + resp, err := cc.runner.httpClient.Do(req) + if err != nil { + cc.mu.Lock() + cc.responses = append(cc.responses, responseLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + StatusCode: 0, + DurationMS: time.Since(start).Milliseconds(), + NetworkErr: err.Error(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + cc.mu.Unlock() + return nil, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + cc.mu.Lock() + cc.responses = append(cc.responses, responseLog{ + Seq: seq, + Attempt: attempt, + TraceID: traceID, + StatusCode: resp.StatusCode, + Headers: resp.Header, + BodyText: string(body), + DurationMS: time.Since(start).Milliseconds(), + ReceivedAt: time.Now().Format(time.RFC3339Nano), + }) + + if spec.Stream { + cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL)) + cc.streamRaw.Write(body) + cc.streamRaw.WriteString("\n\n") + } + cc.mu.Unlock() + + return &responseResult{ + StatusCode: resp.StatusCode, + Headers: resp.Header, + Body: body, + TraceID: traceID, + URL: fullURL, + }, nil +} + +func (cc *caseContext) flushArtifacts(cs caseResult) error { + requestPath := filepath.Join(cc.dir, "request.json") + headersPath := filepath.Join(cc.dir, "response.headers") + bodyPath := filepath.Join(cc.dir, "response.body") + streamPath := filepath.Join(cc.dir, "stream.raw") + assertPath := filepath.Join(cc.dir, "assertions.json") + metaPath := filepath.Join(cc.dir, "meta.json") + + if err := writeJSONFile(requestPath, cc.requests); err != nil { + return err + } + respHeaders := make([]map[string]any, 0, len(cc.responses)) + respBodies := make([]map[string]any, 0, len(cc.responses)) + for _, r := range cc.responses { + respHeaders = append(respHeaders, map[string]any{ + "seq": r.Seq, + "attempt": r.Attempt, + "trace_id": r.TraceID, + "status_code": r.StatusCode, + "headers": r.Headers, + }) + respBodies = append(respBodies, map[string]any{ + "seq": r.Seq, + "attempt": r.Attempt, + "trace_id": r.TraceID, + "status_code": r.StatusCode, + "body_text": r.BodyText, + "network_error": r.NetworkErr, + "duration_ms": r.DurationMS, + }) + } + if err := writeJSONFile(headersPath, respHeaders); err != nil { + return err + } + if err := writeJSONFile(bodyPath, respBodies); err != nil { + return err + } + if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil { + return err + } + if err := writeJSONFile(assertPath, cc.assertions); err != nil { + return err + } + meta := map[string]any{ + "case_id": cs.CaseID, + "trace_id": strings.Join(cs.TraceIDs, ","), + "attempt": len(cc.responses), + "duration_ms": cs.DurationMS, + "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed], + "status_codes": cs.StatusCodes, + "assertions": cs.Assertions, + "artifact_path": cs.ArtifactPath, + } + return writeJSONFile(metaPath, meta) +} +func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) { + cc := &caseContext{ + runner: r, + id: "auth_prepare", + traceIDsSet: map[string]struct{}{}, + } + return cc.request(ctx, requestSpec{ + Method: method, + Path: path, + Headers: headers, + Body: body, + Retryable: true, + }) +} diff --git a/internal/testsuite/runner_registry.go b/internal/testsuite/runner_registry.go new file mode 100644 index 0000000..08b602a --- /dev/null +++ b/internal/testsuite/runner_registry.go @@ -0,0 +1,43 @@ +package testsuite + +import "context" + +type caseDef struct { + ID string + Run func(context.Context, *caseContext) error +} + +func (r *Runner) cases() []caseDef { + return []caseDef{ + {ID: "healthz_ok", Run: r.caseHealthz}, + {ID: "readyz_ok", Run: r.caseReadyz}, + {ID: "models_openai", Run: r.caseModelsOpenAI}, + {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID}, + {ID: "models_claude", Run: r.caseModelsClaude}, + {ID: "admin_login_verify", Run: r.caseAdminLoginVerify}, + {ID: "admin_queue_status", Run: r.caseAdminQueueStatus}, + {ID: "chat_nonstream_basic", Run: r.caseChatNonstream}, + {ID: "chat_stream_basic", Run: r.caseChatStream}, + {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream}, + {ID: "responses_stream_basic", Run: r.caseResponsesStream}, + {ID: "embeddings_contract", Run: r.caseEmbeddings}, + {ID: "reasoner_stream", Run: r.caseReasonerStream}, + {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream}, + {ID: "toolcall_stream", Run: r.caseToolcallStream}, + {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream}, + {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream}, + {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens}, + {ID: "admin_account_test_single", Run: r.caseAdminAccountTest}, + {ID: "concurrency_burst", Run: r.caseConcurrencyBurst}, + {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit}, + {ID: "stream_abort_release", Run: r.caseStreamAbortRelease}, + {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed}, + {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity}, + {ID: "error_contract_invalid_model", Run: r.caseInvalidModel}, + {ID: "error_contract_missing_messages", Run: r.caseMissingMessages}, + {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized}, + {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated}, + {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount}, + {ID: "error_contract_invalid_key", Run: r.caseInvalidKey}, + } +} diff --git a/internal/testsuite/runner_summary.go b/internal/testsuite/runner_summary.go new file mode 100644 index 0000000..25b44a4 --- /dev/null +++ b/internal/testsuite/runner_summary.go @@ -0,0 +1,97 @@ +package testsuite + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +func (r *Runner) writeSummary(start, end time.Time) error { + passed := 0 + failed := 0 + for _, cs := range r.results { + if cs.Passed { + passed++ + } else { + failed++ + } + } + summary := runSummary{ + RunID: r.runID, + StartedAt: start.Format(time.RFC3339Nano), + EndedAt: end.Format(time.RFC3339Nano), + DurationMS: end.Sub(start).Milliseconds(), + Stats: map[string]any{ + "total": len(r.results), + "passed": passed, + "failed": failed, + }, + Environment: map[string]any{ + "go_version": runtime.Version(), + "os": runtime.GOOS, + "arch": runtime.GOARCH, + "base_url": r.baseURL, + "config_source": r.originalConfigPath, + "config_isolated": r.configCopyPath, + "server_log": r.serverLog, + "preflight_log": r.preflightLog, + "retries": r.opts.Retries, + "timeout_seconds": int(r.opts.Timeout.Seconds()), + }, + Cases: r.results, + Warnings: r.warnings, + } + if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil { + return err + } + return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644) +} + +func (r *Runner) summaryMarkdown(s runSummary) string { + var b strings.Builder + b.WriteString("# DS2API Live Testsuite Summary\n\n") + b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n") + fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID) + fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt) + fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt) + fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS) + fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"]) + if len(s.Warnings) > 0 { + b.WriteString("## Warnings\n\n") + for _, w := range s.Warnings { + fmt.Fprintf(&b, "- %s\n", w) + } + b.WriteString("\n") + } + b.WriteString("## Failed Cases\n\n") + hasFailed := false + for _, c := range s.Cases { + if c.Passed { + continue + } + hasFailed = true + fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error) + if len(c.TraceIDs) > 0 { + fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", ")) + fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log")) + } + fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath) + } + if !hasFailed { + b.WriteString("- none\n") + } + b.WriteString("\n## Case Table\n\n") + b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n") + b.WriteString("|---|---:|---:|---|---|\n") + for _, c := range s.Cases { + status := "PASS" + if !c.Passed { + status = "FAIL" + } + fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath) + } + return b.String() +} diff --git a/internal/testsuite/runner_utils.go b/internal/testsuite/runner_utils.go new file mode 100644 index 0000000..c4879c6 --- /dev/null +++ b/internal/testsuite/runner_utils.go @@ -0,0 +1,202 @@ +package testsuite + +import ( + "encoding/json" + "errors" + "fmt" + "net" + "net/url" + "os" + "sort" + "strings" +) + +func parseSSEFrames(body []byte) ([]map[string]any, bool) { + lines := strings.Split(string(body), "\n") + frames := make([]map[string]any, 0, len(lines)) + done := false + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + if payload == "[DONE]" { + done = true + continue + } + var m map[string]any + if err := json.Unmarshal([]byte(payload), &m); err == nil { + frames = append(frames, m) + } + } + return frames, done +} + +func parseClaudeStreamEvents(body []byte) []string { + events := []string{} + seen := map[string]bool{} + lines := strings.Split(string(body), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if !strings.HasPrefix(line, "data:") { + continue + } + payload := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + if payload == "" { + continue + } + var m map[string]any + if err := json.Unmarshal([]byte(payload), &m); err != nil { + continue + } + t := asString(m["type"]) + if t == "" || seen[t] { + continue + } + seen[t] = true + events = append(events, t) + } + return events +} + +func extractModelIDs(body []byte) []string { + var m map[string]any + if err := json.Unmarshal(body, &m); err != nil { + return nil + } + out := []string{} + data, _ := m["data"].([]any) + for _, it := range data { + item, _ := it.(map[string]any) + id := asString(item["id"]) + if id != "" { + out = append(out, id) + } + } + return out +} + +func withTraceQuery(rawURL, traceID string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + q := u.Query() + q.Set("__trace_id", traceID) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func writeJSONFile(path string, v any) error { + b, err := json.MarshalIndent(v, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, b, 0o644) +} + +func prepareServerEnv(base []string, overrides map[string]string) []string { + out := make([]string, 0, len(base)+len(overrides)) + skip := map[string]struct{}{} + for k := range overrides { + skip[k] = struct{}{} + } + for _, e := range base { + parts := strings.SplitN(e, "=", 2) + if len(parts) != 2 { + continue + } + if _, ok := skip[parts[0]]; ok { + continue + } + out = append(out, e) + } + for k, v := range overrides { + out = append(out, k+"="+v) + } + return out +} + +func findFreePort() (int, error) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return 0, err + } + defer ln.Close() + addr, ok := ln.Addr().(*net.TCPAddr) + if !ok { + return 0, errors.New("failed to detect tcp port") + } + return addr.Port, nil +} + +func uniqueStatusCodes(in []responseLog) []int { + set := map[int]struct{}{} + for _, it := range in { + if it.StatusCode > 0 { + set[it.StatusCode] = struct{}{} + } + } + out := make([]int, 0, len(set)) + for k := range set { + out = append(out, k) + } + sort.Ints(out) + return out +} + +func has5xx(dist map[int]int) (int, bool) { + for k := range dist { + if k >= 500 { + return k, true + } + } + return 0, false +} + +func sanitizeID(s string) string { + s = strings.ReplaceAll(s, ":", "_") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, " ", "_") + return s +} + +func asString(v any) string { + if v == nil { + return "" + } + switch x := v.(type) { + case string: + return strings.TrimSpace(x) + default: + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } +} + +func toInt(v any) int { + switch x := v.(type) { + case float64: + return int(x) + case float32: + return int(x) + case int: + return x + case int64: + return int(x) + default: + return 0 + } +} + +func contains(xs []string, target string) bool { + for _, x := range xs { + if x == target { + return true + } + } + return false +} diff --git a/internal/util/toolcalls_candidates.go b/internal/util/toolcalls_candidates.go new file mode 100644 index 0000000..4e8afc4 --- /dev/null +++ b/internal/util/toolcalls_candidates.go @@ -0,0 +1,138 @@ +package util + +import ( + "regexp" + "strings" +) + +var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) +var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") +var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```") + +func buildToolCallCandidates(text string) []string { + trimmed := strings.TrimSpace(text) + candidates := []string{trimmed} + + // fenced code block candidates: ```json ... ``` + for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) { + if len(match) >= 2 { + candidates = append(candidates, strings.TrimSpace(match[1])) + } + } + + // best-effort extraction around "tool_calls" key in mixed text payloads. + candidates = append(candidates, extractToolCallObjects(trimmed)...) + + // best-effort object slice: from first '{' to last '}' + first := strings.Index(trimmed, "{") + last := strings.LastIndex(trimmed, "}") + if first >= 0 && last > first { + candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1])) + } + + // legacy regex extraction fallback + if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 { + candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}") + } + + uniq := make([]string, 0, len(candidates)) + seen := map[string]struct{}{} + for _, c := range candidates { + if c == "" { + continue + } + if _, ok := seen[c]; ok { + continue + } + seen[c] = struct{}{} + uniq = append(uniq, c) + } + return uniq +} + +func extractToolCallObjects(text string) []string { + if text == "" { + return nil + } + lower := strings.ToLower(text) + out := []string{} + offset := 0 + for { + idx := strings.Index(lower[offset:], "tool_calls") + if idx < 0 { + break + } + idx += offset + start := strings.LastIndex(text[:idx], "{") + for start >= 0 { + candidate, end, ok := extractJSONObject(text, start) + if ok { + // Move forward to avoid repeatedly matching the same object. + offset = end + out = append(out, strings.TrimSpace(candidate)) + break + } + start = strings.LastIndex(text[:start], "{") + } + if start < 0 { + offset = idx + len("tool_calls") + } + } + return out +} + +func extractJSONObject(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '{' { + return "", 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + return text[start : i+1], i + 1, true + } + } + } + return "", 0, false +} + +func looksLikeToolExampleContext(text string) bool { + t := strings.ToLower(strings.TrimSpace(text)) + if t == "" { + return false + } + return strings.Contains(t, "```") +} + +func stripFencedCodeBlocks(text string) string { + if strings.TrimSpace(text) == "" { + return "" + } + return fencedBlockPattern.ReplaceAllString(text, " ") +} diff --git a/internal/util/toolcalls_format.go b/internal/util/toolcalls_format.go new file mode 100644 index 0000000..8feb48f --- /dev/null +++ b/internal/util/toolcalls_format.go @@ -0,0 +1,41 @@ +package util + +import ( + "encoding/json" + "strings" + + "github.com/google/uuid" +) + +func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for _, c := range calls { + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} + +func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any { + out := make([]map[string]any, 0, len(calls)) + for i, c := range calls { + args, _ := json.Marshal(c.Input) + out = append(out, map[string]any{ + "index": i, + "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), + "type": "function", + "function": map[string]any{ + "name": c.Name, + "arguments": string(args), + }, + }) + } + return out +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls_parse.go similarity index 52% rename from internal/util/toolcalls.go rename to internal/util/toolcalls_parse.go index 9e44b94..ab9fe84 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls_parse.go @@ -2,16 +2,9 @@ package util import ( "encoding/json" - "regexp" "strings" - - "github.com/google/uuid" ) -var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`) -var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") -var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```") - type ParsedToolCall struct { Name string `json:"name"` Input map[string]any `json:"input"` @@ -102,47 +95,6 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par return out } -func buildToolCallCandidates(text string) []string { - trimmed := strings.TrimSpace(text) - candidates := []string{trimmed} - - // fenced code block candidates: ```json ... ``` - for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) { - if len(match) >= 2 { - candidates = append(candidates, strings.TrimSpace(match[1])) - } - } - - // best-effort extraction around "tool_calls" key in mixed text payloads. - candidates = append(candidates, extractToolCallObjects(trimmed)...) - - // best-effort object slice: from first '{' to last '}' - first := strings.Index(trimmed, "{") - last := strings.LastIndex(trimmed, "}") - if first >= 0 && last > first { - candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1])) - } - - // legacy regex extraction fallback - if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 { - candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}") - } - - uniq := make([]string, 0, len(candidates)) - seen := map[string]struct{}{} - for _, c := range candidates { - if c == "" { - continue - } - if _, ok := seen[c]; ok { - continue - } - seen[c] = struct{}{} - uniq = append(uniq, c) - } - return uniq -} - func parseToolCallsPayload(payload string) []ParsedToolCall { var decoded any if err := json.Unmarshal([]byte(payload), &decoded); err != nil { @@ -243,123 +195,3 @@ func parseToolCallInput(v any) map[string]any { return map[string]any{} } } - -func extractToolCallObjects(text string) []string { - if text == "" { - return nil - } - lower := strings.ToLower(text) - out := []string{} - offset := 0 - for { - idx := strings.Index(lower[offset:], "tool_calls") - if idx < 0 { - break - } - idx += offset - start := strings.LastIndex(text[:idx], "{") - for start >= 0 { - candidate, end, ok := extractJSONObject(text, start) - if ok { - // Move forward to avoid repeatedly matching the same object. - offset = end - out = append(out, strings.TrimSpace(candidate)) - break - } - start = strings.LastIndex(text[:start], "{") - } - if start < 0 { - offset = idx + len("tool_calls") - } - } - return out -} - -func extractJSONObject(text string, start int) (string, int, bool) { - if start < 0 || start >= len(text) || text[start] != '{' { - return "", 0, false - } - depth := 0 - quote := byte(0) - escaped := false - for i := start; i < len(text); i++ { - ch := text[i] - if quote != 0 { - if escaped { - escaped = false - continue - } - if ch == '\\' { - escaped = true - continue - } - if ch == quote { - quote = 0 - } - continue - } - if ch == '"' || ch == '\'' { - quote = ch - continue - } - if ch == '{' { - depth++ - continue - } - if ch == '}' { - depth-- - if depth == 0 { - return text[start : i+1], i + 1, true - } - } - } - return "", 0, false -} - -func looksLikeToolExampleContext(text string) bool { - t := strings.ToLower(strings.TrimSpace(text)) - if t == "" { - return false - } - return strings.Contains(t, "```") -} - -func stripFencedCodeBlocks(text string) string { - if strings.TrimSpace(text) == "" { - return "" - } - return fencedBlockPattern.ReplaceAllString(text, " ") -} - -func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any { - out := make([]map[string]any, 0, len(calls)) - for _, c := range calls { - args, _ := json.Marshal(c.Input) - out = append(out, map[string]any{ - "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "type": "function", - "function": map[string]any{ - "name": c.Name, - "arguments": string(args), - }, - }) - } - return out -} - -func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any { - out := make([]map[string]any, 0, len(calls)) - for i, c := range calls { - args, _ := json.Marshal(c.Input) - out = append(out, map[string]any{ - "index": i, - "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""), - "type": "function", - "function": map[string]any{ - "name": c.Name, - "arguments": string(args), - }, - }) - } - return out -} diff --git a/plans/refactor-baseline.md b/plans/refactor-baseline.md new file mode 100644 index 0000000..9b31ed1 --- /dev/null +++ b/plans/refactor-baseline.md @@ -0,0 +1,31 @@ +# DS2API Refactor Baseline (Backfilled) + +- Recorded at: `2026-02-22T08:53:54Z` +- Branch: `dev` +- HEAD: `5d3989a` +- Scope: backend + node api + webui large-file decoupling (no behavior change) + +## Gate Commands + +1. `./tests/scripts/run-unit-all.sh` + - Result: PASS + - Includes: + - `go test ./...` + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` +2. `npm --prefix webui run build` + - Result: PASS +3. `./tests/scripts/check-refactor-line-gate.sh` + - Result: PASS (`checked=131 missing=0 over_limit=0`) +4. Stage gates (1-5) replay: + - `go test ./internal/config ./internal/admin ./internal/account ./internal/deepseek ./internal/format/openai` -> PASS + - `go test ./internal/adapter/openai ./internal/util ./internal/sse ./internal/compat` -> PASS + - `go test ./internal/adapter/claude ./internal/adapter/gemini ./internal/config` -> PASS + - `go test ./internal/testsuite ./cmd/ds2api-tests` -> PASS + - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` -> PASS +5. Final full regression: + - `go test ./... -count=1` -> PASS + +## Notes + +- This baseline file is a backfill artifact for phase 0 process tracking. +- Frontend manual smoke for phase 6 still requires human execution and sign-off. diff --git a/plans/refactor-line-gate-targets.txt b/plans/refactor-line-gate-targets.txt new file mode 100644 index 0000000..d3678ad --- /dev/null +++ b/plans/refactor-line-gate-targets.txt @@ -0,0 +1,151 @@ +# Line gate targets for large-file decoupling refactor. +# Default limit: 300 lines +# Entry/facade limit: 120 lines (enforced in script) + +internal/config/config.go +internal/config/logger.go +internal/config/paths.go +internal/config/codec.go +internal/config/store.go +internal/config/store_index.go +internal/config/store_accessors.go +internal/config/account.go + +internal/admin/handler_config_read.go +internal/admin/handler_config_write.go +internal/admin/handler_config_import.go +internal/admin/handler_settings_read.go +internal/admin/handler_settings_write.go +internal/admin/handler_settings_parse.go +internal/admin/handler_settings_runtime.go +internal/admin/handler_accounts_crud.go +internal/admin/handler_accounts_testing.go +internal/admin/handler_accounts_queue.go + +internal/account/pool_core.go +internal/account/pool_acquire.go +internal/account/pool_waiters.go +internal/account/pool_limits.go + +internal/deepseek/client_core.go +internal/deepseek/client_auth.go +internal/deepseek/client_completion.go +internal/deepseek/client_http_json.go +internal/deepseek/client_http_helpers.go + +internal/format/openai/render_chat.go +internal/format/openai/render_responses.go +internal/format/openai/render_stream_events.go +internal/format/openai/render_usage.go + +internal/adapter/openai/handler_routes.go +internal/adapter/openai/handler_chat.go +internal/adapter/openai/handler_errors.go +internal/adapter/openai/handler_toolcall_policy.go +internal/adapter/openai/handler_toolcall_format.go +internal/adapter/openai/responses_handler.go +internal/adapter/openai/responses_input_normalize.go +internal/adapter/openai/responses_input_items.go +internal/adapter/openai/responses_stream_runtime_core.go +internal/adapter/openai/responses_stream_runtime_events.go +internal/adapter/openai/responses_stream_runtime_toolcalls.go +internal/adapter/openai/tool_sieve_state.go +internal/adapter/openai/tool_sieve_core.go +internal/adapter/openai/tool_sieve_incremental.go +internal/adapter/openai/tool_sieve_jsonscan.go + +internal/util/toolcalls_parse.go +internal/util/toolcalls_candidates.go +internal/util/toolcalls_format.go + +internal/adapter/claude/handler_routes.go +internal/adapter/claude/handler_messages.go +internal/adapter/claude/handler_tokens.go +internal/adapter/claude/handler_errors.go +internal/adapter/claude/handler_utils.go +internal/adapter/claude/stream_runtime_core.go +internal/adapter/claude/stream_runtime_emit.go +internal/adapter/claude/stream_runtime_finalize.go + +internal/adapter/gemini/handler_routes.go +internal/adapter/gemini/handler_generate.go +internal/adapter/gemini/handler_stream_runtime.go +internal/adapter/gemini/handler_errors.go +internal/adapter/gemini/convert_request.go +internal/adapter/gemini/convert_messages.go +internal/adapter/gemini/convert_tools.go +internal/adapter/gemini/convert_passthrough.go + +internal/testsuite/runner_core.go +internal/testsuite/runner_env.go +internal/testsuite/runner_http.go +internal/testsuite/runner_cases_openai.go +internal/testsuite/runner_cases_openai_advanced.go +internal/testsuite/runner_cases_admin.go +internal/testsuite/runner_cases_claude.go +internal/testsuite/runner_summary.go +internal/testsuite/runner_utils.go +internal/testsuite/runner_defaults.go +internal/testsuite/runner_registry.go +internal/testsuite/edge_cases_abort.go +internal/testsuite/edge_cases_error_contract.go + +api/chat-stream.js +api/chat-stream/index.js +api/chat-stream/vercel_stream.js +api/chat-stream/proxy_go.js +api/chat-stream/sse_parse.js +api/chat-stream/http_internal.js +api/chat-stream/toolcall_policy.js +api/chat-stream/error_shape.js +api/chat-stream/token_usage.js +api/chat-stream/stream_emitter.js + +api/helpers/stream-tool-sieve.js +api/helpers/stream-tool-sieve/index.js +api/helpers/stream-tool-sieve/state.js +api/helpers/stream-tool-sieve/sieve.js +api/helpers/stream-tool-sieve/incremental.js +api/helpers/stream-tool-sieve/jsonscan.js +api/helpers/stream-tool-sieve/parse.js +api/helpers/stream-tool-sieve/format.js + +webui/src/App.jsx +webui/src/app/AppRoutes.jsx +webui/src/app/useAdminAuth.js +webui/src/app/useAdminConfig.js +webui/src/layout/DashboardShell.jsx + +webui/src/components/AccountManager.jsx +webui/src/features/account/AccountManagerContainer.jsx +webui/src/features/account/useAccountsData.js +webui/src/features/account/useAccountActions.js +webui/src/features/account/QueueCards.jsx +webui/src/features/account/ApiKeysPanel.jsx +webui/src/features/account/AccountsTable.jsx +webui/src/features/account/AddKeyModal.jsx +webui/src/features/account/AddAccountModal.jsx + +webui/src/components/ApiTester.jsx +webui/src/features/apiTester/ApiTesterContainer.jsx +webui/src/features/apiTester/useApiTesterState.js +webui/src/features/apiTester/useChatStreamClient.js +webui/src/features/apiTester/ConfigPanel.jsx +webui/src/features/apiTester/ChatPanel.jsx + +webui/src/components/Settings.jsx +webui/src/features/settings/SettingsContainer.jsx +webui/src/features/settings/useSettingsForm.js +webui/src/features/settings/settingsApi.js +webui/src/features/settings/SecuritySection.jsx +webui/src/features/settings/RuntimeSection.jsx +webui/src/features/settings/BehaviorSection.jsx +webui/src/features/settings/ModelSection.jsx +webui/src/features/settings/BackupSection.jsx + +webui/src/components/VercelSync.jsx +webui/src/features/vercel/VercelSyncContainer.jsx +webui/src/features/vercel/useVercelSyncState.js +webui/src/features/vercel/VercelSyncForm.jsx +webui/src/features/vercel/VercelSyncStatus.jsx +webui/src/features/vercel/VercelGuide.jsx diff --git a/plans/refactor-line-gate.md b/plans/refactor-line-gate.md new file mode 100644 index 0000000..86f0d82 --- /dev/null +++ b/plans/refactor-line-gate.md @@ -0,0 +1,21 @@ +# Refactor Line Gate + +## Rules + +1. Production file default upper bound: `<= 300` lines. +2. Entry/facade files upper bound: `<= 120` lines. +3. Scope is limited to target files in `plans/refactor-line-gate-targets.txt`. +4. Test files are out of scope for this gate. + +## Command + +```bash +./tests/scripts/check-refactor-line-gate.sh +``` + +## Naming Note + +- Original split plan used `internal/admin/handler_accounts_test.go` for account probing logic. +- In Go, `*_test.go` files are test-only compilation units and cannot host production handlers. +- The production file is implemented as `internal/admin/handler_accounts_testing.go`. + diff --git a/plans/stage6-manual-smoke.md b/plans/stage6-manual-smoke.md new file mode 100644 index 0000000..8875b54 --- /dev/null +++ b/plans/stage6-manual-smoke.md @@ -0,0 +1,29 @@ +# Stage 6 Manual Smoke Checklist + +- Date: +- Tester: +- Environment: + +## Items + +1. Login flow (`/admin/login`) succeeds and failure message shape unchanged. +2. Account manager: + - add/edit/delete account + - queue status cards render and refresh +3. API tester: + - non-stream request succeeds + - stream request receives incremental output and final state +4. Settings: + - read settings + - save settings + - backup/export path works +5. Vercel sync: + - status poll + - manual refresh + - sync action and status feedback text + +## Result + +- Status: `PENDING` +- Notes: + diff --git a/tests/scripts/check-refactor-line-gate.sh b/tests/scripts/check-refactor-line-gate.sh new file mode 100755 index 0000000..0568666 --- /dev/null +++ b/tests/scripts/check-refactor-line-gate.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)" +TARGETS_FILE="$ROOT_DIR/plans/refactor-line-gate-targets.txt" + +DEFAULT_MAX=300 +ENTRY_MAX=120 + +is_entry_file() { + case "$1" in + api/chat-stream.js|\ + api/helpers/stream-tool-sieve.js|\ + webui/src/App.jsx|\ + webui/src/components/AccountManager.jsx|\ + webui/src/components/ApiTester.jsx|\ + webui/src/components/Settings.jsx|\ + webui/src/components/VercelSync.jsx) + return 0 + ;; + esac + return 1 +} + +if [[ ! -f "$TARGETS_FILE" ]]; then + echo "missing targets file: $TARGETS_FILE" >&2 + exit 1 +fi + +missing=0 +over=0 +checked=0 + +while IFS= read -r file; do + [[ -z "$file" ]] && continue + [[ "${file:0:1}" == "#" ]] && continue + + checked=$((checked + 1)) + abs="$ROOT_DIR/$file" + if [[ ! -f "$abs" ]]; then + echo "MISSING $file" + missing=$((missing + 1)) + continue + fi + + lines="$(wc -l < "$abs" | tr -d ' ')" + limit="$DEFAULT_MAX" + if is_entry_file "$file"; then + limit="$ENTRY_MAX" + fi + + if (( lines > limit )); then + echo "OVER $file lines=$lines limit=$limit" + over=$((over + 1)) + fi +done < "$TARGETS_FILE" + +echo "checked=$checked missing=$missing over_limit=$over" + +if (( missing > 0 || over > 0 )); then + exit 1 +fi diff --git a/webui/src/App.jsx b/webui/src/App.jsx index 2c3f099..8067b8b 100644 --- a/webui/src/App.jsx +++ b/webui/src/App.jsx @@ -1,346 +1,3 @@ -import { useState, useEffect, useCallback, useMemo } from 'react' -import { - Routes, - Route, - Navigate, - useNavigate, - useLocation -} from 'react-router-dom' -import { - LayoutDashboard, - Key, - Upload, - Cloud, - Settings as SettingsIcon, - LogOut, - Menu, - X, - Server, - Users -} from 'lucide-react' -import clsx from 'clsx' +import AppRoutes from './app/AppRoutes' -import AccountManager from './components/AccountManager' -import ApiTester from './components/ApiTester' -import BatchImport from './components/BatchImport' -import VercelSync from './components/VercelSync' -import Settings from './components/Settings' -import Login from './components/Login' -import LandingPage from './components/LandingPage' -import LanguageToggle from './components/LanguageToggle' -import { useI18n } from './i18n' -import { detectRuntimeEnv } from './utils/runtimeEnv' - -function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) { - const { t } = useI18n() - const [activeTab, setActiveTab] = useState('accounts') - const [sidebarOpen, setSidebarOpen] = useState(false) - - const navItems = [ - { id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') }, - { id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') }, - { id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') }, - { id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') }, - { id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') }, - ] - - const authFetch = useCallback(async (url, options = {}) => { - const headers = { - ...options.headers, - 'Authorization': `Bearer ${token}` - } - const res = await fetch(url, { ...options, headers }) - - if (res.status === 401) { - onLogout() - throw new Error(t('auth.expired')) - } - return res - }, [onLogout, t, token]) - - const renderTab = () => { - switch (activeTab) { - case 'accounts': - return - case 'test': - return - case 'import': - return - case 'vercel': - return - case 'settings': - return - default: - return null - } - } - - return ( -
- {sidebarOpen && ( -
setSidebarOpen(false)} - /> - )} - - - -
-
-
-
- -
- DS2API -
-
- - -
-
- -
-
-
-

- {navItems.find(n => n.id === activeTab)?.label} -

-

- {navItems.find(n => n.id === activeTab)?.description} -

-
- - {message && ( -
- {message.type === 'error' ? :
} - {message.text} -
- )} - -
- {renderTab()} -
-
-
-
-
- ) -} - -export default function App() { - const { t } = useI18n() - const navigate = useNavigate() - const location = useLocation() - const [config, setConfig] = useState({ keys: [], accounts: [] }) - const [message, setMessage] = useState(null) - const [token, setToken] = useState(null) - const [authChecking, setAuthChecking] = useState(true) - - const isProduction = import.meta.env.MODE === 'production' - const isAdminRoute = location.pathname.startsWith('/admin') || isProduction - const runtimeEnv = useMemo(() => detectRuntimeEnv(), []) - const isVercel = runtimeEnv.isVercel - - const showMessage = useCallback((type, text) => { - setMessage({ type, text }) - setTimeout(() => setMessage(null), 5000) - }, []) - - const handleLogout = useCallback(() => { - setToken(null) - localStorage.removeItem('ds2api_token') - localStorage.removeItem('ds2api_token_expires') - sessionStorage.removeItem('ds2api_token') - sessionStorage.removeItem('ds2api_token_expires') - }, []) - - useEffect(() => { - // Only check auth status on admin routes. - if (!isAdminRoute) { - setAuthChecking(false) - return - } - - const checkAuth = async () => { - const storedToken = localStorage.getItem('ds2api_token') || sessionStorage.getItem('ds2api_token') - const expiresAt = parseInt(localStorage.getItem('ds2api_token_expires') || sessionStorage.getItem('ds2api_token_expires') || '0') - - if (storedToken && expiresAt > Date.now()) { - try { - const res = await fetch('/admin/verify', { - headers: { 'Authorization': `Bearer ${storedToken}` } - }) - if (res.ok) { - setToken(storedToken) - } else { - handleLogout() - } - } catch { - setToken(storedToken) - } - } - setAuthChecking(false) - } - checkAuth() - }, [handleLogout, isAdminRoute]) - - const fetchConfig = useCallback(async () => { - if (!token) return - try { - const res = await fetch('/admin/config', { - headers: { 'Authorization': `Bearer ${token}` } - }) - if (res.ok) { - const data = await res.json() - setConfig(data) - } - } catch (e) { - console.error('Failed to fetch config:', e) - showMessage('error', t('errors.fetchConfig', { error: e.message })) - } - }, [showMessage, t, token]) - - useEffect(() => { - if (token) { - fetchConfig() - } - }, [fetchConfig, token]) - - const handleLogin = (newToken) => { - setToken(newToken) - } - - // Wait for auth checks on admin routes. - if (isAdminRoute && authChecking) { - return ( -
-
-
-

{t('auth.checking')}

-
-
- ) - } - - return ( - - {!isProduction && ( - navigate('/admin')} />} /> - )} - - ) : ( -
-
-
-
-
- - {message && ( -
- {message.text} -
- )} - -
- ) - } /> - } /> -
- ) -} +export default AppRoutes diff --git a/webui/src/app/AppRoutes.jsx b/webui/src/app/AppRoutes.jsx new file mode 100644 index 0000000..9a75dba --- /dev/null +++ b/webui/src/app/AppRoutes.jsx @@ -0,0 +1,84 @@ +import { Navigate, Route, Routes, useLocation, useNavigate } from 'react-router-dom' +import clsx from 'clsx' + +import LandingPage from '../components/LandingPage' +import Login from '../components/Login' +import DashboardShell from '../layout/DashboardShell' +import { useI18n } from '../i18n' +import { useAdminAuth } from './useAdminAuth' +import { useAdminConfig } from './useAdminConfig' + +export default function AppRoutes() { + const { t } = useI18n() + const navigate = useNavigate() + const location = useLocation() + + const isProduction = import.meta.env.MODE === 'production' + const { + token, + authChecking, + message, + isAdminRoute, + isVercel, + showMessage, + handleLogin, + handleLogout, + } = useAdminAuth({ isProduction, location, t }) + + const { + config, + fetchConfig, + } = useAdminConfig({ token, showMessage, t }) + + if (isAdminRoute && authChecking) { + return ( +
+
+
+

{t('auth.checking')}

+
+
+ ) + } + + return ( + + {!isProduction && ( + navigate('/admin')} />} /> + )} + + ) : ( +
+
+
+
+
+ + {message && ( +
+ {message.text} +
+ )} + +
+ ) + } /> + } /> +
+ ) +} diff --git a/webui/src/app/useAdminAuth.js b/webui/src/app/useAdminAuth.js new file mode 100644 index 0000000..2da2391 --- /dev/null +++ b/webui/src/app/useAdminAuth.js @@ -0,0 +1,70 @@ +import { useCallback, useEffect, useMemo, useState } from 'react' +import { detectRuntimeEnv } from '../utils/runtimeEnv' + +export function useAdminAuth({ isProduction, location, t }) { + const [message, setMessage] = useState(null) + const [token, setToken] = useState(null) + const [authChecking, setAuthChecking] = useState(true) + + const isAdminRoute = location.pathname.startsWith('/admin') || isProduction + const runtimeEnv = useMemo(() => detectRuntimeEnv(), []) + const isVercel = runtimeEnv.isVercel + + const showMessage = useCallback((type, text) => { + setMessage({ type, text }) + setTimeout(() => setMessage(null), 5000) + }, []) + + const handleLogout = useCallback(() => { + setToken(null) + localStorage.removeItem('ds2api_token') + localStorage.removeItem('ds2api_token_expires') + sessionStorage.removeItem('ds2api_token') + sessionStorage.removeItem('ds2api_token_expires') + }, []) + + const handleLogin = useCallback((newToken) => { + setToken(newToken) + }, []) + + useEffect(() => { + if (!isAdminRoute) { + setAuthChecking(false) + return + } + + const checkAuth = async () => { + const storedToken = localStorage.getItem('ds2api_token') || sessionStorage.getItem('ds2api_token') + const expiresAt = parseInt(localStorage.getItem('ds2api_token_expires') || sessionStorage.getItem('ds2api_token_expires') || '0') + + if (storedToken && expiresAt > Date.now()) { + try { + const res = await fetch('/admin/verify', { + headers: { 'Authorization': `Bearer ${storedToken}` } + }) + if (res.ok) { + setToken(storedToken) + } else { + handleLogout() + } + } catch { + setToken(storedToken) + } + } + setAuthChecking(false) + } + + checkAuth() + }, [handleLogout, isAdminRoute, t]) + + return { + token, + authChecking, + message, + isAdminRoute, + isVercel, + showMessage, + handleLogin, + handleLogout, + } +} diff --git a/webui/src/app/useAdminConfig.js b/webui/src/app/useAdminConfig.js new file mode 100644 index 0000000..3fa410d --- /dev/null +++ b/webui/src/app/useAdminConfig.js @@ -0,0 +1,32 @@ +import { useCallback, useEffect, useState } from 'react' + +export function useAdminConfig({ token, showMessage, t }) { + const [config, setConfig] = useState({ keys: [], accounts: [] }) + + const fetchConfig = useCallback(async () => { + if (!token) return + try { + const res = await fetch('/admin/config', { + headers: { 'Authorization': `Bearer ${token}` } + }) + if (res.ok) { + const data = await res.json() + setConfig(data) + } + } catch (e) { + console.error('Failed to fetch config:', e) + showMessage('error', t('errors.fetchConfig', { error: e.message })) + } + }, [showMessage, t, token]) + + useEffect(() => { + if (token) { + fetchConfig() + } + }, [fetchConfig, token]) + + return { + config, + fetchConfig, + } +} diff --git a/webui/src/components/AccountManager.jsx b/webui/src/components/AccountManager.jsx index 7ee3b97..2a37010 100644 --- a/webui/src/components/AccountManager.jsx +++ b/webui/src/components/AccountManager.jsx @@ -1,578 +1,3 @@ -import { useState, useEffect } from 'react' -import { - Plus, - Trash2, - CheckCircle2, - Play, - X, - Server, - ShieldCheck, - Copy, - Check, - ChevronLeft, - ChevronRight, - ChevronDown -} from 'lucide-react' -import clsx from 'clsx' -import { useI18n } from '../i18n' +import AccountManagerContainer from '../features/account/AccountManagerContainer' -export default function AccountManager({ config, onRefresh, onMessage, authFetch }) { - const { t } = useI18n() - const [showAddKey, setShowAddKey] = useState(false) - const [showAddAccount, setShowAddAccount] = useState(false) - const [newKey, setNewKey] = useState('') - const [copiedKey, setCopiedKey] = useState(null) - const [newAccount, setNewAccount] = useState({ email: '', mobile: '', password: '' }) - const [loading, setLoading] = useState(false) - const [testing, setTesting] = useState({}) - const [testingAll, setTestingAll] = useState(false) - const [batchProgress, setBatchProgress] = useState({ current: 0, total: 0, results: [] }) - const [queueStatus, setQueueStatus] = useState(null) - const [keysExpanded, setKeysExpanded] = useState(false) - - // 分页状态 - const [accounts, setAccounts] = useState([]) - const [page, setPage] = useState(1) - const [pageSize] = useState(10) - const [totalPages, setTotalPages] = useState(1) - const [totalAccounts, setTotalAccounts] = useState(0) - const [loadingAccounts, setLoadingAccounts] = useState(false) - - const apiFetch = authFetch || fetch - const resolveAccountIdentifier = (acc) => { - if (!acc || typeof acc !== 'object') return '' - return String(acc.identifier || acc.email || acc.mobile || '').trim() - } - - const fetchAccounts = async (targetPage = page) => { - setLoadingAccounts(true) - try { - const res = await apiFetch(`/admin/accounts?page=${targetPage}&page_size=${pageSize}`) - if (res.ok) { - const data = await res.json() - setAccounts(data.items || []) - setTotalPages(data.total_pages || 1) - setTotalAccounts(data.total || 0) - setPage(data.page || 1) - } - } catch (e) { - console.error('Failed to fetch accounts:', e) - } finally { - setLoadingAccounts(false) - } - } - - const fetchQueueStatus = async () => { - try { - const res = await apiFetch('/admin/queue/status') - if (res.ok) { - const data = await res.json() - setQueueStatus(data) - } - } catch (e) { - console.error('Failed to fetch queue status:', e) - } - } - - useEffect(() => { - fetchAccounts() - fetchQueueStatus() - const interval = setInterval(fetchQueueStatus, 5000) - return () => clearInterval(interval) - }, []) - - const addKey = async () => { - if (!newKey.trim()) return - setLoading(true) - try { - const res = await apiFetch('/admin/keys', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ key: newKey.trim() }), - }) - if (res.ok) { - onMessage('success', t('accountManager.addKeySuccess')) - setNewKey('') - setShowAddKey(false) - onRefresh() - } else { - const data = await res.json() - onMessage('error', data.detail || t('messages.failedToAdd')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } finally { - setLoading(false) - } - } - - const deleteKey = async (key) => { - if (!confirm(t('accountManager.deleteKeyConfirm'))) return - try { - const res = await apiFetch(`/admin/keys/${encodeURIComponent(key)}`, { method: 'DELETE' }) - if (res.ok) { - onMessage('success', t('messages.deleted')) - onRefresh() - } else { - onMessage('error', t('messages.deleteFailed')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } - } - - const addAccount = async () => { - if (!newAccount.password || (!newAccount.email && !newAccount.mobile)) { - onMessage('error', t('accountManager.requiredFields')) - return - } - setLoading(true) - try { - const res = await apiFetch('/admin/accounts', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify(newAccount), - }) - if (res.ok) { - onMessage('success', t('accountManager.addAccountSuccess')) - setNewAccount({ email: '', mobile: '', password: '' }) - setShowAddAccount(false) - fetchAccounts(1) // 添加后回到第一页 - onRefresh() - } else { - const data = await res.json() - onMessage('error', data.detail || t('messages.failedToAdd')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } finally { - setLoading(false) - } - } - - const deleteAccount = async (id) => { - const identifier = String(id || '').trim() - if (!identifier) { - onMessage('error', t('accountManager.invalidIdentifier')) - return - } - if (!confirm(t('accountManager.deleteAccountConfirm'))) return - try { - const res = await apiFetch(`/admin/accounts/${encodeURIComponent(identifier)}`, { method: 'DELETE' }) - if (res.ok) { - onMessage('success', t('messages.deleted')) - fetchAccounts() // 刷新当前页 - onRefresh() - } else { - onMessage('error', t('messages.deleteFailed')) - } - } catch (e) { - onMessage('error', t('messages.networkError')) - } - } - - const testAccount = async (identifier) => { - const accountID = String(identifier || '').trim() - if (!accountID) { - onMessage('error', t('accountManager.invalidIdentifier')) - return - } - setTesting(prev => ({ ...prev, [accountID]: true })) - try { - const res = await apiFetch('/admin/accounts/test', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier: accountID }), - }) - const data = await res.json() - const statusMessage = data.success - ? t('apiTester.testSuccess', { account: accountID, time: data.response_time }) - : `${accountID}: ${data.message}` - onMessage(data.success ? 'success' : 'error', statusMessage) - fetchAccounts() // 刷新当前页 - onRefresh() - } catch (e) { - onMessage('error', t('accountManager.testFailed', { error: e.message })) - } finally { - setTesting(prev => ({ ...prev, [accountID]: false })) - } - } - - const testAllAccounts = async () => { - if (!confirm(t('accountManager.testAllConfirm'))) return - const allAccounts = config.accounts || [] - if (allAccounts.length === 0) return - - setTestingAll(true) - setBatchProgress({ current: 0, total: allAccounts.length, results: [] }) - - let successCount = 0 - const results = [] - - for (let i = 0; i < allAccounts.length; i++) { - const acc = allAccounts[i] - const id = resolveAccountIdentifier(acc) - if (!id) { - results.push({ id: '-', success: false, message: t('accountManager.invalidIdentifier') }) - setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] }) - continue - } - - try { - const res = await apiFetch('/admin/accounts/test', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ identifier: id }), - }) - const data = await res.json() - results.push({ id, success: data.success, message: data.message, time: data.response_time }) - if (data.success) successCount++ - } catch (e) { - results.push({ id, success: false, message: e.message }) - } - - setBatchProgress({ current: i + 1, total: allAccounts.length, results: [...results] }) - } - - onMessage('success', t('accountManager.testAllCompleted', { success: successCount, total: allAccounts.length })) - fetchAccounts() // 刷新当前页 - onRefresh() - setTestingAll(false) - } - - return ( -
- {/* Queue Status - Flat & Clean */} - { - queueStatus && ( -
-
-
- -
-

{t('accountManager.available')}

-
- {queueStatus.available} - {t('accountManager.accountsUnit')} -
-
-
-
- -
-

{t('accountManager.inUse')}

-
- {queueStatus.in_use} - {t('accountManager.threadsUnit')} -
-
-
-
- -
-

{t('accountManager.totalPool')}

-
- {queueStatus.total} - {t('accountManager.accountsUnit')} -
-
-
- ) - } - - {/* API Keys Section */} -
-
setKeysExpanded(!keysExpanded)} - > -
- -
-

{t('accountManager.apiKeysTitle')}

-

{t('accountManager.apiKeysDesc')} ({config.keys?.length || 0})

-
-
- -
- - {keysExpanded && ( -
- {config.keys?.length > 0 ? ( - config.keys.map((key, i) => ( -
-
-
- {key.slice(0, 16)}**** -
- {copiedKey === key && ( - {t('accountManager.copied')} - )} -
-
- - -
-
- )) - ) : ( -
{t('accountManager.noApiKeys')}
- )} -
- )} -
- - {/* Accounts Section */} -
-
-
-

{t('accountManager.accountsTitle')}

-

{t('accountManager.accountsDesc')}

-
-
- - -
-
- - {/* Batch Progress */} - {testingAll && batchProgress.total > 0 && ( -
-
- {t('accountManager.testingAllAccounts')} - {batchProgress.current} / {batchProgress.total} -
-
-
-
- {batchProgress.results.length > 0 && ( -
- {batchProgress.results.map((r, i) => ( -
- {r.success ? '✓' : '✗'} {r.id} -
- ))} -
- )} -
- )} - -
- {loadingAccounts ? ( -
{t('actions.loading')}
- ) : accounts.length > 0 ? ( - accounts.map((acc, i) => { - const id = resolveAccountIdentifier(acc) - return ( -
-
-
-
-
{id || '-'}
-
- {acc.has_token ? t('accountManager.sessionActive') : t('accountManager.reauthRequired')} - {acc.token_preview && ( - - {acc.token_preview} - - )} -
-
-
-
- - -
-
- ) - }) - ) : ( -
{t('accountManager.noAccounts')}
- )} -
- - {/* 分页控件 */} - {totalPages > 1 && ( -
-
- {t('accountManager.pageInfo', { current: page, total: totalPages, count: totalAccounts })} -
-
- - {page} / {totalPages} - -
-
- )} -
- - {/* Modals */} - { - showAddKey && ( -
-
-
-

{t('accountManager.modalAddKeyTitle')}

- -
-
-
- -
- setNewKey(e.target.value)} - autoFocus - /> - -
-

{t('accountManager.generateHint')}

-
-
- - -
-
-
-
- ) - } - - { - showAddAccount && ( -
-
-
-

{t('accountManager.modalAddAccountTitle')}

- -
-
-
- - setNewAccount({ ...newAccount, email: e.target.value })} - /> -
-
- - setNewAccount({ ...newAccount, mobile: e.target.value })} - /> -
-
- - setNewAccount({ ...newAccount, password: e.target.value })} - /> -
-
- - -
-
-
-
- ) - } -
- ) -} +export default AccountManagerContainer diff --git a/webui/src/components/ApiTester.jsx b/webui/src/components/ApiTester.jsx index 75af1c0..b688195 100644 --- a/webui/src/components/ApiTester.jsx +++ b/webui/src/components/ApiTester.jsx @@ -1,447 +1,3 @@ -import { useEffect, useRef, useState } from 'react' -import { - Send, - Square, - MessageSquare, - Cpu, - Search as SearchIcon, - Sparkles, - Bot, - User, - Loader2, - CheckCircle2, - AlertCircle, - ChevronDown, - ShieldCheck, - Terminal, - Zap, - ToggleLeft, - ToggleRight -} from 'lucide-react' -import clsx from 'clsx' -import { useI18n } from '../i18n' +import ApiTesterContainer from '../features/apiTester/ApiTesterContainer' -export default function ApiTester({ config, onMessage, authFetch }) { - const { t } = useI18n() - const [model, setModel] = useState('deepseek-chat') - const defaultMessage = t('apiTester.defaultMessage') - const [message, setMessage] = useState(defaultMessage) - const [apiKey, setApiKey] = useState('') - const [selectedAccount, setSelectedAccount] = useState('') - const [response, setResponse] = useState(null) - const [loading, setLoading] = useState(false) - const [streamingContent, setStreamingContent] = useState('') - const [streamingThinking, setStreamingThinking] = useState('') - const [isStreaming, setIsStreaming] = useState(false) - const [streamingMode, setStreamingMode] = useState(true) - const abortControllerRef = useRef(null) - const defaultMessageRef = useRef(defaultMessage) - - const [sidebarOpen, setSidebarOpen] = useState(false) - const [configExpanded, setConfigExpanded] = useState(false) - - const apiFetch = authFetch || fetch - const accounts = config.accounts || [] - const resolveAccountIdentifier = (acc) => { - if (!acc || typeof acc !== 'object') return '' - return String(acc.identifier || acc.email || acc.mobile || '').trim() - } - const configuredKeys = config.keys || [] - const trimmedApiKey = apiKey.trim() - const defaultKey = configuredKeys[0] || '' - const effectiveKey = trimmedApiKey || defaultKey - const customKeyActive = trimmedApiKey !== '' - const customKeyManaged = customKeyActive && configuredKeys.includes(trimmedApiKey) - const models = [ - { id: "deepseek-chat", name: "deepseek-chat", icon: MessageSquare, desc: t('apiTester.models.chat'), color: "text-amber-500" }, - { id: "deepseek-reasoner", name: "deepseek-reasoner", icon: Cpu, desc: t('apiTester.models.reasoner'), color: "text-amber-600" }, - { id: "deepseek-chat-search", name: "deepseek-chat-search", icon: SearchIcon, desc: t('apiTester.models.chatSearch'), color: "text-cyan-500" }, - { id: "deepseek-reasoner-search", name: "deepseek-reasoner-search", icon: SearchIcon, desc: t('apiTester.models.reasonerSearch'), color: "text-cyan-600" }, - ] - - const stopGeneration = () => { - if (abortControllerRef.current) { - abortControllerRef.current.abort() - abortControllerRef.current = null - } - setLoading(false) - setIsStreaming(false) - } - - const extractErrorMessage = async (res) => { - let raw = '' - try { - raw = await res.text() - } catch { - return t('apiTester.requestFailed') - } - if (!raw) { - return t('apiTester.requestFailed') - } - try { - const data = JSON.parse(raw) - const fromErrorObject = data?.error?.message - const fromErrorString = typeof data?.error === 'string' ? data.error : '' - const detail = typeof data?.detail === 'string' ? data.detail : '' - const message = typeof data?.message === 'string' ? data.message : '' - return fromErrorObject || fromErrorString || detail || message || t('apiTester.requestFailed') - } catch { - return raw.length > 240 ? `${raw.slice(0, 240)}...` : raw - } - } - - const runTest = async () => { - if (loading) return - - const startedAt = Date.now() - - setLoading(true) - setIsStreaming(true) - setResponse(null) - setStreamingContent('') - setStreamingThinking('') - - abortControllerRef.current = new AbortController() - - try { - if (!effectiveKey) { - onMessage('error', t('apiTester.missingApiKey')) - setLoading(false) - setIsStreaming(false) - return - } - - const headers = { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${effectiveKey}`, - } - if (selectedAccount) { - headers['X-Ds2-Target-Account'] = selectedAccount - } - - const endpoint = streamingMode ? '/v1/chat/completions' : '/v1/chat/completions?__go=1' - const res = await fetch(endpoint, { - method: 'POST', - headers, - body: JSON.stringify({ - model, - messages: [{ role: 'user', content: message }], - stream: streamingMode, - }), - signal: abortControllerRef.current.signal, - }) - - if (!res.ok) { - const errorMsg = await extractErrorMessage(res) - setResponse({ success: false, error: errorMsg }) - onMessage('error', errorMsg) - setLoading(false) - setIsStreaming(false) - return - } - - if (streamingMode) { - setResponse({ success: true, status_code: res.status }) - - const reader = res.body.getReader() - const decoder = new TextDecoder() - let buffer = '' - - while (true) { - const { done, value } = await reader.read() - if (done) break - - buffer += decoder.decode(value, { stream: true }) - const lines = buffer.split('\n') - buffer = lines.pop() || '' - - for (const line of lines) { - const trimmed = line.trim() - if (!trimmed || !trimmed.startsWith('data: ')) continue - - const dataStr = trimmed.slice(6) - if (dataStr === '[DONE]') continue - - try { - const json = JSON.parse(dataStr) - const choice = json.choices?.[0] - if (choice?.delta) { - const delta = choice.delta - if (delta.reasoning_content) { - setStreamingThinking(prev => prev + delta.reasoning_content) - } - if (delta.content) { - setStreamingContent(prev => prev + delta.content) - } - } - } catch (e) { - console.error('Invalid JSON hunk:', dataStr, e) - } - } - } - } else { - const data = await res.json() - setResponse({ success: true, status_code: res.status, ...data }) - const elapsed = Math.max(0, Date.now() - startedAt) - onMessage('success', t('apiTester.testSuccess', { account: selectedAccount || 'Auto', time: elapsed })) - } - } catch (e) { - if (e.name === 'AbortError') { - onMessage('info', t('messages.generationStopped')) - } else { - onMessage('error', t('apiTester.networkError', { error: e.message })) - setResponse({ error: e.message, success: false }) - } - } finally { - setLoading(false) - setIsStreaming(false) - abortControllerRef.current = null - } - } - -useEffect(() => { - setMessage((prev) => (prev === defaultMessageRef.current ? defaultMessage : prev)) - defaultMessageRef.current = defaultMessage -}, [defaultMessage]) - -return ( -
- {/* Configuration Panel */} -
-
- {/* Mobile Toggle Header */} - - -
-
- -
- {models.map(m => { - const Icon = m.icon - return ( - - ) - })} -
-
- -
- - -
- -
- -
- - -
-
- -
- - setApiKey(e.target.value)} - /> - {customKeyActive && ( -

- {customKeyManaged ? t('apiTester.modeManaged') : t('apiTester.modeDirect')} -

- )} -
-
-
-
- - {/* Chat Interface */} -
- - {/* Messages Area */} -
- {/* User Message */} -
-
- -
-
-
- {message} -
-
-
- - {/* AI Response */} - {(response || isStreaming) && ( -
-
- -
-
-
- - DeepSeek - - {response && ( - - {response.status_code || t('apiTester.statusError')} - - )} -
- - {(streamingThinking || response?.choices?.[0]?.message?.reasoning_content) && ( -
-
- - {t('apiTester.reasoningTrace')} -
-
- {streamingThinking || response?.choices?.[0]?.message?.reasoning_content} -
-
- )} - -
- {streamingContent || response?.choices?.[0]?.message?.content || (response?.error && {response.error}) || (loading && {t('apiTester.generating')})} - {isStreaming && } -
-
-
- )} -
- - {/* Input Area */} -
-
-