- {navItems.find(n => n.id === activeTab)?.label} -
-- {navItems.find(n => n.id === activeTab)?.description} -
-diff --git a/api/chat-stream.js b/api/chat-stream.js
index c4a4cd1..32e7601 100644
--- a/api/chat-stream.js
+++ b/api/chat-stream.js
@@ -1,968 +1,3 @@
'use strict';
-const crypto = require('crypto');
-
-const {
- extractToolNames,
- createToolSieveState,
- processToolSieveChunk,
- flushToolSieve,
- parseToolCalls,
- formatOpenAIStreamToolCalls,
-} = require('./helpers/stream-tool-sieve');
-const {
- BASE_HEADERS,
- SKIP_PATTERNS,
- SKIP_EXACT_PATHS,
-} = require('./shared/deepseek-constants');
-
-const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion';
-
-module.exports = async function handler(req, res) {
- setCorsHeaders(res);
- if (req.method === 'OPTIONS') {
- res.statusCode = 204;
- res.end();
- return;
- }
- if (req.method !== 'POST') {
- writeOpenAIError(res, 405, 'method not allowed');
- return;
- }
-
- const rawBody = await readRawBody(req);
-
- // Hard guard: only use Node data path for streaming on Vercel runtime.
- // Any non-Vercel runtime always falls back to Go for full behavior parity.
- if (!isVercelRuntime()) {
- await proxyToGo(req, res, rawBody);
- return;
- }
-
- let payload;
- try {
- payload = JSON.parse(rawBody.toString('utf8') || '{}');
- } catch (_err) {
- writeOpenAIError(res, 400, 'invalid json');
- return;
- }
-
- // Keep all non-stream behavior on Go side to avoid compatibility regressions.
- if (!toBool(payload.stream)) {
- await proxyToGo(req, res, rawBody);
- return;
- }
-
- const prep = await fetchStreamPrepare(req, rawBody);
- if (!prep.ok) {
- relayPreparedFailure(res, prep);
- return;
- }
-
- const model = asString(prep.body.model) || asString(payload.model);
- const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`;
- const leaseID = asString(prep.body.lease_id);
- const deepseekToken = asString(prep.body.deepseek_token);
- const powHeader = asString(prep.body.pow_header);
- const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null;
- const finalPrompt = asString(prep.body.final_prompt);
- const thinkingEnabled = toBool(prep.body.thinking_enabled);
- const searchEnabled = toBool(prep.body.search_enabled);
- const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
- const toolNames = toolPolicy.toolNames;
-
- if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
- writeOpenAIError(res, 500, 'invalid vercel prepare response');
- return;
- }
- const releaseLease = createLeaseReleaser(req, leaseID);
- const upstreamController = new AbortController();
- let clientClosed = false;
- let reader = null;
- const markClientClosed = () => {
- if (clientClosed) {
- return;
- }
- clientClosed = true;
- upstreamController.abort();
- if (reader && typeof reader.cancel === 'function') {
- Promise.resolve(reader.cancel()).catch(() => {});
- }
- };
- const onReqAborted = () => markClientClosed();
- const onResClose = () => {
- if (!res.writableEnded) {
- markClientClosed();
- }
- };
- req.on('aborted', onReqAborted);
- res.on('close', onResClose);
- try {
- let completionRes;
- try {
- completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
- method: 'POST',
- headers: {
- ...BASE_HEADERS,
- authorization: `Bearer ${deepseekToken}`,
- 'x-ds-pow-response': powHeader,
- },
- body: JSON.stringify(completionPayload),
- signal: upstreamController.signal,
- });
- } catch (err) {
- if (clientClosed || isAbortError(err)) {
- return;
- }
- throw err;
- }
- if (clientClosed) {
- return;
- }
-
- if (!completionRes.ok || !completionRes.body) {
- const detail = await safeReadText(completionRes);
- writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.');
- return;
- }
-
- res.statusCode = 200;
- res.setHeader('Content-Type', 'text/event-stream');
- res.setHeader('Cache-Control', 'no-cache, no-transform');
- res.setHeader('Connection', 'keep-alive');
- res.setHeader('X-Accel-Buffering', 'no');
- if (typeof res.flushHeaders === 'function') {
- res.flushHeaders();
- }
-
- const created = Math.floor(Date.now() / 1000);
- let firstChunkSent = false;
- let currentType = thinkingEnabled ? 'thinking' : 'text';
- let thinkingText = '';
- let outputText = '';
- const toolSieveEnabled = toolPolicy.toolSieveEnabled;
- const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
- const toolSieveState = createToolSieveState();
- let toolCallsEmitted = false;
- const streamToolCallIDs = new Map();
- const decoder = new TextDecoder();
- reader = completionRes.body.getReader();
- let buffered = '';
- let ended = false;
-
- const sendFrame = (obj) => {
- if (clientClosed || res.writableEnded || res.destroyed) {
- return;
- }
- res.write(`data: ${JSON.stringify(obj)}\n\n`);
- if (typeof res.flush === 'function') {
- res.flush();
- }
- };
-
- const sendDeltaFrame = (delta) => {
- const payloadDelta = { ...delta };
- if (!firstChunkSent) {
- payloadDelta.role = 'assistant';
- firstChunkSent = true;
- }
- sendFrame({
- id: sessionID,
- object: 'chat.completion.chunk',
- created,
- model,
- choices: [{ delta: payloadDelta, index: 0 }],
- });
- };
-
- const finish = async (reason) => {
- if (ended) {
- return;
- }
- ended = true;
- if (clientClosed || res.writableEnded || res.destroyed) {
- await releaseLease();
- return;
- }
- const detected = parseToolCalls(outputText, toolNames);
- if (detected.length > 0 && !toolCallsEmitted) {
- toolCallsEmitted = true;
- sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
- } else if (toolSieveEnabled) {
- const tailEvents = flushToolSieve(toolSieveState, toolNames);
- for (const evt of tailEvents) {
- if (evt.text) {
- sendDeltaFrame({ content: evt.text });
- }
- }
- }
- if (detected.length > 0 || toolCallsEmitted) {
- reason = 'tool_calls';
- }
- sendFrame({
- id: sessionID,
- object: 'chat.completion.chunk',
- created,
- model,
- choices: [{ delta: {}, index: 0, finish_reason: reason }],
- usage: buildUsage(finalPrompt, thinkingText, outputText),
- });
- if (!res.writableEnded && !res.destroyed) {
- res.write('data: [DONE]\n\n');
- }
- await releaseLease();
- if (!res.writableEnded && !res.destroyed) {
- res.end();
- }
- };
-
- try {
- // eslint-disable-next-line no-constant-condition
- while (true) {
- if (clientClosed) {
- await finish('stop');
- return;
- }
- const { value, done } = await reader.read();
- if (done) {
- break;
- }
- buffered += decoder.decode(value, { stream: true });
- const lines = buffered.split('\n');
- buffered = lines.pop() || '';
-
- for (const rawLine of lines) {
- const line = rawLine.trim();
- if (!line.startsWith('data:')) {
- continue;
- }
- const dataStr = line.slice(5).trim();
- if (!dataStr) {
- continue;
- }
- if (dataStr === '[DONE]') {
- await finish('stop');
- return;
- }
- let chunk;
- try {
- chunk = JSON.parse(dataStr);
- } catch (_err) {
- continue;
- }
- if (chunk.error || chunk.code === 'content_filter') {
- await finish('content_filter');
- return;
- }
- const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType);
- currentType = parsed.newType;
- if (parsed.finished) {
- await finish('stop');
- return;
- }
-
- for (const p of parsed.parts) {
- if (!p.text) {
- continue;
- }
- if (searchEnabled && isCitation(p.text)) {
- continue;
- }
- if (p.type === 'thinking') {
- if (thinkingEnabled) {
- thinkingText += p.text;
- sendDeltaFrame({ reasoning_content: p.text });
- }
- } else {
- outputText += p.text;
- if (!toolSieveEnabled) {
- sendDeltaFrame({ content: p.text });
- continue;
- }
- const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
- for (const evt of events) {
- if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
- if (!emitEarlyToolDeltas) {
- continue;
- }
- toolCallsEmitted = true;
- sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
- continue;
- }
- if (evt.type === 'tool_calls') {
- toolCallsEmitted = true;
- sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
- continue;
- }
- if (evt.text) {
- sendDeltaFrame({ content: evt.text });
- }
- }
- }
- }
- }
- }
- await finish('stop');
- } catch (err) {
- if (clientClosed || isAbortError(err)) {
- await finish('stop');
- return;
- }
- await finish('stop');
- }
- } finally {
- req.removeListener('aborted', onReqAborted);
- res.removeListener('close', onResClose);
- await releaseLease();
- }
-};
-
-function setCorsHeaders(res) {
- res.setHeader('Access-Control-Allow-Origin', '*');
- res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
- res.setHeader(
- 'Access-Control-Allow-Headers',
- 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass',
- );
-}
-
-function header(req, key) {
- if (!req || !req.headers) {
- return '';
- }
- return asString(req.headers[key.toLowerCase()]);
-}
-
-async function readRawBody(req) {
- if (Buffer.isBuffer(req.body)) {
- return req.body;
- }
- if (typeof req.body === 'string') {
- return Buffer.from(req.body);
- }
- if (req.body && typeof req.body === 'object') {
- return Buffer.from(JSON.stringify(req.body));
- }
- const chunks = [];
- for await (const chunk of req) {
- chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
- }
- return Buffer.concat(chunks);
-}
-
-async function fetchStreamPrepare(req, rawBody) {
- const url = buildInternalGoURL(req);
- url.searchParams.set('__stream_prepare', '1');
-
- const upstream = await fetch(url.toString(), {
- method: 'POST',
- headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
- body: rawBody,
- });
-
- const text = await upstream.text();
- let body = {};
- try {
- body = JSON.parse(text || '{}');
- } catch (_err) {
- body = {};
- }
-
- return {
- ok: upstream.ok,
- status: upstream.status,
- contentType: upstream.headers.get('content-type') || 'application/json',
- text,
- body,
- };
-}
-
-function relayPreparedFailure(res, prep) {
- if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) {
- writeOpenAIError(
- res,
- 401,
- 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.',
- );
- return;
- }
- res.statusCode = prep.status || 500;
- res.setHeader('Content-Type', prep.contentType || 'application/json');
- if (prep.text) {
- res.end(prep.text);
- return;
- }
- writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
-}
-
-function resolveToolcallPolicy(prepBody, payloadTools) {
- const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
- const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
- const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
- const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
- return {
- toolNames,
- toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
- emitEarlyToolDeltas,
- };
-}
-
-function normalizePreparedToolNames(v) {
- if (!Array.isArray(v) || v.length === 0) {
- return [];
- }
- const out = [];
- for (const item of v) {
- const name = asString(item);
- if (!name) {
- continue;
- }
- out.push(name);
- }
- return out;
-}
-
-function boolDefaultTrue(v) {
- return v !== false;
-}
-
-async function safeReadText(resp) {
- if (!resp) {
- return '';
- }
- try {
- const text = await resp.text();
- return text.trim();
- } catch (_err) {
- return '';
- }
-}
-
-function internalSecret() {
- return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin';
-}
-
-function buildInternalGoURL(req) {
- const proto = asString(header(req, 'x-forwarded-proto')) || 'https';
- const host = asString(header(req, 'host'));
- const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`);
- url.searchParams.set('__go', '1');
- const protectionBypass = resolveProtectionBypass(req);
- if (protectionBypass) {
- url.searchParams.set('x-vercel-protection-bypass', protectionBypass);
- }
- return url;
-}
-
-function buildInternalGoHeaders(req, opts = {}) {
- const headers = {
- authorization: asString(header(req, 'authorization')),
- 'x-api-key': asString(header(req, 'x-api-key')),
- 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')),
- 'x-vercel-protection-bypass': resolveProtectionBypass(req),
- };
- if (opts.withInternalToken) {
- headers['x-ds2-internal-token'] = internalSecret();
- }
- if (opts.withContentType) {
- headers['content-type'] = asString(header(req, 'content-type')) || 'application/json';
- }
- return headers;
-}
-
-function createLeaseReleaser(req, leaseID) {
- let released = false;
- return async () => {
- if (released || !leaseID) {
- return;
- }
- released = true;
- try {
- await releaseStreamLease(req, leaseID);
- } catch (_err) {
- // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks.
- }
- };
-}
-
-async function releaseStreamLease(req, leaseID) {
- const url = buildInternalGoURL(req);
- url.searchParams.set('__stream_release', '1');
- const body = Buffer.from(JSON.stringify({ lease_id: leaseID }));
-
- const controller = new AbortController();
- const timeout = setTimeout(() => controller.abort(), 1500);
- try {
- await fetch(url.toString(), {
- method: 'POST',
- headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
- body,
- signal: controller.signal,
- });
- } finally {
- clearTimeout(timeout);
- }
-}
-
-function resolveProtectionBypass(req) {
- const fromHeader = asString(header(req, 'x-vercel-protection-bypass'));
- if (fromHeader) {
- return fromHeader;
- }
- return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS);
-}
-
-function looksLikeVercelAuthPage(text) {
- const body = asString(text).toLowerCase();
- if (!body) {
- return false;
- }
- return body.includes('authentication required') && body.includes('vercel');
-}
-
-function parseChunkForContent(chunk, thinkingEnabled, currentType) {
- if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) {
- return { parts: [], finished: false, newType: currentType };
- }
- const pathValue = asString(chunk.p);
- if (shouldSkipPath(pathValue)) {
- return { parts: [], finished: false, newType: currentType };
- }
- if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') {
- return { parts: [], finished: true, newType: currentType };
- }
-
- let newType = currentType;
- const parts = [];
-
- if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) {
- for (const frag of chunk.v) {
- if (!frag || typeof frag !== 'object') {
- continue;
- }
- const fragType = asString(frag.type).toUpperCase();
- const content = asString(frag.content);
- if (!content) {
- continue;
- }
- if (fragType === 'THINK' || fragType === 'THINKING') {
- newType = 'thinking';
- parts.push({ text: content, type: 'thinking' });
- } else if (fragType === 'RESPONSE') {
- newType = 'text';
- parts.push({ text: content, type: 'text' });
- } else {
- parts.push({ text: content, type: 'text' });
- }
- }
- }
-
- if (pathValue === 'response' && Array.isArray(chunk.v)) {
- for (const item of chunk.v) {
- if (!item || typeof item !== 'object') {
- continue;
- }
- if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) {
- for (const frag of item.v) {
- const fragType = asString(frag && frag.type).toUpperCase();
- if (fragType === 'THINK' || fragType === 'THINKING') {
- newType = 'thinking';
- } else if (fragType === 'RESPONSE') {
- newType = 'text';
- }
- }
- }
- }
- }
-
- let partType = 'text';
- if (pathValue === 'response/thinking_content') {
- partType = 'thinking';
- } else if (pathValue === 'response/content') {
- partType = 'text';
- } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) {
- partType = newType;
- } else if (!pathValue && thinkingEnabled) {
- partType = newType;
- }
-
- const val = chunk.v;
- if (typeof val === 'string') {
- if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) {
- return { parts: [], finished: true, newType };
- }
- if (val) {
- parts.push({ text: val, type: partType });
- }
- return { parts, finished: false, newType };
- }
-
- if (Array.isArray(val)) {
- const extracted = extractContentRecursive(val, partType);
- if (extracted.finished) {
- return { parts: [], finished: true, newType };
- }
- parts.push(...extracted.parts);
- return { parts, finished: false, newType };
- }
-
- if (val && typeof val === 'object') {
- const resp = val.response && typeof val.response === 'object' ? val.response : val;
- if (Array.isArray(resp.fragments)) {
- for (const frag of resp.fragments) {
- if (!frag || typeof frag !== 'object') {
- continue;
- }
- const content = asString(frag.content);
- if (!content) {
- continue;
- }
- const t = asString(frag.type).toUpperCase();
- if (t === 'THINK' || t === 'THINKING') {
- newType = 'thinking';
- parts.push({ text: content, type: 'thinking' });
- } else if (t === 'RESPONSE') {
- newType = 'text';
- parts.push({ text: content, type: 'text' });
- } else {
- parts.push({ text: content, type: partType });
- }
- }
- }
- }
- return { parts, finished: false, newType };
-}
-
-function extractContentRecursive(items, defaultType) {
- const parts = [];
- for (const it of items) {
- if (!it || typeof it !== 'object') {
- continue;
- }
- if (!Object.prototype.hasOwnProperty.call(it, 'v')) {
- continue;
- }
- const itemPath = asString(it.p);
- const itemV = it.v;
- if (itemPath === 'status' && asString(itemV) === 'FINISHED') {
- return { parts: [], finished: true };
- }
- if (shouldSkipPath(itemPath)) {
- continue;
- }
- const content = asString(it.content);
- if (content) {
- const typeName = asString(it.type).toUpperCase();
- if (typeName === 'THINK' || typeName === 'THINKING') {
- parts.push({ text: content, type: 'thinking' });
- } else if (typeName === 'RESPONSE') {
- parts.push({ text: content, type: 'text' });
- } else {
- parts.push({ text: content, type: defaultType });
- }
- continue;
- }
-
- let partType = defaultType;
- if (itemPath.includes('thinking')) {
- partType = 'thinking';
- } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') {
- partType = 'text';
- }
-
- if (typeof itemV === 'string') {
- if (itemV && itemV !== 'FINISHED') {
- parts.push({ text: itemV, type: partType });
- }
- continue;
- }
-
- if (!Array.isArray(itemV)) {
- continue;
- }
- for (const inner of itemV) {
- if (typeof inner === 'string') {
- if (inner) {
- parts.push({ text: inner, type: partType });
- }
- continue;
- }
- if (!inner || typeof inner !== 'object') {
- continue;
- }
- const ct = asString(inner.content);
- if (!ct) {
- continue;
- }
- const typeName = asString(inner.type).toUpperCase();
- if (typeName === 'THINK' || typeName === 'THINKING') {
- parts.push({ text: ct, type: 'thinking' });
- } else if (typeName === 'RESPONSE') {
- parts.push({ text: ct, type: 'text' });
- } else {
- parts.push({ text: ct, type: partType });
- }
- }
- }
- return { parts, finished: false };
-}
-
-function shouldSkipPath(pathValue) {
- if (SKIP_EXACT_PATHS.has(pathValue)) {
- return true;
- }
- for (const p of SKIP_PATTERNS) {
- if (pathValue.includes(p)) {
- return true;
- }
- }
- return false;
-}
-
-function isCitation(text) {
- return asString(text).trim().startsWith('[citation:');
-}
-
-function buildUsage(prompt, thinking, output) {
- const promptTokens = estimateTokens(prompt);
- const reasoningTokens = estimateTokens(thinking);
- const completionTokens = estimateTokens(output);
- return {
- prompt_tokens: promptTokens,
- completion_tokens: reasoningTokens + completionTokens,
- total_tokens: promptTokens + reasoningTokens + completionTokens,
- completion_tokens_details: {
- reasoning_tokens: reasoningTokens,
- },
- };
-}
-
-function formatIncrementalToolCallDeltas(deltas, idStore) {
- if (!Array.isArray(deltas) || deltas.length === 0) {
- return [];
- }
- const out = [];
- for (const d of deltas) {
- if (!d || typeof d !== 'object') {
- continue;
- }
- const index = Number.isInteger(d.index) ? d.index : 0;
- const id = ensureStreamToolCallID(idStore, index);
- const item = {
- index,
- id,
- type: 'function',
- };
- const fn = {};
- if (asString(d.name)) {
- fn.name = asString(d.name);
- }
- if (typeof d.arguments === 'string' && d.arguments !== '') {
- fn.arguments = d.arguments;
- }
- if (Object.keys(fn).length > 0) {
- item.function = fn;
- }
- out.push(item);
- }
- return out;
-}
-
-function ensureStreamToolCallID(idStore, index) {
- const key = Number.isInteger(index) ? index : 0;
- const existing = idStore.get(key);
- if (existing) {
- return existing;
- }
- const next = `call_${newCallID()}`;
- idStore.set(key, next);
- return next;
-}
-
-function newCallID() {
- if (typeof crypto.randomUUID === 'function') {
- return crypto.randomUUID().replace(/-/g, '');
- }
- return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
-}
-
-function estimateTokens(text) {
- const t = asString(text);
- if (!t) {
- return 0;
- }
- let asciiChars = 0;
- let nonASCIIChars = 0;
- for (const ch of Array.from(t)) {
- if (ch.charCodeAt(0) < 128) {
- asciiChars += 1;
- } else {
- nonASCIIChars += 1;
- }
- }
- const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13);
- return n < 1 ? 1 : n;
-}
-
-async function proxyToGo(req, res, rawBody) {
- const url = buildInternalGoURL(req);
- const controller = new AbortController();
- let clientClosed = false;
- const markClientClosed = () => {
- if (clientClosed) {
- return;
- }
- clientClosed = true;
- controller.abort();
- };
- const onReqAborted = () => markClientClosed();
- const onResClose = () => {
- if (!res.writableEnded) {
- markClientClosed();
- }
- };
- req.on('aborted', onReqAborted);
- res.on('close', onResClose);
-
- try {
- let upstream;
- try {
- upstream = await fetch(url.toString(), {
- method: 'POST',
- headers: buildInternalGoHeaders(req, { withContentType: true }),
- body: rawBody,
- signal: controller.signal,
- });
- } catch (err) {
- if (clientClosed || isAbortError(err)) {
- if (!res.writableEnded) {
- res.end();
- }
- return;
- }
- throw err;
- }
- if (clientClosed) {
- if (!res.writableEnded) {
- res.end();
- }
- return;
- }
-
- res.statusCode = upstream.status;
- upstream.headers.forEach((value, key) => {
- if (key.toLowerCase() === 'content-length') {
- return;
- }
- res.setHeader(key, value);
- });
-
- if (!upstream.body || typeof upstream.body.getReader !== 'function') {
- const bytes = Buffer.from(await upstream.arrayBuffer());
- res.end(bytes);
- return;
- }
-
- const reader = upstream.body.getReader();
- try {
- // eslint-disable-next-line no-constant-condition
- while (true) {
- if (clientClosed) {
- break;
- }
- const { value, done } = await reader.read();
- if (done) {
- break;
- }
- if (value && value.length > 0) {
- res.write(Buffer.from(value));
- if (typeof res.flush === 'function') {
- res.flush();
- }
- }
- }
- if (!res.writableEnded) {
- res.end();
- }
- } catch (err) {
- if (!isAbortError(err) && !res.writableEnded) {
- res.end();
- }
- }
- } finally {
- req.removeListener('aborted', onReqAborted);
- res.removeListener('close', onResClose);
- if (!res.writableEnded) {
- res.end();
- }
- }
-}
-
-function writeOpenAIError(res, status, message) {
- res.statusCode = status;
- res.setHeader('Content-Type', 'application/json');
- res.end(
- JSON.stringify({
- error: {
- message,
- type: openAIErrorType(status),
- },
- }),
- );
-}
-
-function openAIErrorType(status) {
- switch (status) {
- case 400:
- return 'invalid_request_error';
- case 401:
- return 'authentication_error';
- case 403:
- return 'permission_error';
- case 429:
- return 'rate_limit_error';
- case 503:
- return 'service_unavailable_error';
- default:
- return status >= 500 ? 'api_error' : 'invalid_request_error';
- }
-}
-
-function toBool(v) {
- return v === true;
-}
-
-function isVercelRuntime() {
- return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== '';
-}
-
-function asString(v) {
- if (typeof v === 'string') {
- return v.trim();
- }
- if (Array.isArray(v)) {
- return asString(v[0]);
- }
- if (v == null) {
- return '';
- }
- return String(v).trim();
-}
-
-function isAbortError(err) {
- if (!err || typeof err !== 'object') {
- return false;
- }
- return err.name === 'AbortError' || err.code === 'ABORT_ERR';
-}
-
-module.exports.__test = {
- parseChunkForContent,
- extractContentRecursive,
- shouldSkipPath,
- asString,
- resolveToolcallPolicy,
- normalizePreparedToolNames,
- boolDefaultTrue,
- estimateTokens,
-};
+module.exports = require('./chat-stream/index.js');
diff --git a/api/chat-stream/error_shape.js b/api/chat-stream/error_shape.js
new file mode 100644
index 0000000..18aeedb
--- /dev/null
+++ b/api/chat-stream/error_shape.js
@@ -0,0 +1,36 @@
+'use strict';
+
+function writeOpenAIError(res, status, message) {
+ res.statusCode = status;
+ res.setHeader('Content-Type', 'application/json');
+ res.end(
+ JSON.stringify({
+ error: {
+ message,
+ type: openAIErrorType(status),
+ },
+ }),
+ );
+}
+
+function openAIErrorType(status) {
+ switch (status) {
+ case 400:
+ return 'invalid_request_error';
+ case 401:
+ return 'authentication_error';
+ case 403:
+ return 'permission_error';
+ case 429:
+ return 'rate_limit_error';
+ case 503:
+ return 'service_unavailable_error';
+ default:
+ return status >= 500 ? 'api_error' : 'invalid_request_error';
+ }
+}
+
+module.exports = {
+ writeOpenAIError,
+ openAIErrorType,
+};
diff --git a/api/chat-stream/http_internal.js b/api/chat-stream/http_internal.js
new file mode 100644
index 0000000..20f24c8
--- /dev/null
+++ b/api/chat-stream/http_internal.js
@@ -0,0 +1,214 @@
+'use strict';
+
+const {
+ writeOpenAIError,
+} = require('./error_shape');
+
+function setCorsHeaders(res) {
+ res.setHeader('Access-Control-Allow-Origin', '*');
+ res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
+ res.setHeader(
+ 'Access-Control-Allow-Headers',
+ 'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass',
+ );
+}
+
+function header(req, key) {
+ if (!req || !req.headers) {
+ return '';
+ }
+ return asString(req.headers[key.toLowerCase()]);
+}
+
+async function readRawBody(req) {
+ if (Buffer.isBuffer(req.body)) {
+ return req.body;
+ }
+ if (typeof req.body === 'string') {
+ return Buffer.from(req.body);
+ }
+ if (req.body && typeof req.body === 'object') {
+ return Buffer.from(JSON.stringify(req.body));
+ }
+ const chunks = [];
+ for await (const chunk of req) {
+ chunks.push(Buffer.isBuffer(chunk) ? chunk : Buffer.from(chunk));
+ }
+ return Buffer.concat(chunks);
+}
+
+async function fetchStreamPrepare(req, rawBody) {
+ const url = buildInternalGoURL(req);
+ url.searchParams.set('__stream_prepare', '1');
+
+ const upstream = await fetch(url.toString(), {
+ method: 'POST',
+ headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
+ body: rawBody,
+ });
+
+ const text = await upstream.text();
+ let body = {};
+ try {
+ body = JSON.parse(text || '{}');
+ } catch (_err) {
+ body = {};
+ }
+
+ return {
+ ok: upstream.ok,
+ status: upstream.status,
+ contentType: upstream.headers.get('content-type') || 'application/json',
+ text,
+ body,
+ };
+}
+
+function relayPreparedFailure(res, prep) {
+ if (prep.status === 401 && looksLikeVercelAuthPage(prep.text)) {
+ writeOpenAIError(
+ res,
+ 401,
+ 'Vercel Deployment Protection blocked internal prepare request. Disable protection for this deployment or set VERCEL_AUTOMATION_BYPASS_SECRET.',
+ );
+ return;
+ }
+ res.statusCode = prep.status || 500;
+ res.setHeader('Content-Type', prep.contentType || 'application/json');
+ if (prep.text) {
+ res.end(prep.text);
+ return;
+ }
+ writeOpenAIError(res, prep.status || 500, 'vercel prepare failed');
+}
+
+async function safeReadText(resp) {
+ if (!resp) {
+ return '';
+ }
+ try {
+ const text = await resp.text();
+ return text.trim();
+ } catch (_err) {
+ return '';
+ }
+}
+
+function internalSecret() {
+ return asString(process.env.DS2API_VERCEL_INTERNAL_SECRET) || asString(process.env.DS2API_ADMIN_KEY) || 'admin';
+}
+
+function buildInternalGoURL(req) {
+ const proto = asString(header(req, 'x-forwarded-proto')) || 'https';
+ const host = asString(header(req, 'host'));
+ const url = new URL(`${proto}://${host}${req.url || '/v1/chat/completions'}`);
+ url.searchParams.set('__go', '1');
+ const protectionBypass = resolveProtectionBypass(req);
+ if (protectionBypass) {
+ url.searchParams.set('x-vercel-protection-bypass', protectionBypass);
+ }
+ return url;
+}
+
+function buildInternalGoHeaders(req, opts = {}) {
+ const headers = {
+ authorization: asString(header(req, 'authorization')),
+ 'x-api-key': asString(header(req, 'x-api-key')),
+ 'x-ds2-target-account': asString(header(req, 'x-ds2-target-account')),
+ 'x-vercel-protection-bypass': resolveProtectionBypass(req),
+ };
+ if (opts.withInternalToken) {
+ headers['x-ds2-internal-token'] = internalSecret();
+ }
+ if (opts.withContentType) {
+ headers['content-type'] = asString(header(req, 'content-type')) || 'application/json';
+ }
+ return headers;
+}
+
+function createLeaseReleaser(req, leaseID) {
+ let released = false;
+ return async () => {
+ if (released || !leaseID) {
+ return;
+ }
+ released = true;
+ try {
+ await releaseStreamLease(req, leaseID);
+ } catch (_err) {
+ // Ignore release errors. Lease TTL cleanup on Go side still prevents permanent leaks.
+ }
+ };
+}
+
+async function releaseStreamLease(req, leaseID) {
+ const url = buildInternalGoURL(req);
+ url.searchParams.set('__stream_release', '1');
+ const body = Buffer.from(JSON.stringify({ lease_id: leaseID }));
+
+ const controller = new AbortController();
+ const timeout = setTimeout(() => controller.abort(), 1500);
+ try {
+ await fetch(url.toString(), {
+ method: 'POST',
+ headers: buildInternalGoHeaders(req, { withInternalToken: true, withContentType: true }),
+ body,
+ signal: controller.signal,
+ });
+ } finally {
+ clearTimeout(timeout);
+ }
+}
+
+function resolveProtectionBypass(req) {
+ const fromHeader = asString(header(req, 'x-vercel-protection-bypass'));
+ if (fromHeader) {
+ return fromHeader;
+ }
+ return asString(process.env.VERCEL_AUTOMATION_BYPASS_SECRET) || asString(process.env.DS2API_VERCEL_PROTECTION_BYPASS);
+}
+
+function looksLikeVercelAuthPage(text) {
+ const body = asString(text).toLowerCase();
+ if (!body) {
+ return false;
+ }
+ return body.includes('authentication required') && body.includes('vercel');
+}
+
+function asString(v) {
+ if (typeof v === 'string') {
+ return v.trim();
+ }
+ if (Array.isArray(v)) {
+ return asString(v[0]);
+ }
+ if (v == null) {
+ return '';
+ }
+ return String(v).trim();
+}
+
+function isAbortError(err) {
+ if (!err || typeof err !== 'object') {
+ return false;
+ }
+ return err.name === 'AbortError' || err.code === 'ABORT_ERR';
+}
+
+module.exports = {
+ setCorsHeaders,
+ header,
+ readRawBody,
+ fetchStreamPrepare,
+ relayPreparedFailure,
+ safeReadText,
+ buildInternalGoURL,
+ buildInternalGoHeaders,
+ createLeaseReleaser,
+ releaseStreamLease,
+ resolveProtectionBypass,
+ looksLikeVercelAuthPage,
+ asString,
+ isAbortError,
+};
diff --git a/api/chat-stream/index.js b/api/chat-stream/index.js
new file mode 100644
index 0000000..4528924
--- /dev/null
+++ b/api/chat-stream/index.js
@@ -0,0 +1,88 @@
+'use strict';
+
+const {
+ writeOpenAIError,
+} = require('./error_shape');
+const {
+ parseChunkForContent,
+ extractContentRecursive,
+ shouldSkipPath,
+} = require('./sse_parse');
+const {
+ resolveToolcallPolicy,
+ normalizePreparedToolNames,
+ boolDefaultTrue,
+} = require('./toolcall_policy');
+const {
+ estimateTokens,
+} = require('./token_usage');
+const {
+ setCorsHeaders,
+ readRawBody,
+ asString,
+} = require('./http_internal');
+const {
+ proxyToGo,
+} = require('./proxy_go');
+const {
+ handleVercelStream,
+} = require('./vercel_stream');
+
+async function handler(req, res) {
+ setCorsHeaders(res);
+ if (req.method === 'OPTIONS') {
+ res.statusCode = 204;
+ res.end();
+ return;
+ }
+ if (req.method !== 'POST') {
+ writeOpenAIError(res, 405, 'method not allowed');
+ return;
+ }
+
+ const rawBody = await readRawBody(req);
+
+ // Hard guard: only use Node data path for streaming on Vercel runtime.
+ // Any non-Vercel runtime always falls back to Go for full behavior parity.
+ if (!isVercelRuntime()) {
+ await proxyToGo(req, res, rawBody);
+ return;
+ }
+
+ let payload;
+ try {
+ payload = JSON.parse(rawBody.toString('utf8') || '{}');
+ } catch (_err) {
+ writeOpenAIError(res, 400, 'invalid json');
+ return;
+ }
+
+ // Keep all non-stream behavior on Go side to avoid compatibility regressions.
+ if (!toBool(payload.stream)) {
+ await proxyToGo(req, res, rawBody);
+ return;
+ }
+
+ await handleVercelStream(req, res, rawBody, payload);
+}
+
+function toBool(v) {
+ return v === true;
+}
+
+function isVercelRuntime() {
+ return asString(process.env.VERCEL) !== '' || asString(process.env.NOW_REGION) !== '';
+}
+
+module.exports = handler;
+
+module.exports.__test = {
+ parseChunkForContent,
+ extractContentRecursive,
+ shouldSkipPath,
+ asString,
+ resolveToolcallPolicy,
+ normalizePreparedToolNames,
+ boolDefaultTrue,
+ estimateTokens,
+};
diff --git a/api/chat-stream/proxy_go.js b/api/chat-stream/proxy_go.js
new file mode 100644
index 0000000..5218df0
--- /dev/null
+++ b/api/chat-stream/proxy_go.js
@@ -0,0 +1,105 @@
+'use strict';
+
+const {
+ buildInternalGoURL,
+ buildInternalGoHeaders,
+ isAbortError,
+} = require('./http_internal');
+
+async function proxyToGo(req, res, rawBody) {
+ const url = buildInternalGoURL(req);
+ const controller = new AbortController();
+ let clientClosed = false;
+ const markClientClosed = () => {
+ if (clientClosed) {
+ return;
+ }
+ clientClosed = true;
+ controller.abort();
+ };
+ const onReqAborted = () => markClientClosed();
+ const onResClose = () => {
+ if (!res.writableEnded) {
+ markClientClosed();
+ }
+ };
+ req.on('aborted', onReqAborted);
+ res.on('close', onResClose);
+
+ try {
+ let upstream;
+ try {
+ upstream = await fetch(url.toString(), {
+ method: 'POST',
+ headers: buildInternalGoHeaders(req, { withContentType: true }),
+ body: rawBody,
+ signal: controller.signal,
+ });
+ } catch (err) {
+ if (clientClosed || isAbortError(err)) {
+ if (!res.writableEnded) {
+ res.end();
+ }
+ return;
+ }
+ throw err;
+ }
+ if (clientClosed) {
+ if (!res.writableEnded) {
+ res.end();
+ }
+ return;
+ }
+
+ res.statusCode = upstream.status;
+ upstream.headers.forEach((value, key) => {
+ if (key.toLowerCase() === 'content-length') {
+ return;
+ }
+ res.setHeader(key, value);
+ });
+
+ if (!upstream.body || typeof upstream.body.getReader !== 'function') {
+ const bytes = Buffer.from(await upstream.arrayBuffer());
+ res.end(bytes);
+ return;
+ }
+
+ const reader = upstream.body.getReader();
+ try {
+ // eslint-disable-next-line no-constant-condition
+ while (true) {
+ if (clientClosed) {
+ break;
+ }
+ const { value, done } = await reader.read();
+ if (done) {
+ break;
+ }
+ if (value && value.length > 0) {
+ res.write(Buffer.from(value));
+ if (typeof res.flush === 'function') {
+ res.flush();
+ }
+ }
+ }
+ if (!res.writableEnded) {
+ res.end();
+ }
+ } catch (err) {
+ if (!isAbortError(err) && !res.writableEnded) {
+ res.end();
+ }
+ }
+ } finally {
+ req.removeListener('aborted', onReqAborted);
+ res.removeListener('close', onResClose);
+ if (!res.writableEnded) {
+ res.end();
+ }
+ }
+}
+
+module.exports = {
+ proxyToGo,
+};
diff --git a/api/chat-stream/sse_parse.js b/api/chat-stream/sse_parse.js
new file mode 100644
index 0000000..1774430
--- /dev/null
+++ b/api/chat-stream/sse_parse.js
@@ -0,0 +1,229 @@
+'use strict';
+
+const {
+ SKIP_PATTERNS,
+ SKIP_EXACT_PATHS,
+} = require('../shared/deepseek-constants');
+
+function parseChunkForContent(chunk, thinkingEnabled, currentType) {
+ if (!chunk || typeof chunk !== 'object' || !Object.prototype.hasOwnProperty.call(chunk, 'v')) {
+ return { parts: [], finished: false, newType: currentType };
+ }
+ const pathValue = asString(chunk.p);
+ if (shouldSkipPath(pathValue)) {
+ return { parts: [], finished: false, newType: currentType };
+ }
+ if (pathValue === 'response/status' && asString(chunk.v) === 'FINISHED') {
+ return { parts: [], finished: true, newType: currentType };
+ }
+
+ let newType = currentType;
+ const parts = [];
+
+ if (pathValue === 'response/fragments' && asString(chunk.o).toUpperCase() === 'APPEND' && Array.isArray(chunk.v)) {
+ for (const frag of chunk.v) {
+ if (!frag || typeof frag !== 'object') {
+ continue;
+ }
+ const fragType = asString(frag.type).toUpperCase();
+ const content = asString(frag.content);
+ if (!content) {
+ continue;
+ }
+ if (fragType === 'THINK' || fragType === 'THINKING') {
+ newType = 'thinking';
+ parts.push({ text: content, type: 'thinking' });
+ } else if (fragType === 'RESPONSE') {
+ newType = 'text';
+ parts.push({ text: content, type: 'text' });
+ } else {
+ parts.push({ text: content, type: 'text' });
+ }
+ }
+ }
+
+ if (pathValue === 'response' && Array.isArray(chunk.v)) {
+ for (const item of chunk.v) {
+ if (!item || typeof item !== 'object') {
+ continue;
+ }
+ if (item.p === 'fragments' && item.o === 'APPEND' && Array.isArray(item.v)) {
+ for (const frag of item.v) {
+ const fragType = asString(frag && frag.type).toUpperCase();
+ if (fragType === 'THINK' || fragType === 'THINKING') {
+ newType = 'thinking';
+ } else if (fragType === 'RESPONSE') {
+ newType = 'text';
+ }
+ }
+ }
+ }
+ }
+
+ let partType = 'text';
+ if (pathValue === 'response/thinking_content') {
+ partType = 'thinking';
+ } else if (pathValue === 'response/content') {
+ partType = 'text';
+ } else if (pathValue.includes('response/fragments') && pathValue.includes('/content')) {
+ partType = newType;
+ } else if (!pathValue && thinkingEnabled) {
+ partType = newType;
+ }
+
+ const val = chunk.v;
+ if (typeof val === 'string') {
+ if (val === 'FINISHED' && (!pathValue || pathValue === 'status')) {
+ return { parts: [], finished: true, newType };
+ }
+ if (val) {
+ parts.push({ text: val, type: partType });
+ }
+ return { parts, finished: false, newType };
+ }
+
+ if (Array.isArray(val)) {
+ const extracted = extractContentRecursive(val, partType);
+ if (extracted.finished) {
+ return { parts: [], finished: true, newType };
+ }
+ parts.push(...extracted.parts);
+ return { parts, finished: false, newType };
+ }
+
+ if (val && typeof val === 'object') {
+ const resp = val.response && typeof val.response === 'object' ? val.response : val;
+ if (Array.isArray(resp.fragments)) {
+ for (const frag of resp.fragments) {
+ if (!frag || typeof frag !== 'object') {
+ continue;
+ }
+ const content = asString(frag.content);
+ if (!content) {
+ continue;
+ }
+ const t = asString(frag.type).toUpperCase();
+ if (t === 'THINK' || t === 'THINKING') {
+ newType = 'thinking';
+ parts.push({ text: content, type: 'thinking' });
+ } else if (t === 'RESPONSE') {
+ newType = 'text';
+ parts.push({ text: content, type: 'text' });
+ } else {
+ parts.push({ text: content, type: partType });
+ }
+ }
+ }
+ }
+ return { parts, finished: false, newType };
+}
+
+function extractContentRecursive(items, defaultType) {
+ const parts = [];
+ for (const it of items) {
+ if (!it || typeof it !== 'object') {
+ continue;
+ }
+ if (!Object.prototype.hasOwnProperty.call(it, 'v')) {
+ continue;
+ }
+ const itemPath = asString(it.p);
+ const itemV = it.v;
+ if (itemPath === 'status' && asString(itemV) === 'FINISHED') {
+ return { parts: [], finished: true };
+ }
+ if (shouldSkipPath(itemPath)) {
+ continue;
+ }
+ const content = asString(it.content);
+ if (content) {
+ const typeName = asString(it.type).toUpperCase();
+ if (typeName === 'THINK' || typeName === 'THINKING') {
+ parts.push({ text: content, type: 'thinking' });
+ } else if (typeName === 'RESPONSE') {
+ parts.push({ text: content, type: 'text' });
+ } else {
+ parts.push({ text: content, type: defaultType });
+ }
+ continue;
+ }
+
+ let partType = defaultType;
+ if (itemPath.includes('thinking')) {
+ partType = 'thinking';
+ } else if (itemPath.includes('content') || itemPath === 'response' || itemPath === 'fragments') {
+ partType = 'text';
+ }
+
+ if (typeof itemV === 'string') {
+ if (itemV && itemV !== 'FINISHED') {
+ parts.push({ text: itemV, type: partType });
+ }
+ continue;
+ }
+
+ if (!Array.isArray(itemV)) {
+ continue;
+ }
+ for (const inner of itemV) {
+ if (typeof inner === 'string') {
+ if (inner) {
+ parts.push({ text: inner, type: partType });
+ }
+ continue;
+ }
+ if (!inner || typeof inner !== 'object') {
+ continue;
+ }
+ const ct = asString(inner.content);
+ if (!ct) {
+ continue;
+ }
+ const typeName = asString(inner.type).toUpperCase();
+ if (typeName === 'THINK' || typeName === 'THINKING') {
+ parts.push({ text: ct, type: 'thinking' });
+ } else if (typeName === 'RESPONSE') {
+ parts.push({ text: ct, type: 'text' });
+ } else {
+ parts.push({ text: ct, type: partType });
+ }
+ }
+ }
+ return { parts, finished: false };
+}
+
+function shouldSkipPath(pathValue) {
+ if (SKIP_EXACT_PATHS.has(pathValue)) {
+ return true;
+ }
+ for (const p of SKIP_PATTERNS) {
+ if (pathValue.includes(p)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+function isCitation(text) {
+ return asString(text).trim().startsWith('[citation:');
+}
+
+function asString(v) {
+ if (typeof v === 'string') {
+ return v.trim();
+ }
+ if (Array.isArray(v)) {
+ return asString(v[0]);
+ }
+ if (v == null) {
+ return '';
+ }
+ return String(v).trim();
+}
+
+module.exports = {
+ parseChunkForContent,
+ extractContentRecursive,
+ shouldSkipPath,
+ isCitation,
+};
diff --git a/api/chat-stream/stream_emitter.js b/api/chat-stream/stream_emitter.js
new file mode 100644
index 0000000..442c24e
--- /dev/null
+++ b/api/chat-stream/stream_emitter.js
@@ -0,0 +1,39 @@
+'use strict';
+
+function createChatCompletionEmitter({ res, sessionID, created, model, isClosed }) {
+ let firstChunkSent = false;
+
+ const sendFrame = (obj) => {
+ if (isClosed() || res.writableEnded || res.destroyed) {
+ return;
+ }
+ res.write(`data: ${JSON.stringify(obj)}\n\n`);
+ if (typeof res.flush === 'function') {
+ res.flush();
+ }
+ };
+
+ const sendDeltaFrame = (delta) => {
+ const payloadDelta = { ...delta };
+ if (!firstChunkSent) {
+ payloadDelta.role = 'assistant';
+ firstChunkSent = true;
+ }
+ sendFrame({
+ id: sessionID,
+ object: 'chat.completion.chunk',
+ created,
+ model,
+ choices: [{ delta: payloadDelta, index: 0 }],
+ });
+ };
+
+ return {
+ sendFrame,
+ sendDeltaFrame,
+ };
+}
+
+module.exports = {
+ createChatCompletionEmitter,
+};
diff --git a/api/chat-stream/token_usage.js b/api/chat-stream/token_usage.js
new file mode 100644
index 0000000..57a36fb
--- /dev/null
+++ b/api/chat-stream/token_usage.js
@@ -0,0 +1,51 @@
+'use strict';
+
+function buildUsage(prompt, thinking, output) {
+ const promptTokens = estimateTokens(prompt);
+ const reasoningTokens = estimateTokens(thinking);
+ const completionTokens = estimateTokens(output);
+ return {
+ prompt_tokens: promptTokens,
+ completion_tokens: reasoningTokens + completionTokens,
+ total_tokens: promptTokens + reasoningTokens + completionTokens,
+ completion_tokens_details: {
+ reasoning_tokens: reasoningTokens,
+ },
+ };
+}
+
+function estimateTokens(text) {
+ const t = asString(text);
+ if (!t) {
+ return 0;
+ }
+ let asciiChars = 0;
+ let nonASCIIChars = 0;
+ for (const ch of Array.from(t)) {
+ if (ch.charCodeAt(0) < 128) {
+ asciiChars += 1;
+ } else {
+ nonASCIIChars += 1;
+ }
+ }
+ const n = Math.floor(asciiChars / 4) + Math.floor((nonASCIIChars * 10 + 7) / 13);
+ return n < 1 ? 1 : n;
+}
+
+function asString(v) {
+ if (typeof v === 'string') {
+ return v.trim();
+ }
+ if (Array.isArray(v)) {
+ return asString(v[0]);
+ }
+ if (v == null) {
+ return '';
+ }
+ return String(v).trim();
+}
+
+module.exports = {
+ buildUsage,
+ estimateTokens,
+};
diff --git a/api/chat-stream/toolcall_policy.js b/api/chat-stream/toolcall_policy.js
new file mode 100644
index 0000000..4f4b37c
--- /dev/null
+++ b/api/chat-stream/toolcall_policy.js
@@ -0,0 +1,107 @@
+'use strict';
+
+const crypto = require('crypto');
+
+const {
+ extractToolNames,
+} = require('../helpers/stream-tool-sieve');
+
+function resolveToolcallPolicy(prepBody, payloadTools) {
+ const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
+ const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
+ const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
+ const emitEarlyToolDeltas = boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
+ return {
+ toolNames,
+ toolSieveEnabled: toolNames.length > 0 && featureMatchEnabled,
+ emitEarlyToolDeltas,
+ };
+}
+
+function normalizePreparedToolNames(v) {
+ if (!Array.isArray(v) || v.length === 0) {
+ return [];
+ }
+ const out = [];
+ for (const item of v) {
+ const name = asString(item);
+ if (!name) {
+ continue;
+ }
+ out.push(name);
+ }
+ return out;
+}
+
+function boolDefaultTrue(v) {
+ return v !== false;
+}
+
+function formatIncrementalToolCallDeltas(deltas, idStore) {
+ if (!Array.isArray(deltas) || deltas.length === 0) {
+ return [];
+ }
+ const out = [];
+ for (const d of deltas) {
+ if (!d || typeof d !== 'object') {
+ continue;
+ }
+ const index = Number.isInteger(d.index) ? d.index : 0;
+ const id = ensureStreamToolCallID(idStore, index);
+ const item = {
+ index,
+ id,
+ type: 'function',
+ };
+ const fn = {};
+ if (asString(d.name)) {
+ fn.name = asString(d.name);
+ }
+ if (typeof d.arguments === 'string' && d.arguments !== '') {
+ fn.arguments = d.arguments;
+ }
+ if (Object.keys(fn).length > 0) {
+ item.function = fn;
+ }
+ out.push(item);
+ }
+ return out;
+}
+
+function ensureStreamToolCallID(idStore, index) {
+ const key = Number.isInteger(index) ? index : 0;
+ const existing = idStore.get(key);
+ if (existing) {
+ return existing;
+ }
+ const next = `call_${newCallID()}`;
+ idStore.set(key, next);
+ return next;
+}
+
+function newCallID() {
+ if (typeof crypto.randomUUID === 'function') {
+ return crypto.randomUUID().replace(/-/g, '');
+ }
+ return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
+}
+
+function asString(v) {
+ if (typeof v === 'string') {
+ return v.trim();
+ }
+ if (Array.isArray(v)) {
+ return asString(v[0]);
+ }
+ if (v == null) {
+ return '';
+ }
+ return String(v).trim();
+}
+
+module.exports = {
+ resolveToolcallPolicy,
+ normalizePreparedToolNames,
+ boolDefaultTrue,
+ formatIncrementalToolCallDeltas,
+};
diff --git a/api/chat-stream/vercel_stream.js b/api/chat-stream/vercel_stream.js
new file mode 100644
index 0000000..324a3d8
--- /dev/null
+++ b/api/chat-stream/vercel_stream.js
@@ -0,0 +1,297 @@
+'use strict';
+
+const {
+ extractToolNames,
+ createToolSieveState,
+ processToolSieveChunk,
+ flushToolSieve,
+ parseToolCalls,
+ formatOpenAIStreamToolCalls,
+} = require('../helpers/stream-tool-sieve');
+const {
+ BASE_HEADERS,
+} = require('../shared/deepseek-constants');
+
+const {
+ writeOpenAIError,
+} = require('./error_shape');
+const {
+ parseChunkForContent,
+ isCitation,
+} = require('./sse_parse');
+const {
+ buildUsage,
+} = require('./token_usage');
+const {
+ resolveToolcallPolicy,
+ formatIncrementalToolCallDeltas,
+} = require('./toolcall_policy');
+const {
+ createChatCompletionEmitter,
+} = require('./stream_emitter');
+const {
+ asString,
+ isAbortError,
+ fetchStreamPrepare,
+ relayPreparedFailure,
+ safeReadText,
+ createLeaseReleaser,
+} = require('./http_internal');
+
+const DEEPSEEK_COMPLETION_URL = 'https://chat.deepseek.com/api/v0/chat/completion';
+
+async function handleVercelStream(req, res, rawBody, payload) {
+ const prep = await fetchStreamPrepare(req, rawBody);
+ if (!prep.ok) {
+ relayPreparedFailure(res, prep);
+ return;
+ }
+
+ const model = asString(prep.body.model) || asString(payload.model);
+ const sessionID = asString(prep.body.session_id) || `chatcmpl-${Date.now()}`;
+ const leaseID = asString(prep.body.lease_id);
+ const deepseekToken = asString(prep.body.deepseek_token);
+ const powHeader = asString(prep.body.pow_header);
+ const completionPayload = prep.body.payload && typeof prep.body.payload === 'object' ? prep.body.payload : null;
+ const finalPrompt = asString(prep.body.final_prompt);
+ const thinkingEnabled = toBool(prep.body.thinking_enabled);
+ const searchEnabled = toBool(prep.body.search_enabled);
+ const toolPolicy = resolveToolcallPolicy(prep.body, payload.tools);
+ const toolNames = toolPolicy.toolNames;
+
+ if (!model || !leaseID || !deepseekToken || !powHeader || !completionPayload) {
+ writeOpenAIError(res, 500, 'invalid vercel prepare response');
+ return;
+ }
+
+ const releaseLease = createLeaseReleaser(req, leaseID);
+ const upstreamController = new AbortController();
+ let clientClosed = false;
+ let reader = null;
+ const markClientClosed = () => {
+ if (clientClosed) {
+ return;
+ }
+ clientClosed = true;
+ upstreamController.abort();
+ if (reader && typeof reader.cancel === 'function') {
+ Promise.resolve(reader.cancel()).catch(() => {});
+ }
+ };
+ const onReqAborted = () => markClientClosed();
+ const onResClose = () => {
+ if (!res.writableEnded) {
+ markClientClosed();
+ }
+ };
+ req.on('aborted', onReqAborted);
+ res.on('close', onResClose);
+
+ try {
+ let completionRes;
+ try {
+ completionRes = await fetch(DEEPSEEK_COMPLETION_URL, {
+ method: 'POST',
+ headers: {
+ ...BASE_HEADERS,
+ authorization: `Bearer ${deepseekToken}`,
+ 'x-ds-pow-response': powHeader,
+ },
+ body: JSON.stringify(completionPayload),
+ signal: upstreamController.signal,
+ });
+ } catch (err) {
+ if (clientClosed || isAbortError(err)) {
+ return;
+ }
+ throw err;
+ }
+ if (clientClosed) {
+ return;
+ }
+
+ if (!completionRes.ok || !completionRes.body) {
+ const detail = await safeReadText(completionRes);
+ writeOpenAIError(res, 500, detail ? `Failed to get completion: ${detail}` : 'Failed to get completion.');
+ return;
+ }
+
+ res.statusCode = 200;
+ res.setHeader('Content-Type', 'text/event-stream');
+ res.setHeader('Cache-Control', 'no-cache, no-transform');
+ res.setHeader('Connection', 'keep-alive');
+ res.setHeader('X-Accel-Buffering', 'no');
+ if (typeof res.flushHeaders === 'function') {
+ res.flushHeaders();
+ }
+
+ const created = Math.floor(Date.now() / 1000);
+ let currentType = thinkingEnabled ? 'thinking' : 'text';
+ let thinkingText = '';
+ let outputText = '';
+ const toolSieveEnabled = toolPolicy.toolSieveEnabled;
+ const emitEarlyToolDeltas = toolPolicy.emitEarlyToolDeltas;
+ const toolSieveState = createToolSieveState();
+ let toolCallsEmitted = false;
+ const streamToolCallIDs = new Map();
+ const decoder = new TextDecoder();
+ reader = completionRes.body.getReader();
+ let buffered = '';
+ let ended = false;
+ const { sendFrame, sendDeltaFrame } = createChatCompletionEmitter({
+ res,
+ sessionID,
+ created,
+ model,
+ isClosed: () => clientClosed,
+ });
+
+ const finish = async (reason) => {
+ if (ended) {
+ return;
+ }
+ ended = true;
+ if (clientClosed || res.writableEnded || res.destroyed) {
+ await releaseLease();
+ return;
+ }
+ const detected = parseToolCalls(outputText, toolNames);
+ if (detected.length > 0 && !toolCallsEmitted) {
+ toolCallsEmitted = true;
+ sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) });
+ } else if (toolSieveEnabled) {
+ const tailEvents = flushToolSieve(toolSieveState, toolNames);
+ for (const evt of tailEvents) {
+ if (evt.text) {
+ sendDeltaFrame({ content: evt.text });
+ }
+ }
+ }
+ if (detected.length > 0 || toolCallsEmitted) {
+ reason = 'tool_calls';
+ }
+ sendFrame({
+ id: sessionID,
+ object: 'chat.completion.chunk',
+ created,
+ model,
+ choices: [{ delta: {}, index: 0, finish_reason: reason }],
+ usage: buildUsage(finalPrompt, thinkingText, outputText),
+ });
+ if (!res.writableEnded && !res.destroyed) {
+ res.write('data: [DONE]\n\n');
+ }
+ await releaseLease();
+ if (!res.writableEnded && !res.destroyed) {
+ res.end();
+ }
+ };
+
+ try {
+ // eslint-disable-next-line no-constant-condition
+ while (true) {
+ if (clientClosed) {
+ await finish('stop');
+ return;
+ }
+ const { value, done } = await reader.read();
+ if (done) {
+ break;
+ }
+ buffered += decoder.decode(value, { stream: true });
+ const lines = buffered.split('\n');
+ buffered = lines.pop() || '';
+
+ for (const rawLine of lines) {
+ const line = rawLine.trim();
+ if (!line.startsWith('data:')) {
+ continue;
+ }
+ const dataStr = line.slice(5).trim();
+ if (!dataStr) {
+ continue;
+ }
+ if (dataStr === '[DONE]') {
+ await finish('stop');
+ return;
+ }
+ let chunk;
+ try {
+ chunk = JSON.parse(dataStr);
+ } catch (_err) {
+ continue;
+ }
+ if (chunk.error || chunk.code === 'content_filter') {
+ await finish('content_filter');
+ return;
+ }
+ const parsed = parseChunkForContent(chunk, thinkingEnabled, currentType);
+ currentType = parsed.newType;
+ if (parsed.finished) {
+ await finish('stop');
+ return;
+ }
+
+ for (const p of parsed.parts) {
+ if (!p.text) {
+ continue;
+ }
+ if (searchEnabled && isCitation(p.text)) {
+ continue;
+ }
+ if (p.type === 'thinking') {
+ if (thinkingEnabled) {
+ thinkingText += p.text;
+ sendDeltaFrame({ reasoning_content: p.text });
+ }
+ } else {
+ outputText += p.text;
+ if (!toolSieveEnabled) {
+ sendDeltaFrame({ content: p.text });
+ continue;
+ }
+ const events = processToolSieveChunk(toolSieveState, p.text, toolNames);
+ for (const evt of events) {
+ if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) {
+ if (!emitEarlyToolDeltas) {
+ continue;
+ }
+ toolCallsEmitted = true;
+ sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) });
+ continue;
+ }
+ if (evt.type === 'tool_calls') {
+ toolCallsEmitted = true;
+ sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) });
+ continue;
+ }
+ if (evt.text) {
+ sendDeltaFrame({ content: evt.text });
+ }
+ }
+ }
+ }
+ }
+ }
+ await finish('stop');
+ } catch (err) {
+ if (clientClosed || isAbortError(err)) {
+ await finish('stop');
+ return;
+ }
+ await finish('stop');
+ }
+ } finally {
+ req.removeListener('aborted', onReqAborted);
+ res.removeListener('close', onResClose);
+ await releaseLease();
+ }
+}
+
+function toBool(v) {
+ return v === true;
+}
+
+module.exports = {
+ handleVercelStream,
+};
diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js
index 44e31cd..8985478 100644
--- a/api/helpers/stream-tool-sieve.js
+++ b/api/helpers/stream-tool-sieve.js
@@ -1,957 +1,3 @@
'use strict';
-const crypto = require('crypto');
-const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
-const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
-const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
-
-function extractToolNames(tools) {
- if (!Array.isArray(tools) || tools.length === 0) {
- return [];
- }
- const out = [];
- for (const t of tools) {
- if (!t || typeof t !== 'object') {
- continue;
- }
- const fn = t.function && typeof t.function === 'object' ? t.function : t;
- const name = toStringSafe(fn.name);
- // Keep parity with Go injectToolPrompt: object tools without name still
- // enter tool mode via fallback name "unknown".
- out.push(name || 'unknown');
- }
- return out;
-}
-
-function createToolSieveState() {
- return {
- pending: '',
- capture: '',
- capturing: false,
- recentTextTail: '',
- toolNameSent: false,
- toolName: '',
- toolArgsStart: -1,
- toolArgsSent: -1,
- toolArgsString: false,
- toolArgsDone: false,
- };
-}
-
-function resetIncrementalToolState(state) {
- state.toolNameSent = false;
- state.toolName = '';
- state.toolArgsStart = -1;
- state.toolArgsSent = -1;
- state.toolArgsString = false;
- state.toolArgsDone = false;
-}
-
-function processToolSieveChunk(state, chunk, toolNames) {
- if (!state) {
- return [];
- }
- if (chunk) {
- state.pending += chunk;
- }
- const events = [];
- // eslint-disable-next-line no-constant-condition
- while (true) {
- if (state.capturing) {
- if (state.pending) {
- state.capture += state.pending;
- state.pending = '';
- }
- const deltas = buildIncrementalToolDeltas(state);
- if (deltas.length > 0) {
- events.push({ type: 'tool_call_deltas', deltas });
- }
- const consumed = consumeToolCapture(state, toolNames);
- if (!consumed.ready) {
- if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
- noteText(state, state.capture);
- events.push({ type: 'text', text: state.capture });
- state.capture = '';
- state.capturing = false;
- resetIncrementalToolState(state);
- continue;
- }
- break;
- }
- state.capture = '';
- state.capturing = false;
- resetIncrementalToolState(state);
- if (consumed.prefix) {
- noteText(state, consumed.prefix);
- events.push({ type: 'text', text: consumed.prefix });
- }
- if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
- events.push({ type: 'tool_calls', calls: consumed.calls });
- }
- if (consumed.suffix) {
- state.pending += consumed.suffix;
- }
- continue;
- }
-
- if (!state.pending) {
- break;
- }
-
- const start = findToolSegmentStart(state.pending);
- if (start >= 0) {
- const prefix = state.pending.slice(0, start);
- if (prefix) {
- noteText(state, prefix);
- events.push({ type: 'text', text: prefix });
- }
- state.capture = state.pending.slice(start);
- state.pending = '';
- state.capturing = true;
- resetIncrementalToolState(state);
- continue;
- }
-
- const [safe, hold] = splitSafeContentForToolDetection(state.pending);
- if (!safe) {
- break;
- }
- state.pending = hold;
- noteText(state, safe);
- events.push({ type: 'text', text: safe });
- }
- return events;
-}
-
-function flushToolSieve(state, toolNames) {
- if (!state) {
- return [];
- }
- const events = processToolSieveChunk(state, '', toolNames);
- if (state.capturing) {
- const consumed = consumeToolCapture(state, toolNames);
- if (consumed.ready) {
- if (consumed.prefix) {
- noteText(state, consumed.prefix);
- events.push({ type: 'text', text: consumed.prefix });
- }
- if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
- events.push({ type: 'tool_calls', calls: consumed.calls });
- }
- if (consumed.suffix) {
- noteText(state, consumed.suffix);
- events.push({ type: 'text', text: consumed.suffix });
- }
- } else if (state.capture) {
- noteText(state, state.capture);
- events.push({ type: 'text', text: state.capture });
- }
- state.capture = '';
- state.capturing = false;
- resetIncrementalToolState(state);
- }
- if (state.pending) {
- noteText(state, state.pending);
- events.push({ type: 'text', text: state.pending });
- state.pending = '';
- }
- return events;
-}
-
-function splitSafeContentForToolDetection(s) {
- const text = s || '';
- if (!text) {
- return ['', ''];
- }
- const suspiciousStart = findSuspiciousPrefixStart(text);
- if (suspiciousStart < 0) {
- return [text, ''];
- }
- if (suspiciousStart > 0) {
- return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
- }
- // If suspicious content starts at the beginning, keep holding until we can
- // either parse a full tool JSON block or reach stream flush.
- return ['', text];
-}
-
-function findSuspiciousPrefixStart(s) {
- let start = -1;
- for (const needle of ['{', '[', '```']) {
- const idx = s.lastIndexOf(needle);
- if (idx > start) {
- start = idx;
- }
- }
- return start;
-}
-
-function findToolSegmentStart(s) {
- if (!s) {
- return -1;
- }
- const lower = s.toLowerCase();
- let offset = 0;
- // eslint-disable-next-line no-constant-condition
- while (true) {
- const keyRel = lower.indexOf('tool_calls', offset);
- if (keyRel < 0) {
- return -1;
- }
- const keyIdx = keyRel;
- const start = s.slice(0, keyIdx).lastIndexOf('{');
- const candidateStart = start >= 0 ? start : keyIdx;
- if (!insideCodeFence(s.slice(0, candidateStart))) {
- return candidateStart;
- }
- offset = keyIdx + 'tool_calls'.length;
- }
-}
-
-function consumeToolCapture(state, toolNames) {
- const captured = state.capture;
- if (!captured) {
- return { ready: false, prefix: '', calls: [], suffix: '' };
- }
- const lower = captured.toLowerCase();
- const keyIdx = lower.indexOf('tool_calls');
- if (keyIdx < 0) {
- return { ready: false, prefix: '', calls: [], suffix: '' };
- }
- const start = captured.slice(0, keyIdx).lastIndexOf('{');
- if (start < 0) {
- return { ready: false, prefix: '', calls: [], suffix: '' };
- }
- const obj = extractJSONObjectFrom(captured, start);
- if (!obj.ok) {
- return { ready: false, prefix: '', calls: [], suffix: '' };
- }
- const prefixPart = captured.slice(0, start);
- const suffixPart = captured.slice(obj.end);
- if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
- return {
- ready: true,
- prefix: captured,
- calls: [],
- suffix: '',
- };
- }
- const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
- if (parsed.length === 0) {
- if (state.toolNameSent) {
- return {
- ready: true,
- prefix: prefixPart,
- calls: [],
- suffix: suffixPart,
- };
- }
- return {
- ready: true,
- prefix: captured,
- calls: [],
- suffix: '',
- };
- }
- if (state.toolNameSent) {
- if (parsed.length > 1) {
- return {
- ready: true,
- prefix: prefixPart,
- calls: parsed.slice(1),
- suffix: suffixPart,
- };
- }
- return {
- ready: true,
- prefix: prefixPart,
- calls: [],
- suffix: suffixPart,
- };
- }
- return {
- ready: true,
- prefix: prefixPart,
- calls: parsed,
- suffix: suffixPart,
- };
-}
-
-function buildIncrementalToolDeltas(state) {
- const captured = state.capture || '';
- if (!captured) {
- return [];
- }
- if (looksLikeToolExampleContext(state.recentTextTail)) {
- return [];
- }
- const lower = captured.toLowerCase();
- const keyIdx = lower.indexOf('tool_calls');
- if (keyIdx < 0) {
- return [];
- }
- const start = captured.slice(0, keyIdx).lastIndexOf('{');
- if (start < 0) {
- return [];
- }
- if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
- return [];
- }
- const callStart = findFirstToolCallObjectStart(captured, keyIdx);
- if (callStart < 0) {
- return [];
- }
-
- const deltas = [];
- if (!state.toolName) {
- const name = extractToolCallName(captured, callStart);
- if (!name) {
- return [];
- }
- state.toolName = name;
- }
-
- if (state.toolArgsStart < 0) {
- const args = findToolCallArgsStart(captured, callStart);
- if (args) {
- state.toolArgsString = Boolean(args.stringMode);
- state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
- state.toolArgsSent = state.toolArgsStart;
- }
- }
- if (!state.toolNameSent) {
- if (state.toolArgsStart < 0) {
- return [];
- }
- state.toolNameSent = true;
- deltas.push({ index: 0, name: state.toolName });
- }
- if (state.toolArgsStart < 0 || state.toolArgsDone) {
- return deltas;
- }
- const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
- if (!progress) {
- return deltas;
- }
- if (progress.end > state.toolArgsSent) {
- deltas.push({
- index: 0,
- arguments: captured.slice(state.toolArgsSent, progress.end),
- });
- state.toolArgsSent = progress.end;
- }
- if (progress.complete) {
- state.toolArgsDone = true;
- }
- return deltas;
-}
-
-function findFirstToolCallObjectStart(text, keyIdx) {
- const arrStart = findToolCallsArrayStart(text, keyIdx);
- if (arrStart < 0) {
- return -1;
- }
- const i = skipSpaces(text, arrStart + 1);
- if (i >= text.length || text[i] !== '{') {
- return -1;
- }
- return i;
-}
-
-function findToolCallsArrayStart(text, keyIdx) {
- let i = keyIdx + 'tool_calls'.length;
- while (i < text.length && text[i] !== ':') {
- i += 1;
- }
- if (i >= text.length) {
- return -1;
- }
- i = skipSpaces(text, i + 1);
- if (i >= text.length || text[i] !== '[') {
- return -1;
- }
- return i;
-}
-
-function extractToolCallName(text, callStart) {
- let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
- if (valueStart < 0 || text[valueStart] !== '"') {
- const fnStart = findFunctionObjectStart(text, callStart);
- if (fnStart < 0) {
- return '';
- }
- valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
- if (valueStart < 0 || text[valueStart] !== '"') {
- return '';
- }
- }
- const parsed = parseJSONStringLiteral(text, valueStart);
- if (!parsed) {
- return '';
- }
- return parsed.value;
-}
-
-function findToolCallArgsStart(text, callStart) {
- const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
- let valueStart = findObjectFieldValueStart(text, callStart, keys);
- if (valueStart < 0) {
- const fnStart = findFunctionObjectStart(text, callStart);
- if (fnStart < 0) {
- return null;
- }
- valueStart = findObjectFieldValueStart(text, fnStart, keys);
- if (valueStart < 0) {
- return null;
- }
- }
- if (valueStart >= text.length) {
- return null;
- }
- const ch = text[valueStart];
- if (ch === '{' || ch === '[') {
- return { start: valueStart, stringMode: false };
- }
- if (ch === '"') {
- return { start: valueStart, stringMode: true };
- }
- return null;
-}
-
-function scanToolCallArgsProgress(text, start, stringMode) {
- if (start < 0 || start > text.length) {
- return null;
- }
- if (stringMode) {
- let escaped = false;
- for (let i = start; i < text.length; i += 1) {
- const ch = text[i];
- if (escaped) {
- escaped = false;
- continue;
- }
- if (ch === '\\') {
- escaped = true;
- continue;
- }
- if (ch === '"') {
- return { end: i, complete: true };
- }
- }
- return { end: text.length, complete: false };
- }
- if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
- return null;
- }
- let depth = 0;
- let quote = '';
- let escaped = false;
- for (let i = start; i < text.length; i += 1) {
- const ch = text[i];
- if (quote) {
- if (escaped) {
- escaped = false;
- continue;
- }
- if (ch === '\\') {
- escaped = true;
- continue;
- }
- if (ch === quote) {
- quote = '';
- }
- continue;
- }
- if (ch === '"' || ch === "'") {
- quote = ch;
- continue;
- }
- if (ch === '{' || ch === '[') {
- depth += 1;
- continue;
- }
- if (ch === '}' || ch === ']') {
- depth -= 1;
- if (depth === 0) {
- return { end: i + 1, complete: true };
- }
- }
- }
- return { end: text.length, complete: false };
-}
-
-function findObjectFieldValueStart(text, objStart, keys) {
- if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') {
- return -1;
- }
- let depth = 0;
- let quote = '';
- let escaped = false;
- for (let i = objStart; i < text.length; i += 1) {
- const ch = text[i];
- if (quote) {
- if (escaped) {
- escaped = false;
- continue;
- }
- if (ch === '\\') {
- escaped = true;
- continue;
- }
- if (ch === quote) {
- quote = '';
- }
- continue;
- }
- if (ch === '"' || ch === "'") {
- if (depth === 1) {
- const parsed = parseJSONStringLiteral(text, i);
- if (!parsed) {
- return -1;
- }
- let j = skipSpaces(text, parsed.end);
- if (j >= text.length || text[j] !== ':') {
- i = parsed.end - 1;
- continue;
- }
- j = skipSpaces(text, j + 1);
- if (j >= text.length) {
- return -1;
- }
- if (keys.includes(parsed.value)) {
- return j;
- }
- i = j - 1;
- continue;
- }
- quote = ch;
- continue;
- }
- if (ch === '{') {
- depth += 1;
- continue;
- }
- if (ch === '}') {
- depth -= 1;
- if (depth === 0) {
- break;
- }
- }
- }
- return -1;
-}
-
-function findFunctionObjectStart(text, callStart) {
- const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
- if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
- return -1;
- }
- return valueStart;
-}
-
-function parseJSONStringLiteral(text, start) {
- if (!text || start < 0 || start >= text.length || text[start] !== '"') {
- return null;
- }
- let out = '';
- let escaped = false;
- for (let i = start + 1; i < text.length; i += 1) {
- const ch = text[i];
- if (escaped) {
- out += ch;
- escaped = false;
- continue;
- }
- if (ch === '\\') {
- escaped = true;
- continue;
- }
- if (ch === '"') {
- return { value: out, end: i + 1 };
- }
- out += ch;
- }
- return null;
-}
-
-function skipSpaces(text, i) {
- let idx = i;
- while (idx < text.length) {
- const ch = text[idx];
- if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') {
- idx += 1;
- continue;
- }
- break;
- }
- return idx;
-}
-
-function extractJSONObjectFrom(text, start) {
- if (!text || start < 0 || start >= text.length || text[start] !== '{') {
- return { ok: false, end: 0 };
- }
- let depth = 0;
- let quote = '';
- let escaped = false;
- for (let i = start; i < text.length; i += 1) {
- const ch = text[i];
- if (quote) {
- if (escaped) {
- escaped = false;
- continue;
- }
- if (ch === '\\') {
- escaped = true;
- continue;
- }
- if (ch === quote) {
- quote = '';
- }
- continue;
- }
- if (ch === '"' || ch === "'") {
- quote = ch;
- continue;
- }
- if (ch === '{') {
- depth += 1;
- continue;
- }
- if (ch === '}') {
- depth -= 1;
- if (depth === 0) {
- return { ok: true, end: i + 1 };
- }
- }
- }
- return { ok: false, end: 0 };
-}
-
-function parseToolCalls(text, toolNames) {
- if (!toStringSafe(text)) {
- return [];
- }
- const sanitized = stripFencedCodeBlocks(text);
- if (!toStringSafe(sanitized)) {
- return [];
- }
- const candidates = buildToolCallCandidates(sanitized);
- let parsed = [];
- for (const c of candidates) {
- parsed = parseToolCallsPayload(c);
- if (parsed.length > 0) {
- break;
- }
- }
- if (parsed.length === 0) {
- return [];
- }
- return filterToolCalls(parsed, toolNames);
-}
-
-function stripFencedCodeBlocks(text) {
- const t = typeof text === 'string' ? text : '';
- if (!t) {
- return '';
- }
- return t.replace(/```[\s\S]*?```/g, ' ');
-}
-
-function parseStandaloneToolCalls(text, toolNames) {
- const trimmed = toStringSafe(text);
- if (!trimmed) {
- return [];
- }
- if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
- return [];
- }
- if (looksLikeToolExampleContext(trimmed)) {
- return [];
- }
- const candidates = [trimmed];
- if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
- const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
- if (m && m[1]) {
- candidates.push(toStringSafe(m[1]));
- }
- }
- for (const candidate of candidates) {
- const c = toStringSafe(candidate);
- if (!c) {
- continue;
- }
- if (!c.startsWith('{') && !c.startsWith('[')) {
- continue;
- }
- const parsed = parseToolCallsPayload(c);
- if (parsed.length > 0) {
- return filterToolCalls(parsed, toolNames);
- }
- }
- return [];
-}
-
-function buildToolCallCandidates(text) {
- const trimmed = toStringSafe(text);
- const candidates = [trimmed];
- const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
- for (const block of fenced) {
- const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
- if (m && m[1]) {
- candidates.push(toStringSafe(m[1]));
- }
- }
- for (const candidate of extractToolCallObjects(trimmed)) {
- candidates.push(toStringSafe(candidate));
- }
- const first = trimmed.indexOf('{');
- const last = trimmed.lastIndexOf('}');
- if (first >= 0 && last > first) {
- candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
- }
- const m = trimmed.match(TOOL_CALL_PATTERN);
- if (m && m[1]) {
- candidates.push(`{"tool_calls":[${m[1]}]}`);
- }
- return [...new Set(candidates.filter(Boolean))];
-}
-
-function extractToolCallObjects(text) {
- const raw = toStringSafe(text);
- if (!raw) {
- return [];
- }
- const lower = raw.toLowerCase();
- const out = [];
- let offset = 0;
- // eslint-disable-next-line no-constant-condition
- while (true) {
- let idx = lower.indexOf('tool_calls', offset);
- if (idx < 0) {
- break;
- }
- let start = raw.slice(0, idx).lastIndexOf('{');
- while (start >= 0) {
- const obj = extractJSONObjectFrom(raw, start);
- if (obj.ok) {
- out.push(raw.slice(start, obj.end).trim());
- offset = obj.end;
- idx = -1;
- break;
- }
- start = raw.slice(0, start).lastIndexOf('{');
- }
- if (idx >= 0) {
- offset = idx + 'tool_calls'.length;
- }
- }
- return out;
-}
-
-function parseToolCallsPayload(payload) {
- let decoded;
- try {
- decoded = JSON.parse(payload);
- } catch (_err) {
- return [];
- }
- if (Array.isArray(decoded)) {
- return parseToolCallList(decoded);
- }
- if (!decoded || typeof decoded !== 'object') {
- return [];
- }
- if (decoded.tool_calls) {
- return parseToolCallList(decoded.tool_calls);
- }
- const one = parseToolCallItem(decoded);
- return one ? [one] : [];
-}
-
-function parseToolCallList(v) {
- if (!Array.isArray(v)) {
- return [];
- }
- const out = [];
- for (const item of v) {
- if (!item || typeof item !== 'object') {
- continue;
- }
- const one = parseToolCallItem(item);
- if (one) {
- out.push(one);
- }
- }
- return out;
-}
-
-function parseToolCallItem(m) {
- let name = toStringSafe(m.name);
- let inputRaw = m.input;
- let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
- const fn = m.function && typeof m.function === 'object' ? m.function : null;
- if (fn) {
- if (!name) {
- name = toStringSafe(fn.name);
- }
- if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
- inputRaw = fn.arguments;
- hasInput = true;
- }
- }
- if (!hasInput) {
- for (const k of ['arguments', 'args', 'parameters', 'params']) {
- if (Object.prototype.hasOwnProperty.call(m, k)) {
- inputRaw = m[k];
- hasInput = true;
- break;
- }
- }
- }
- if (!name) {
- return null;
- }
- return {
- name,
- input: parseToolCallInput(inputRaw),
- };
-}
-
-function parseToolCallInput(v) {
- if (v == null) {
- return {};
- }
- if (typeof v === 'string') {
- const raw = toStringSafe(v);
- if (!raw) {
- return {};
- }
- try {
- const parsed = JSON.parse(raw);
- if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
- return parsed;
- }
- return { _raw: raw };
- } catch (_err) {
- return { _raw: raw };
- }
- }
- if (typeof v === 'object' && !Array.isArray(v)) {
- return v;
- }
- try {
- const parsed = JSON.parse(JSON.stringify(v));
- if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
- return parsed;
- }
- } catch (_err) {
- return {};
- }
- return {};
-}
-
-function filterToolCalls(parsed, toolNames) {
- const allowed = new Set((toolNames || []).filter(Boolean));
- const out = [];
- for (const tc of parsed) {
- if (!tc || !tc.name) {
- continue;
- }
- if (allowed.size > 0 && !allowed.has(tc.name)) {
- continue;
- }
- out.push({ name: tc.name, input: tc.input || {} });
- }
- if (out.length === 0 && parsed.length > 0) {
- for (const tc of parsed) {
- if (!tc || !tc.name) {
- continue;
- }
- out.push({ name: tc.name, input: tc.input || {} });
- }
- }
- return out;
-}
-
-function noteText(state, text) {
- if (!state || !hasMeaningfulText(text)) {
- return;
- }
- state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
-}
-
-function appendTail(prev, next, max) {
- const left = typeof prev === 'string' ? prev : '';
- const right = typeof next === 'string' ? next : '';
- if (!Number.isFinite(max) || max <= 0) {
- return '';
- }
- const combined = left + right;
- if (combined.length <= max) {
- return combined;
- }
- return combined.slice(combined.length - max);
-}
-
-function looksLikeToolExampleContext(text) {
- return insideCodeFence(text);
-}
-
-function insideCodeFence(text) {
- const t = typeof text === 'string' ? text : '';
- if (!t) {
- return false;
- }
- const ticks = (t.match(/```/g) || []).length;
- return ticks % 2 === 1;
-}
-
-function hasMeaningfulText(text) {
- return toStringSafe(text) !== '';
-}
-
-function formatOpenAIStreamToolCalls(calls) {
- if (!Array.isArray(calls) || calls.length === 0) {
- return [];
- }
- return calls.map((c, idx) => ({
- index: idx,
- id: `call_${newCallID()}`,
- type: 'function',
- function: {
- name: c.name,
- arguments: JSON.stringify(c.input || {}),
- },
- }));
-}
-
-function newCallID() {
- if (typeof crypto.randomUUID === 'function') {
- return crypto.randomUUID().replace(/-/g, '');
- }
- return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
-}
-
-function toStringSafe(v) {
- if (typeof v === 'string') {
- return v.trim();
- }
- if (Array.isArray(v)) {
- return toStringSafe(v[0]);
- }
- if (v == null) {
- return '';
- }
- return String(v).trim();
-}
-
-module.exports = {
- extractToolNames,
- createToolSieveState,
- processToolSieveChunk,
- flushToolSieve,
- parseToolCalls,
- parseStandaloneToolCalls,
- formatOpenAIStreamToolCalls,
-};
+module.exports = require('./stream-tool-sieve/index.js');
diff --git a/api/helpers/stream-tool-sieve/format.js b/api/helpers/stream-tool-sieve/format.js
new file mode 100644
index 0000000..ff1dcef
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/format.js
@@ -0,0 +1,29 @@
+'use strict';
+
+const crypto = require('crypto');
+
+function formatOpenAIStreamToolCalls(calls) {
+ if (!Array.isArray(calls) || calls.length === 0) {
+ return [];
+ }
+ return calls.map((c, idx) => ({
+ index: idx,
+ id: `call_${newCallID()}`,
+ type: 'function',
+ function: {
+ name: c.name,
+ arguments: JSON.stringify(c.input || {}),
+ },
+ }));
+}
+
+function newCallID() {
+ if (typeof crypto.randomUUID === 'function') {
+ return crypto.randomUUID().replace(/-/g, '');
+ }
+ return `${Date.now()}${Math.floor(Math.random() * 1e9)}`;
+}
+
+module.exports = {
+ formatOpenAIStreamToolCalls,
+};
diff --git a/api/helpers/stream-tool-sieve/incremental.js b/api/helpers/stream-tool-sieve/incremental.js
new file mode 100644
index 0000000..1895075
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/incremental.js
@@ -0,0 +1,226 @@
+'use strict';
+
+const {
+ looksLikeToolExampleContext,
+ insideCodeFence,
+} = require('./state');
+const {
+ findObjectFieldValueStart,
+ parseJSONStringLiteral,
+ skipSpaces,
+} = require('./jsonscan');
+
+function buildIncrementalToolDeltas(state) {
+ const captured = state.capture || '';
+ if (!captured) {
+ return [];
+ }
+ if (looksLikeToolExampleContext(state.recentTextTail)) {
+ return [];
+ }
+ const lower = captured.toLowerCase();
+ const keyIdx = lower.indexOf('tool_calls');
+ if (keyIdx < 0) {
+ return [];
+ }
+ const start = captured.slice(0, keyIdx).lastIndexOf('{');
+ if (start < 0) {
+ return [];
+ }
+ if (insideCodeFence((state.recentTextTail || '') + captured.slice(0, start))) {
+ return [];
+ }
+ const callStart = findFirstToolCallObjectStart(captured, keyIdx);
+ if (callStart < 0) {
+ return [];
+ }
+
+ const deltas = [];
+ if (!state.toolName) {
+ const name = extractToolCallName(captured, callStart);
+ if (!name) {
+ return [];
+ }
+ state.toolName = name;
+ }
+
+ if (state.toolArgsStart < 0) {
+ const args = findToolCallArgsStart(captured, callStart);
+ if (args) {
+ state.toolArgsString = Boolean(args.stringMode);
+ state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start;
+ state.toolArgsSent = state.toolArgsStart;
+ }
+ }
+ if (!state.toolNameSent) {
+ if (state.toolArgsStart < 0) {
+ return [];
+ }
+ state.toolNameSent = true;
+ deltas.push({ index: 0, name: state.toolName });
+ }
+ if (state.toolArgsStart < 0 || state.toolArgsDone) {
+ return deltas;
+ }
+ const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString);
+ if (!progress) {
+ return deltas;
+ }
+ if (progress.end > state.toolArgsSent) {
+ deltas.push({
+ index: 0,
+ arguments: captured.slice(state.toolArgsSent, progress.end),
+ });
+ state.toolArgsSent = progress.end;
+ }
+ if (progress.complete) {
+ state.toolArgsDone = true;
+ }
+ return deltas;
+}
+
+function findFirstToolCallObjectStart(text, keyIdx) {
+ const arrStart = findToolCallsArrayStart(text, keyIdx);
+ if (arrStart < 0) {
+ return -1;
+ }
+ const i = skipSpaces(text, arrStart + 1);
+ if (i >= text.length || text[i] !== '{') {
+ return -1;
+ }
+ return i;
+}
+
+function findToolCallsArrayStart(text, keyIdx) {
+ let i = keyIdx + 'tool_calls'.length;
+ while (i < text.length && text[i] !== ':') {
+ i += 1;
+ }
+ if (i >= text.length) {
+ return -1;
+ }
+ i = skipSpaces(text, i + 1);
+ if (i >= text.length || text[i] !== '[') {
+ return -1;
+ }
+ return i;
+}
+
+function extractToolCallName(text, callStart) {
+ let valueStart = findObjectFieldValueStart(text, callStart, ['name']);
+ if (valueStart < 0 || text[valueStart] !== '"') {
+ const fnStart = findFunctionObjectStart(text, callStart);
+ if (fnStart < 0) {
+ return '';
+ }
+ valueStart = findObjectFieldValueStart(text, fnStart, ['name']);
+ if (valueStart < 0 || text[valueStart] !== '"') {
+ return '';
+ }
+ }
+ const parsed = parseJSONStringLiteral(text, valueStart);
+ if (!parsed) {
+ return '';
+ }
+ return parsed.value;
+}
+
+function findToolCallArgsStart(text, callStart) {
+ const keys = ['input', 'arguments', 'args', 'parameters', 'params'];
+ let valueStart = findObjectFieldValueStart(text, callStart, keys);
+ if (valueStart < 0) {
+ const fnStart = findFunctionObjectStart(text, callStart);
+ if (fnStart < 0) {
+ return null;
+ }
+ valueStart = findObjectFieldValueStart(text, fnStart, keys);
+ if (valueStart < 0) {
+ return null;
+ }
+ }
+ if (valueStart >= text.length) {
+ return null;
+ }
+ const ch = text[valueStart];
+ if (ch === '{' || ch === '[') {
+ return { start: valueStart, stringMode: false };
+ }
+ if (ch === '"') {
+ return { start: valueStart, stringMode: true };
+ }
+ return null;
+}
+
+function scanToolCallArgsProgress(text, start, stringMode) {
+ if (start < 0 || start > text.length) {
+ return null;
+ }
+ if (stringMode) {
+ let escaped = false;
+ for (let i = start; i < text.length; i += 1) {
+ const ch = text[i];
+ if (escaped) {
+ escaped = false;
+ continue;
+ }
+ if (ch === '\\') {
+ escaped = true;
+ continue;
+ }
+ if (ch === '"') {
+ return { end: i, complete: true };
+ }
+ }
+ return { end: text.length, complete: false };
+ }
+ if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) {
+ return null;
+ }
+ let depth = 0;
+ let quote = '';
+ let escaped = false;
+ for (let i = start; i < text.length; i += 1) {
+ const ch = text[i];
+ if (quote) {
+ if (escaped) {
+ escaped = false;
+ continue;
+ }
+ if (ch === '\\') {
+ escaped = true;
+ continue;
+ }
+ if (ch === quote) {
+ quote = '';
+ }
+ continue;
+ }
+ if (ch === '"' || ch === "'") {
+ quote = ch;
+ continue;
+ }
+ if (ch === '{' || ch === '[') {
+ depth += 1;
+ continue;
+ }
+ if (ch === '}' || ch === ']') {
+ depth -= 1;
+ if (depth === 0) {
+ return { end: i + 1, complete: true };
+ }
+ }
+ }
+ return { end: text.length, complete: false };
+}
+
+function findFunctionObjectStart(text, callStart) {
+ const valueStart = findObjectFieldValueStart(text, callStart, ['function']);
+ if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') {
+ return -1;
+ }
+ return valueStart;
+}
+
+module.exports = {
+ buildIncrementalToolDeltas,
+};
diff --git a/api/helpers/stream-tool-sieve/index.js b/api/helpers/stream-tool-sieve/index.js
new file mode 100644
index 0000000..f218b52
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/index.js
@@ -0,0 +1,27 @@
+'use strict';
+
+const {
+ createToolSieveState,
+} = require('./state');
+const {
+ processToolSieveChunk,
+ flushToolSieve,
+} = require('./sieve');
+const {
+ extractToolNames,
+ parseToolCalls,
+ parseStandaloneToolCalls,
+} = require('./parse');
+const {
+ formatOpenAIStreamToolCalls,
+} = require('./format');
+
+module.exports = {
+ extractToolNames,
+ createToolSieveState,
+ processToolSieveChunk,
+ flushToolSieve,
+ parseToolCalls,
+ parseStandaloneToolCalls,
+ formatOpenAIStreamToolCalls,
+};
diff --git a/api/helpers/stream-tool-sieve/jsonscan.js b/api/helpers/stream-tool-sieve/jsonscan.js
new file mode 100644
index 0000000..a86ed05
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/jsonscan.js
@@ -0,0 +1,148 @@
+'use strict';
+
+function findObjectFieldValueStart(text, objStart, keys) {
+ if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') {
+ return -1;
+ }
+ let depth = 0;
+ let quote = '';
+ let escaped = false;
+ for (let i = objStart; i < text.length; i += 1) {
+ const ch = text[i];
+ if (quote) {
+ if (escaped) {
+ escaped = false;
+ continue;
+ }
+ if (ch === '\\') {
+ escaped = true;
+ continue;
+ }
+ if (ch === quote) {
+ quote = '';
+ }
+ continue;
+ }
+ if (ch === '"' || ch === "'") {
+ if (depth === 1) {
+ const parsed = parseJSONStringLiteral(text, i);
+ if (!parsed) {
+ return -1;
+ }
+ let j = skipSpaces(text, parsed.end);
+ if (j >= text.length || text[j] !== ':') {
+ i = parsed.end - 1;
+ continue;
+ }
+ j = skipSpaces(text, j + 1);
+ if (j >= text.length) {
+ return -1;
+ }
+ if (keys.includes(parsed.value)) {
+ return j;
+ }
+ i = j - 1;
+ continue;
+ }
+ quote = ch;
+ continue;
+ }
+ if (ch === '{') {
+ depth += 1;
+ continue;
+ }
+ if (ch === '}') {
+ depth -= 1;
+ if (depth === 0) {
+ break;
+ }
+ }
+ }
+ return -1;
+}
+
+function parseJSONStringLiteral(text, start) {
+ if (!text || start < 0 || start >= text.length || text[start] !== '"') {
+ return null;
+ }
+ let out = '';
+ let escaped = false;
+ for (let i = start + 1; i < text.length; i += 1) {
+ const ch = text[i];
+ if (escaped) {
+ out += ch;
+ escaped = false;
+ continue;
+ }
+ if (ch === '\\') {
+ escaped = true;
+ continue;
+ }
+ if (ch === '"') {
+ return { value: out, end: i + 1 };
+ }
+ out += ch;
+ }
+ return null;
+}
+
+function skipSpaces(text, i) {
+ let idx = i;
+ while (idx < text.length) {
+ const ch = text[idx];
+ if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') {
+ idx += 1;
+ continue;
+ }
+ break;
+ }
+ return idx;
+}
+
+function extractJSONObjectFrom(text, start) {
+ if (!text || start < 0 || start >= text.length || text[start] !== '{') {
+ return { ok: false, end: 0 };
+ }
+ let depth = 0;
+ let quote = '';
+ let escaped = false;
+ for (let i = start; i < text.length; i += 1) {
+ const ch = text[i];
+ if (quote) {
+ if (escaped) {
+ escaped = false;
+ continue;
+ }
+ if (ch === '\\') {
+ escaped = true;
+ continue;
+ }
+ if (ch === quote) {
+ quote = '';
+ }
+ continue;
+ }
+ if (ch === '"' || ch === "'") {
+ quote = ch;
+ continue;
+ }
+ if (ch === '{') {
+ depth += 1;
+ continue;
+ }
+ if (ch === '}') {
+ depth -= 1;
+ if (depth === 0) {
+ return { ok: true, end: i + 1 };
+ }
+ }
+ }
+ return { ok: false, end: 0 };
+}
+
+module.exports = {
+ findObjectFieldValueStart,
+ parseJSONStringLiteral,
+ skipSpaces,
+ extractJSONObjectFrom,
+};
diff --git a/api/helpers/stream-tool-sieve/parse.js b/api/helpers/stream-tool-sieve/parse.js
new file mode 100644
index 0000000..def46db
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/parse.js
@@ -0,0 +1,281 @@
+'use strict';
+
+const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s;
+
+const {
+ toStringSafe,
+ looksLikeToolExampleContext,
+} = require('./state');
+const {
+ extractJSONObjectFrom,
+} = require('./jsonscan');
+
+function extractToolNames(tools) {
+ if (!Array.isArray(tools) || tools.length === 0) {
+ return [];
+ }
+ const out = [];
+ for (const t of tools) {
+ if (!t || typeof t !== 'object') {
+ continue;
+ }
+ const fn = t.function && typeof t.function === 'object' ? t.function : t;
+ const name = toStringSafe(fn.name);
+ // Keep parity with Go injectToolPrompt: object tools without name still
+ // enter tool mode via fallback name "unknown".
+ out.push(name || 'unknown');
+ }
+ return out;
+}
+
+function parseToolCalls(text, toolNames) {
+ if (!toStringSafe(text)) {
+ return [];
+ }
+ const sanitized = stripFencedCodeBlocks(text);
+ if (!toStringSafe(sanitized)) {
+ return [];
+ }
+ const candidates = buildToolCallCandidates(sanitized);
+ let parsed = [];
+ for (const c of candidates) {
+ parsed = parseToolCallsPayload(c);
+ if (parsed.length > 0) {
+ break;
+ }
+ }
+ if (parsed.length === 0) {
+ return [];
+ }
+ return filterToolCalls(parsed, toolNames);
+}
+
+function stripFencedCodeBlocks(text) {
+ const t = typeof text === 'string' ? text : '';
+ if (!t) {
+ return '';
+ }
+ return t.replace(/```[\s\S]*?```/g, ' ');
+}
+
+function parseStandaloneToolCalls(text, toolNames) {
+ const trimmed = toStringSafe(text);
+ if (!trimmed) {
+ return [];
+ }
+ if ((trimmed.startsWith('```') && trimmed.endsWith('```')) || trimmed.includes('```')) {
+ return [];
+ }
+ if (looksLikeToolExampleContext(trimmed)) {
+ return [];
+ }
+ const candidates = [trimmed];
+ if (trimmed.startsWith('```') && trimmed.endsWith('```')) {
+ const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
+ if (m && m[1]) {
+ candidates.push(toStringSafe(m[1]));
+ }
+ }
+ for (const candidate of candidates) {
+ const c = toStringSafe(candidate);
+ if (!c) {
+ continue;
+ }
+ if (!c.startsWith('{') && !c.startsWith('[')) {
+ continue;
+ }
+ const parsed = parseToolCallsPayload(c);
+ if (parsed.length > 0) {
+ return filterToolCalls(parsed, toolNames);
+ }
+ }
+ return [];
+}
+
+function buildToolCallCandidates(text) {
+ const trimmed = toStringSafe(text);
+ const candidates = [trimmed];
+ const fenced = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/gi) || [];
+ for (const block of fenced) {
+ const m = block.match(/```(?:json)?\s*([\s\S]*?)\s*```/i);
+ if (m && m[1]) {
+ candidates.push(toStringSafe(m[1]));
+ }
+ }
+ for (const candidate of extractToolCallObjects(trimmed)) {
+ candidates.push(toStringSafe(candidate));
+ }
+ const first = trimmed.indexOf('{');
+ const last = trimmed.lastIndexOf('}');
+ if (first >= 0 && last > first) {
+ candidates.push(toStringSafe(trimmed.slice(first, last + 1)));
+ }
+ const m = trimmed.match(TOOL_CALL_PATTERN);
+ if (m && m[1]) {
+ candidates.push(`{"tool_calls":[${m[1]}]}`);
+ }
+ return [...new Set(candidates.filter(Boolean))];
+}
+
+function extractToolCallObjects(text) {
+ const raw = toStringSafe(text);
+ if (!raw) {
+ return [];
+ }
+ const lower = raw.toLowerCase();
+ const out = [];
+ let offset = 0;
+ // eslint-disable-next-line no-constant-condition
+ while (true) {
+ let idx = lower.indexOf('tool_calls', offset);
+ if (idx < 0) {
+ break;
+ }
+ let start = raw.slice(0, idx).lastIndexOf('{');
+ while (start >= 0) {
+ const obj = extractJSONObjectFrom(raw, start);
+ if (obj.ok) {
+ out.push(raw.slice(start, obj.end).trim());
+ offset = obj.end;
+ idx = -1;
+ break;
+ }
+ start = raw.slice(0, start).lastIndexOf('{');
+ }
+ if (idx >= 0) {
+ offset = idx + 'tool_calls'.length;
+ }
+ }
+ return out;
+}
+
+function parseToolCallsPayload(payload) {
+ let decoded;
+ try {
+ decoded = JSON.parse(payload);
+ } catch (_err) {
+ return [];
+ }
+ if (Array.isArray(decoded)) {
+ return parseToolCallList(decoded);
+ }
+ if (!decoded || typeof decoded !== 'object') {
+ return [];
+ }
+ if (decoded.tool_calls) {
+ return parseToolCallList(decoded.tool_calls);
+ }
+ const one = parseToolCallItem(decoded);
+ return one ? [one] : [];
+}
+
+function parseToolCallList(v) {
+ if (!Array.isArray(v)) {
+ return [];
+ }
+ const out = [];
+ for (const item of v) {
+ if (!item || typeof item !== 'object') {
+ continue;
+ }
+ const one = parseToolCallItem(item);
+ if (one) {
+ out.push(one);
+ }
+ }
+ return out;
+}
+
+function parseToolCallItem(m) {
+ let name = toStringSafe(m.name);
+ let inputRaw = m.input;
+ let hasInput = Object.prototype.hasOwnProperty.call(m, 'input');
+ const fn = m.function && typeof m.function === 'object' ? m.function : null;
+ if (fn) {
+ if (!name) {
+ name = toStringSafe(fn.name);
+ }
+ if (!hasInput && Object.prototype.hasOwnProperty.call(fn, 'arguments')) {
+ inputRaw = fn.arguments;
+ hasInput = true;
+ }
+ }
+ if (!hasInput) {
+ for (const k of ['arguments', 'args', 'parameters', 'params']) {
+ if (Object.prototype.hasOwnProperty.call(m, k)) {
+ inputRaw = m[k];
+ hasInput = true;
+ break;
+ }
+ }
+ }
+ if (!name) {
+ return null;
+ }
+ return {
+ name,
+ input: parseToolCallInput(inputRaw),
+ };
+}
+
+function parseToolCallInput(v) {
+ if (v == null) {
+ return {};
+ }
+ if (typeof v === 'string') {
+ const raw = toStringSafe(v);
+ if (!raw) {
+ return {};
+ }
+ try {
+ const parsed = JSON.parse(raw);
+ if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
+ return parsed;
+ }
+ return { _raw: raw };
+ } catch (_err) {
+ return { _raw: raw };
+ }
+ }
+ if (typeof v === 'object' && !Array.isArray(v)) {
+ return v;
+ }
+ try {
+ const parsed = JSON.parse(JSON.stringify(v));
+ if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
+ return parsed;
+ }
+ } catch (_err) {
+ return {};
+ }
+ return {};
+}
+
+function filterToolCalls(parsed, toolNames) {
+ const allowed = new Set((toolNames || []).filter(Boolean));
+ const out = [];
+ for (const tc of parsed) {
+ if (!tc || !tc.name) {
+ continue;
+ }
+ if (allowed.size > 0 && !allowed.has(tc.name)) {
+ continue;
+ }
+ out.push({ name: tc.name, input: tc.input || {} });
+ }
+ if (out.length === 0 && parsed.length > 0) {
+ for (const tc of parsed) {
+ if (!tc || !tc.name) {
+ continue;
+ }
+ out.push({ name: tc.name, input: tc.input || {} });
+ }
+ }
+ return out;
+}
+
+module.exports = {
+ extractToolNames,
+ parseToolCalls,
+ parseStandaloneToolCalls,
+};
diff --git a/api/helpers/stream-tool-sieve/sieve.js b/api/helpers/stream-tool-sieve/sieve.js
new file mode 100644
index 0000000..c10e636
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/sieve.js
@@ -0,0 +1,252 @@
+'use strict';
+
+const {
+ TOOL_SIEVE_CAPTURE_LIMIT,
+ resetIncrementalToolState,
+ noteText,
+ insideCodeFence,
+} = require('./state');
+const {
+ buildIncrementalToolDeltas,
+} = require('./incremental');
+const {
+ parseStandaloneToolCalls,
+} = require('./parse');
+const {
+ extractJSONObjectFrom,
+} = require('./jsonscan');
+
+function processToolSieveChunk(state, chunk, toolNames) {
+ if (!state) {
+ return [];
+ }
+ if (chunk) {
+ state.pending += chunk;
+ }
+ const events = [];
+ // eslint-disable-next-line no-constant-condition
+ while (true) {
+ if (state.capturing) {
+ if (state.pending) {
+ state.capture += state.pending;
+ state.pending = '';
+ }
+ const deltas = buildIncrementalToolDeltas(state);
+ if (deltas.length > 0) {
+ events.push({ type: 'tool_call_deltas', deltas });
+ }
+ const consumed = consumeToolCapture(state, toolNames);
+ if (!consumed.ready) {
+ if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) {
+ noteText(state, state.capture);
+ events.push({ type: 'text', text: state.capture });
+ state.capture = '';
+ state.capturing = false;
+ resetIncrementalToolState(state);
+ continue;
+ }
+ break;
+ }
+ state.capture = '';
+ state.capturing = false;
+ resetIncrementalToolState(state);
+ if (consumed.prefix) {
+ noteText(state, consumed.prefix);
+ events.push({ type: 'text', text: consumed.prefix });
+ }
+ if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
+ events.push({ type: 'tool_calls', calls: consumed.calls });
+ }
+ if (consumed.suffix) {
+ state.pending += consumed.suffix;
+ }
+ continue;
+ }
+
+ if (!state.pending) {
+ break;
+ }
+
+ const start = findToolSegmentStart(state.pending);
+ if (start >= 0) {
+ const prefix = state.pending.slice(0, start);
+ if (prefix) {
+ noteText(state, prefix);
+ events.push({ type: 'text', text: prefix });
+ }
+ state.capture = state.pending.slice(start);
+ state.pending = '';
+ state.capturing = true;
+ resetIncrementalToolState(state);
+ continue;
+ }
+
+ const [safe, hold] = splitSafeContentForToolDetection(state.pending);
+ if (!safe) {
+ break;
+ }
+ state.pending = hold;
+ noteText(state, safe);
+ events.push({ type: 'text', text: safe });
+ }
+ return events;
+}
+
+function flushToolSieve(state, toolNames) {
+ if (!state) {
+ return [];
+ }
+ const events = processToolSieveChunk(state, '', toolNames);
+ if (state.capturing) {
+ const consumed = consumeToolCapture(state, toolNames);
+ if (consumed.ready) {
+ if (consumed.prefix) {
+ noteText(state, consumed.prefix);
+ events.push({ type: 'text', text: consumed.prefix });
+ }
+ if (Array.isArray(consumed.calls) && consumed.calls.length > 0) {
+ events.push({ type: 'tool_calls', calls: consumed.calls });
+ }
+ if (consumed.suffix) {
+ noteText(state, consumed.suffix);
+ events.push({ type: 'text', text: consumed.suffix });
+ }
+ } else if (state.capture) {
+ noteText(state, state.capture);
+ events.push({ type: 'text', text: state.capture });
+ }
+ state.capture = '';
+ state.capturing = false;
+ resetIncrementalToolState(state);
+ }
+ if (state.pending) {
+ noteText(state, state.pending);
+ events.push({ type: 'text', text: state.pending });
+ state.pending = '';
+ }
+ return events;
+}
+
+function splitSafeContentForToolDetection(s) {
+ const text = s || '';
+ if (!text) {
+ return ['', ''];
+ }
+ const suspiciousStart = findSuspiciousPrefixStart(text);
+ if (suspiciousStart < 0) {
+ return [text, ''];
+ }
+ if (suspiciousStart > 0) {
+ return [text.slice(0, suspiciousStart), text.slice(suspiciousStart)];
+ }
+ // If suspicious content starts at the beginning, keep holding until we can
+ // either parse a full tool JSON block or reach stream flush.
+ return ['', text];
+}
+
+function findSuspiciousPrefixStart(s) {
+ let start = -1;
+ for (const needle of ['{', '[', '```']) {
+ const idx = s.lastIndexOf(needle);
+ if (idx > start) {
+ start = idx;
+ }
+ }
+ return start;
+}
+
+function findToolSegmentStart(s) {
+ if (!s) {
+ return -1;
+ }
+ const lower = s.toLowerCase();
+ let offset = 0;
+ // eslint-disable-next-line no-constant-condition
+ while (true) {
+ const keyRel = lower.indexOf('tool_calls', offset);
+ if (keyRel < 0) {
+ return -1;
+ }
+ const keyIdx = keyRel;
+ const start = s.slice(0, keyIdx).lastIndexOf('{');
+ const candidateStart = start >= 0 ? start : keyIdx;
+ if (!insideCodeFence(s.slice(0, candidateStart))) {
+ return candidateStart;
+ }
+ offset = keyIdx + 'tool_calls'.length;
+ }
+}
+
+function consumeToolCapture(state, toolNames) {
+ const captured = state.capture;
+ if (!captured) {
+ return { ready: false, prefix: '', calls: [], suffix: '' };
+ }
+ const lower = captured.toLowerCase();
+ const keyIdx = lower.indexOf('tool_calls');
+ if (keyIdx < 0) {
+ return { ready: false, prefix: '', calls: [], suffix: '' };
+ }
+ const start = captured.slice(0, keyIdx).lastIndexOf('{');
+ if (start < 0) {
+ return { ready: false, prefix: '', calls: [], suffix: '' };
+ }
+ const obj = extractJSONObjectFrom(captured, start);
+ if (!obj.ok) {
+ return { ready: false, prefix: '', calls: [], suffix: '' };
+ }
+ const prefixPart = captured.slice(0, start);
+ const suffixPart = captured.slice(obj.end);
+ if (insideCodeFence((state.recentTextTail || '') + prefixPart)) {
+ return {
+ ready: true,
+ prefix: captured,
+ calls: [],
+ suffix: '',
+ };
+ }
+ const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames);
+ if (parsed.length === 0) {
+ if (state.toolNameSent) {
+ return {
+ ready: true,
+ prefix: prefixPart,
+ calls: [],
+ suffix: suffixPart,
+ };
+ }
+ return {
+ ready: true,
+ prefix: captured,
+ calls: [],
+ suffix: '',
+ };
+ }
+ if (state.toolNameSent) {
+ if (parsed.length > 1) {
+ return {
+ ready: true,
+ prefix: prefixPart,
+ calls: parsed.slice(1),
+ suffix: suffixPart,
+ };
+ }
+ return {
+ ready: true,
+ prefix: prefixPart,
+ calls: [],
+ suffix: suffixPart,
+ };
+ }
+ return {
+ ready: true,
+ prefix: prefixPart,
+ calls: parsed,
+ suffix: suffixPart,
+ };
+}
+
+module.exports = {
+ processToolSieveChunk,
+ flushToolSieve,
+};
diff --git a/api/helpers/stream-tool-sieve/state.js b/api/helpers/stream-tool-sieve/state.js
new file mode 100644
index 0000000..a2d2b5c
--- /dev/null
+++ b/api/helpers/stream-tool-sieve/state.js
@@ -0,0 +1,91 @@
+'use strict';
+
+const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024;
+const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 256;
+
+function createToolSieveState() {
+ return {
+ pending: '',
+ capture: '',
+ capturing: false,
+ recentTextTail: '',
+ toolNameSent: false,
+ toolName: '',
+ toolArgsStart: -1,
+ toolArgsSent: -1,
+ toolArgsString: false,
+ toolArgsDone: false,
+ };
+}
+
+function resetIncrementalToolState(state) {
+ state.toolNameSent = false;
+ state.toolName = '';
+ state.toolArgsStart = -1;
+ state.toolArgsSent = -1;
+ state.toolArgsString = false;
+ state.toolArgsDone = false;
+}
+
+function noteText(state, text) {
+ if (!state || !hasMeaningfulText(text)) {
+ return;
+ }
+ state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
+}
+
+function appendTail(prev, next, max) {
+ const left = typeof prev === 'string' ? prev : '';
+ const right = typeof next === 'string' ? next : '';
+ if (!Number.isFinite(max) || max <= 0) {
+ return '';
+ }
+ const combined = left + right;
+ if (combined.length <= max) {
+ return combined;
+ }
+ return combined.slice(combined.length - max);
+}
+
+function looksLikeToolExampleContext(text) {
+ return insideCodeFence(text);
+}
+
+function insideCodeFence(text) {
+ const t = typeof text === 'string' ? text : '';
+ if (!t) {
+ return false;
+ }
+ const ticks = (t.match(/```/g) || []).length;
+ return ticks % 2 === 1;
+}
+
+function hasMeaningfulText(text) {
+ return toStringSafe(text) !== '';
+}
+
+function toStringSafe(v) {
+ if (typeof v === 'string') {
+ return v.trim();
+ }
+ if (Array.isArray(v)) {
+ return toStringSafe(v[0]);
+ }
+ if (v == null) {
+ return '';
+ }
+ return String(v).trim();
+}
+
+module.exports = {
+ TOOL_SIEVE_CAPTURE_LIMIT,
+ TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
+ createToolSieveState,
+ resetIncrementalToolState,
+ noteText,
+ appendTail,
+ looksLikeToolExampleContext,
+ insideCodeFence,
+ hasMeaningfulText,
+ toStringSafe,
+};
diff --git a/internal/account/pool.go b/internal/account/pool.go
deleted file mode 100644
index 12d8874..0000000
--- a/internal/account/pool.go
+++ /dev/null
@@ -1,363 +0,0 @@
-package account
-
-import (
- "context"
- "os"
- "sort"
- "strconv"
- "strings"
- "sync"
-
- "ds2api/internal/config"
-)
-
-type Pool struct {
- store *config.Store
- mu sync.Mutex
- queue []string
- inUse map[string]int
- waiters []chan struct{}
- maxInflightPerAccount int
- recommendedConcurrency int
- maxQueueSize int
- globalMaxInflight int
-}
-
-func NewPool(store *config.Store) *Pool {
- maxPer := 2
- if store != nil {
- maxPer = store.RuntimeAccountMaxInflight()
- }
- p := &Pool{
- store: store,
- inUse: map[string]int{},
- maxInflightPerAccount: maxPer,
- }
- p.Reset()
- return p
-}
-
-func (p *Pool) Reset() {
- accounts := p.store.Accounts()
- sort.SliceStable(accounts, func(i, j int) bool {
- iHas := accounts[i].Token != ""
- jHas := accounts[j].Token != ""
- if iHas == jHas {
- return i < j
- }
- return iHas
- })
- ids := make([]string, 0, len(accounts))
- for _, a := range accounts {
- id := a.Identifier()
- if id != "" {
- ids = append(ids, id)
- }
- }
- if p.store != nil {
- p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
- } else {
- p.maxInflightPerAccount = maxInflightFromEnv()
- }
- recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
- queueLimit := maxQueueFromEnv(recommended)
- globalLimit := recommended
- if p.store != nil {
- queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
- globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
- }
- p.mu.Lock()
- defer p.mu.Unlock()
- p.drainWaitersLocked()
- p.queue = ids
- p.inUse = map[string]int{}
- p.recommendedConcurrency = recommended
- p.maxQueueSize = queueLimit
- p.globalMaxInflight = globalLimit
- config.Logger.Info(
- "[init_account_queue] initialized",
- "total", len(ids),
- "max_inflight_per_account", p.maxInflightPerAccount,
- "global_max_inflight", p.globalMaxInflight,
- "recommended_concurrency", p.recommendedConcurrency,
- "max_queue_size", p.maxQueueSize,
- )
-}
-
-func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
- p.mu.Lock()
- defer p.mu.Unlock()
- return p.acquireLocked(target, normalizeExclude(exclude))
-}
-
-func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
- if ctx == nil {
- ctx = context.Background()
- }
- exclude = normalizeExclude(exclude)
- for {
- if ctx.Err() != nil {
- return config.Account{}, false
- }
-
- p.mu.Lock()
- if acc, ok := p.acquireLocked(target, exclude); ok {
- p.mu.Unlock()
- return acc, true
- }
- if !p.canQueueLocked(target, exclude) {
- p.mu.Unlock()
- return config.Account{}, false
- }
- waiter := make(chan struct{})
- p.waiters = append(p.waiters, waiter)
- p.mu.Unlock()
-
- select {
- case <-ctx.Done():
- p.mu.Lock()
- p.removeWaiterLocked(waiter)
- p.mu.Unlock()
- return config.Account{}, false
- case <-waiter:
- }
- }
-}
-
-func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
- if target != "" {
- if exclude[target] || !p.canAcquireIDLocked(target) {
- return config.Account{}, false
- }
- acc, ok := p.store.FindAccount(target)
- if !ok {
- return config.Account{}, false
- }
- p.inUse[target]++
- p.bumpQueue(target)
- return acc, true
- }
-
- if acc, ok := p.tryAcquire(exclude, true); ok {
- return acc, true
- }
- if acc, ok := p.tryAcquire(exclude, false); ok {
- return acc, true
- }
- return config.Account{}, false
-}
-
-func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
- for i := 0; i < len(p.queue); i++ {
- id := p.queue[i]
- if exclude[id] || !p.canAcquireIDLocked(id) {
- continue
- }
- acc, ok := p.store.FindAccount(id)
- if !ok {
- continue
- }
- if requireToken && acc.Token == "" {
- continue
- }
- p.inUse[id]++
- p.bumpQueue(id)
- return acc, true
- }
- return config.Account{}, false
-}
-
-func (p *Pool) bumpQueue(accountID string) {
- for i, id := range p.queue {
- if id != accountID {
- continue
- }
- p.queue = append(p.queue[:i], p.queue[i+1:]...)
- p.queue = append(p.queue, accountID)
- return
- }
-}
-
-func (p *Pool) Release(accountID string) {
- if accountID == "" {
- return
- }
- p.mu.Lock()
- defer p.mu.Unlock()
- count := p.inUse[accountID]
- if count <= 0 {
- return
- }
- if count == 1 {
- delete(p.inUse, accountID)
- p.notifyWaiterLocked()
- return
- }
- p.inUse[accountID] = count - 1
- p.notifyWaiterLocked()
-}
-
-func (p *Pool) Status() map[string]any {
- p.mu.Lock()
- defer p.mu.Unlock()
- available := make([]string, 0, len(p.queue))
- inUseAccounts := make([]string, 0, len(p.inUse))
- inUseSlots := 0
- for _, id := range p.queue {
- if p.inUse[id] < p.maxInflightPerAccount {
- available = append(available, id)
- }
- }
- for id, count := range p.inUse {
- if count > 0 {
- inUseAccounts = append(inUseAccounts, id)
- inUseSlots += count
- }
- }
- sort.Strings(inUseAccounts)
- return map[string]any{
- "available": len(available),
- "in_use": inUseSlots,
- "total": len(p.store.Accounts()),
- "available_accounts": available,
- "in_use_accounts": inUseAccounts,
- "max_inflight_per_account": p.maxInflightPerAccount,
- "global_max_inflight": p.globalMaxInflight,
- "recommended_concurrency": p.recommendedConcurrency,
- "waiting": len(p.waiters),
- "max_queue_size": p.maxQueueSize,
- }
-}
-
-func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
- if maxInflightPerAccount <= 0 {
- maxInflightPerAccount = 1
- }
- if maxQueueSize < 0 {
- maxQueueSize = 0
- }
- if globalMaxInflight <= 0 {
- globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
- if globalMaxInflight <= 0 {
- globalMaxInflight = maxInflightPerAccount
- }
- }
- p.mu.Lock()
- defer p.mu.Unlock()
- p.maxInflightPerAccount = maxInflightPerAccount
- p.maxQueueSize = maxQueueSize
- p.globalMaxInflight = globalMaxInflight
- p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
- p.notifyWaiterLocked()
-}
-
-func maxInflightFromEnv() int {
- for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
- raw := strings.TrimSpace(os.Getenv(key))
- if raw == "" {
- continue
- }
- n, err := strconv.Atoi(raw)
- if err == nil && n > 0 {
- return n
- }
- }
- return 2
-}
-
-func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
- if accountCount <= 0 {
- return 0
- }
- if maxInflightPerAccount <= 0 {
- maxInflightPerAccount = 2
- }
- return accountCount * maxInflightPerAccount
-}
-
-func normalizeExclude(exclude map[string]bool) map[string]bool {
- if exclude == nil {
- return map[string]bool{}
- }
- return exclude
-}
-
-func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
- if target != "" {
- if exclude[target] {
- return false
- }
- if _, ok := p.store.FindAccount(target); !ok {
- return false
- }
- }
- if p.maxQueueSize <= 0 {
- return false
- }
- return len(p.waiters) < p.maxQueueSize
-}
-
-func (p *Pool) notifyWaiterLocked() {
- if len(p.waiters) == 0 {
- return
- }
- waiter := p.waiters[0]
- p.waiters = p.waiters[1:]
- close(waiter)
-}
-
-func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
- for i, w := range p.waiters {
- if w != waiter {
- continue
- }
- p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
- return true
- }
- return false
-}
-
-func (p *Pool) drainWaitersLocked() {
- for _, waiter := range p.waiters {
- close(waiter)
- }
- p.waiters = nil
-}
-
-func maxQueueFromEnv(defaultSize int) int {
- for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
- raw := strings.TrimSpace(os.Getenv(key))
- if raw == "" {
- continue
- }
- n, err := strconv.Atoi(raw)
- if err == nil && n >= 0 {
- return n
- }
- }
- if defaultSize < 0 {
- return 0
- }
- return defaultSize
-}
-
-func (p *Pool) canAcquireIDLocked(accountID string) bool {
- if accountID == "" {
- return false
- }
- if p.inUse[accountID] >= p.maxInflightPerAccount {
- return false
- }
- if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
- return false
- }
- return true
-}
-
-func (p *Pool) currentInUseLocked() int {
- total := 0
- for _, n := range p.inUse {
- total += n
- }
- return total
-}
diff --git a/internal/account/pool_acquire.go b/internal/account/pool_acquire.go
new file mode 100644
index 0000000..b0c548c
--- /dev/null
+++ b/internal/account/pool_acquire.go
@@ -0,0 +1,108 @@
+package account
+
+import (
+ "context"
+
+ "ds2api/internal/config"
+)
+
+func (p *Pool) Acquire(target string, exclude map[string]bool) (config.Account, bool) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ return p.acquireLocked(target, normalizeExclude(exclude))
+}
+
+func (p *Pool) AcquireWait(ctx context.Context, target string, exclude map[string]bool) (config.Account, bool) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ exclude = normalizeExclude(exclude)
+ for {
+ if ctx.Err() != nil {
+ return config.Account{}, false
+ }
+
+ p.mu.Lock()
+ if acc, ok := p.acquireLocked(target, exclude); ok {
+ p.mu.Unlock()
+ return acc, true
+ }
+ if !p.canQueueLocked(target, exclude) {
+ p.mu.Unlock()
+ return config.Account{}, false
+ }
+ waiter := make(chan struct{})
+ p.waiters = append(p.waiters, waiter)
+ p.mu.Unlock()
+
+ select {
+ case <-ctx.Done():
+ p.mu.Lock()
+ p.removeWaiterLocked(waiter)
+ p.mu.Unlock()
+ return config.Account{}, false
+ case <-waiter:
+ }
+ }
+}
+
+func (p *Pool) acquireLocked(target string, exclude map[string]bool) (config.Account, bool) {
+ if target != "" {
+ if exclude[target] || !p.canAcquireIDLocked(target) {
+ return config.Account{}, false
+ }
+ acc, ok := p.store.FindAccount(target)
+ if !ok {
+ return config.Account{}, false
+ }
+ p.inUse[target]++
+ p.bumpQueue(target)
+ return acc, true
+ }
+
+ if acc, ok := p.tryAcquire(exclude, true); ok {
+ return acc, true
+ }
+ if acc, ok := p.tryAcquire(exclude, false); ok {
+ return acc, true
+ }
+ return config.Account{}, false
+}
+
+func (p *Pool) tryAcquire(exclude map[string]bool, requireToken bool) (config.Account, bool) {
+ for i := 0; i < len(p.queue); i++ {
+ id := p.queue[i]
+ if exclude[id] || !p.canAcquireIDLocked(id) {
+ continue
+ }
+ acc, ok := p.store.FindAccount(id)
+ if !ok {
+ continue
+ }
+ if requireToken && acc.Token == "" {
+ continue
+ }
+ p.inUse[id]++
+ p.bumpQueue(id)
+ return acc, true
+ }
+ return config.Account{}, false
+}
+
+func (p *Pool) bumpQueue(accountID string) {
+ for i, id := range p.queue {
+ if id != accountID {
+ continue
+ }
+ p.queue = append(p.queue[:i], p.queue[i+1:]...)
+ p.queue = append(p.queue, accountID)
+ return
+ }
+}
+
+func normalizeExclude(exclude map[string]bool) map[string]bool {
+ if exclude == nil {
+ return map[string]bool{}
+ }
+ return exclude
+}
diff --git a/internal/account/pool_core.go b/internal/account/pool_core.go
new file mode 100644
index 0000000..90e2594
--- /dev/null
+++ b/internal/account/pool_core.go
@@ -0,0 +1,132 @@
+package account
+
+import (
+ "sort"
+ "sync"
+
+ "ds2api/internal/config"
+)
+
+type Pool struct {
+ store *config.Store
+ mu sync.Mutex
+ queue []string
+ inUse map[string]int
+ waiters []chan struct{}
+ maxInflightPerAccount int
+ recommendedConcurrency int
+ maxQueueSize int
+ globalMaxInflight int
+}
+
+func NewPool(store *config.Store) *Pool {
+ maxPer := 2
+ if store != nil {
+ maxPer = store.RuntimeAccountMaxInflight()
+ }
+ p := &Pool{
+ store: store,
+ inUse: map[string]int{},
+ maxInflightPerAccount: maxPer,
+ }
+ p.Reset()
+ return p
+}
+
+func (p *Pool) Reset() {
+ accounts := p.store.Accounts()
+ sort.SliceStable(accounts, func(i, j int) bool {
+ iHas := accounts[i].Token != ""
+ jHas := accounts[j].Token != ""
+ if iHas == jHas {
+ return i < j
+ }
+ return iHas
+ })
+ ids := make([]string, 0, len(accounts))
+ for _, a := range accounts {
+ id := a.Identifier()
+ if id != "" {
+ ids = append(ids, id)
+ }
+ }
+ if p.store != nil {
+ p.maxInflightPerAccount = p.store.RuntimeAccountMaxInflight()
+ } else {
+ p.maxInflightPerAccount = maxInflightFromEnv()
+ }
+ recommended := defaultRecommendedConcurrency(len(ids), p.maxInflightPerAccount)
+ queueLimit := maxQueueFromEnv(recommended)
+ globalLimit := recommended
+ if p.store != nil {
+ queueLimit = p.store.RuntimeAccountMaxQueue(recommended)
+ globalLimit = p.store.RuntimeGlobalMaxInflight(recommended)
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.drainWaitersLocked()
+ p.queue = ids
+ p.inUse = map[string]int{}
+ p.recommendedConcurrency = recommended
+ p.maxQueueSize = queueLimit
+ p.globalMaxInflight = globalLimit
+ config.Logger.Info(
+ "[init_account_queue] initialized",
+ "total", len(ids),
+ "max_inflight_per_account", p.maxInflightPerAccount,
+ "global_max_inflight", p.globalMaxInflight,
+ "recommended_concurrency", p.recommendedConcurrency,
+ "max_queue_size", p.maxQueueSize,
+ )
+}
+
+func (p *Pool) Release(accountID string) {
+ if accountID == "" {
+ return
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ count := p.inUse[accountID]
+ if count <= 0 {
+ return
+ }
+ if count == 1 {
+ delete(p.inUse, accountID)
+ p.notifyWaiterLocked()
+ return
+ }
+ p.inUse[accountID] = count - 1
+ p.notifyWaiterLocked()
+}
+
+func (p *Pool) Status() map[string]any {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ available := make([]string, 0, len(p.queue))
+ inUseAccounts := make([]string, 0, len(p.inUse))
+ inUseSlots := 0
+ for _, id := range p.queue {
+ if p.inUse[id] < p.maxInflightPerAccount {
+ available = append(available, id)
+ }
+ }
+ for id, count := range p.inUse {
+ if count > 0 {
+ inUseAccounts = append(inUseAccounts, id)
+ inUseSlots += count
+ }
+ }
+ sort.Strings(inUseAccounts)
+ return map[string]any{
+ "available": len(available),
+ "in_use": inUseSlots,
+ "total": len(p.store.Accounts()),
+ "available_accounts": available,
+ "in_use_accounts": inUseAccounts,
+ "max_inflight_per_account": p.maxInflightPerAccount,
+ "global_max_inflight": p.globalMaxInflight,
+ "recommended_concurrency": p.recommendedConcurrency,
+ "waiting": len(p.waiters),
+ "max_queue_size": p.maxQueueSize,
+ }
+}
diff --git a/internal/account/pool_limits.go b/internal/account/pool_limits.go
new file mode 100644
index 0000000..0f0854f
--- /dev/null
+++ b/internal/account/pool_limits.go
@@ -0,0 +1,91 @@
+package account
+
+import (
+ "os"
+ "strconv"
+ "strings"
+)
+
+func (p *Pool) ApplyRuntimeLimits(maxInflightPerAccount, maxQueueSize, globalMaxInflight int) {
+ if maxInflightPerAccount <= 0 {
+ maxInflightPerAccount = 1
+ }
+ if maxQueueSize < 0 {
+ maxQueueSize = 0
+ }
+ if globalMaxInflight <= 0 {
+ globalMaxInflight = maxInflightPerAccount * len(p.store.Accounts())
+ if globalMaxInflight <= 0 {
+ globalMaxInflight = maxInflightPerAccount
+ }
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.maxInflightPerAccount = maxInflightPerAccount
+ p.maxQueueSize = maxQueueSize
+ p.globalMaxInflight = globalMaxInflight
+ p.recommendedConcurrency = defaultRecommendedConcurrency(len(p.queue), p.maxInflightPerAccount)
+ p.notifyWaiterLocked()
+}
+
+func maxInflightFromEnv() int {
+ for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
+ raw := strings.TrimSpace(os.Getenv(key))
+ if raw == "" {
+ continue
+ }
+ n, err := strconv.Atoi(raw)
+ if err == nil && n > 0 {
+ return n
+ }
+ }
+ return 2
+}
+
+func defaultRecommendedConcurrency(accountCount, maxInflightPerAccount int) int {
+ if accountCount <= 0 {
+ return 0
+ }
+ if maxInflightPerAccount <= 0 {
+ maxInflightPerAccount = 2
+ }
+ return accountCount * maxInflightPerAccount
+}
+
+func maxQueueFromEnv(defaultSize int) int {
+ for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
+ raw := strings.TrimSpace(os.Getenv(key))
+ if raw == "" {
+ continue
+ }
+ n, err := strconv.Atoi(raw)
+ if err == nil && n >= 0 {
+ return n
+ }
+ }
+ if defaultSize < 0 {
+ return 0
+ }
+ return defaultSize
+}
+
+func (p *Pool) canAcquireIDLocked(accountID string) bool {
+ if accountID == "" {
+ return false
+ }
+ if p.inUse[accountID] >= p.maxInflightPerAccount {
+ return false
+ }
+ if p.globalMaxInflight > 0 && p.currentInUseLocked() >= p.globalMaxInflight {
+ return false
+ }
+ return true
+}
+
+func (p *Pool) currentInUseLocked() int {
+ total := 0
+ for _, n := range p.inUse {
+ total += n
+ }
+ return total
+}
diff --git a/internal/account/pool_waiters.go b/internal/account/pool_waiters.go
new file mode 100644
index 0000000..40bd146
--- /dev/null
+++ b/internal/account/pool_waiters.go
@@ -0,0 +1,43 @@
+package account
+
+func (p *Pool) canQueueLocked(target string, exclude map[string]bool) bool {
+ if target != "" {
+ if exclude[target] {
+ return false
+ }
+ if _, ok := p.store.FindAccount(target); !ok {
+ return false
+ }
+ }
+ if p.maxQueueSize <= 0 {
+ return false
+ }
+ return len(p.waiters) < p.maxQueueSize
+}
+
+func (p *Pool) notifyWaiterLocked() {
+ if len(p.waiters) == 0 {
+ return
+ }
+ waiter := p.waiters[0]
+ p.waiters = p.waiters[1:]
+ close(waiter)
+}
+
+func (p *Pool) removeWaiterLocked(waiter chan struct{}) bool {
+ for i, w := range p.waiters {
+ if w != waiter {
+ continue
+ }
+ p.waiters = append(p.waiters[:i], p.waiters[i+1:]...)
+ return true
+ }
+ return false
+}
+
+func (p *Pool) drainWaitersLocked() {
+ for _, waiter := range p.waiters {
+ close(waiter)
+ }
+ p.waiters = nil
+}
diff --git a/internal/adapter/claude/error_shape_test.go b/internal/adapter/claude/error_shape_test.go
index 910fce8..b9dc469 100644
--- a/internal/adapter/claude/error_shape_test.go
+++ b/internal/adapter/claude/error_shape_test.go
@@ -32,4 +32,3 @@ func TestWriteClaudeErrorIncludesUnifiedFields(t *testing.T) {
t.Fatal("expected param field")
}
}
-
diff --git a/internal/adapter/claude/handler.go b/internal/adapter/claude/handler.go
deleted file mode 100644
index 2fa6796..0000000
--- a/internal/adapter/claude/handler.go
+++ /dev/null
@@ -1,368 +0,0 @@
-package claude
-
-import (
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
-
- "github.com/go-chi/chi/v5"
-
- "ds2api/internal/auth"
- "ds2api/internal/config"
- "ds2api/internal/deepseek"
- claudefmt "ds2api/internal/format/claude"
- "ds2api/internal/sse"
- streamengine "ds2api/internal/stream"
- "ds2api/internal/util"
-)
-
-// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
-var writeJSON = util.WriteJSON
-
-type Handler struct {
- Store ConfigReader
- Auth AuthResolver
- DS DeepSeekCaller
-}
-
-var (
- claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
- claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
- claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
-)
-
-func RegisterRoutes(r chi.Router, h *Handler) {
- r.Get("/anthropic/v1/models", h.ListModels)
- r.Post("/anthropic/v1/messages", h.Messages)
- r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
- r.Post("/v1/messages", h.Messages)
- r.Post("/messages", h.Messages)
- r.Post("/v1/messages/count_tokens", h.CountTokens)
- r.Post("/messages/count_tokens", h.CountTokens)
-}
-
-func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
- writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
-}
-
-func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
- if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
- r.Header.Set("anthropic-version", "2023-06-01")
- }
- a, err := h.Auth.Determine(r)
- if err != nil {
- status := http.StatusUnauthorized
- detail := err.Error()
- if err == auth.ErrNoAccount {
- status = http.StatusTooManyRequests
- }
- writeClaudeError(w, status, detail)
- return
- }
- defer h.Auth.Release(a)
-
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeClaudeError(w, http.StatusBadRequest, "invalid json")
- return
- }
- norm, err := normalizeClaudeRequest(h.Store, req)
- if err != nil {
- writeClaudeError(w, http.StatusBadRequest, err.Error())
- return
- }
- stdReq := norm.Standard
-
- sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
- if err != nil {
- writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
- return
- }
- pow, err := h.DS.GetPow(r.Context(), a, 3)
- if err != nil {
- writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
- return
- }
- requestPayload := stdReq.CompletionPayload(sessionID)
- resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
- if err != nil {
- writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
- return
- }
- if resp.StatusCode != http.StatusOK {
- defer resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
- writeClaudeError(w, http.StatusInternalServerError, string(body))
- return
- }
-
- if stdReq.Stream {
- h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
- return
- }
- result := sse.CollectStream(resp, stdReq.Thinking, true)
- respBody := claudefmt.BuildMessageResponse(
- fmt.Sprintf("msg_%d", time.Now().UnixNano()),
- stdReq.ResponseModel,
- norm.NormalizedMessages,
- result.Thinking,
- result.Text,
- stdReq.ToolNames,
- )
- writeJSON(w, http.StatusOK, respBody)
-}
-
-func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
- a, err := h.Auth.Determine(r)
- if err != nil {
- writeClaudeError(w, http.StatusUnauthorized, err.Error())
- return
- }
- defer h.Auth.Release(a)
-
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeClaudeError(w, http.StatusBadRequest, "invalid json")
- return
- }
- model, _ := req["model"].(string)
- messages, _ := req["messages"].([]any)
- if model == "" || len(messages) == 0 {
- writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
- return
- }
- inputTokens := 0
- if sys, ok := req["system"].(string); ok {
- inputTokens += util.EstimateTokens(sys)
- }
- for _, item := range messages {
- msg, ok := item.(map[string]any)
- if !ok {
- continue
- }
- inputTokens += 2
- inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
- }
- if tools, ok := req["tools"].([]any); ok {
- for _, t := range tools {
- b, _ := json.Marshal(t)
- inputTokens += util.EstimateTokens(string(b))
- }
- }
- if inputTokens < 1 {
- inputTokens = 1
- }
- writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
-}
-
-func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- writeClaudeError(w, http.StatusInternalServerError, string(body))
- return
- }
-
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache, no-transform")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
- rc := http.NewResponseController(w)
- _, canFlush := w.(http.Flusher)
- if !canFlush {
- config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
- }
-
- streamRuntime := newClaudeStreamRuntime(
- w,
- rc,
- canFlush,
- model,
- messages,
- thinkingEnabled,
- searchEnabled,
- toolNames,
- )
- streamRuntime.sendMessageStart()
-
- initialType := "text"
- if thinkingEnabled {
- initialType = "thinking"
- }
- streamengine.ConsumeSSE(streamengine.ConsumeConfig{
- Context: r.Context(),
- Body: resp.Body,
- ThinkingEnabled: thinkingEnabled,
- InitialType: initialType,
- KeepAliveInterval: claudeStreamPingInterval,
- IdleTimeout: claudeStreamIdleTimeout,
- MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
- }, streamengine.ConsumeHooks{
- OnKeepAlive: func() {
- streamRuntime.sendPing()
- },
- OnParsed: streamRuntime.onParsed,
- OnFinalize: streamRuntime.onFinalize,
- })
-}
-
-func writeClaudeError(w http.ResponseWriter, status int, message string) {
- code := "invalid_request"
- switch status {
- case http.StatusUnauthorized:
- code = "authentication_failed"
- case http.StatusTooManyRequests:
- code = "rate_limit_exceeded"
- case http.StatusNotFound:
- code = "not_found"
- case http.StatusInternalServerError:
- code = "internal_error"
- }
- writeJSON(w, status, map[string]any{
- "error": map[string]any{
- "type": "invalid_request_error",
- "message": message,
- "code": code,
- "param": nil,
- },
- })
-}
-
-func normalizeClaudeMessages(messages []any) []any {
- out := make([]any, 0, len(messages))
- for _, m := range messages {
- msg, ok := m.(map[string]any)
- if !ok {
- continue
- }
- copied := cloneMap(msg)
- switch content := msg["content"].(type) {
- case []any:
- parts := make([]string, 0, len(content))
- for _, block := range content {
- b, ok := block.(map[string]any)
- if !ok {
- continue
- }
- typeStr, _ := b["type"].(string)
- if typeStr == "text" {
- if t, ok := b["text"].(string); ok {
- parts = append(parts, t)
- }
- }
- if typeStr == "tool_result" {
- parts = append(parts, formatClaudeToolResultForPrompt(b))
- }
- }
- copied["content"] = strings.Join(parts, "\n")
- }
- out = append(out, copied)
- }
- return out
-}
-
-func buildClaudeToolPrompt(tools []any) string {
- parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
- for _, t := range tools {
- m, ok := t.(map[string]any)
- if !ok {
- continue
- }
- name, _ := m["name"].(string)
- desc, _ := m["description"].(string)
- schema, _ := json.Marshal(m["input_schema"])
- parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
- }
- parts = append(parts,
- "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
- "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
- "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
- )
- return strings.Join(parts, "\n\n")
-}
-
-func formatClaudeToolResultForPrompt(block map[string]any) string {
- if block == nil {
- return ""
- }
- toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
- if toolCallID == "" {
- toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
- }
- if toolCallID == "" {
- toolCallID = "unknown"
- }
- name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
- if name == "" {
- name = "unknown"
- }
- content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
- if content == "" {
- content = "null"
- }
- return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
-}
-
-func hasSystemMessage(messages []any) bool {
- for _, m := range messages {
- msg, ok := m.(map[string]any)
- if ok && msg["role"] == "system" {
- return true
- }
- }
- return false
-}
-
-func extractClaudeToolNames(tools []any) []string {
- out := make([]string, 0, len(tools))
- for _, t := range tools {
- m, ok := t.(map[string]any)
- if !ok {
- continue
- }
- if name, ok := m["name"].(string); ok && name != "" {
- out = append(out, name)
- }
- }
- return out
-}
-
-func toMessageMaps(v any) []map[string]any {
- arr, ok := v.([]any)
- if !ok {
- return nil
- }
- out := make([]map[string]any, 0, len(arr))
- for _, item := range arr {
- if m, ok := item.(map[string]any); ok {
- out = append(out, m)
- }
- }
- return out
-}
-
-func extractMessageContent(v any) string {
- switch x := v.(type) {
- case string:
- return x
- case []any:
- parts := make([]string, 0, len(x))
- for _, it := range x {
- parts = append(parts, fmt.Sprintf("%v", it))
- }
- return strings.Join(parts, "\n")
- default:
- return fmt.Sprintf("%v", x)
- }
-}
-
-func cloneMap(in map[string]any) map[string]any {
- out := make(map[string]any, len(in))
- for k, v := range in {
- out[k] = v
- }
- return out
-}
diff --git a/internal/adapter/claude/handler_errors.go b/internal/adapter/claude/handler_errors.go
new file mode 100644
index 0000000..f1188d6
--- /dev/null
+++ b/internal/adapter/claude/handler_errors.go
@@ -0,0 +1,25 @@
+package claude
+
+import "net/http"
+
+func writeClaudeError(w http.ResponseWriter, status int, message string) {
+ code := "invalid_request"
+ switch status {
+ case http.StatusUnauthorized:
+ code = "authentication_failed"
+ case http.StatusTooManyRequests:
+ code = "rate_limit_exceeded"
+ case http.StatusNotFound:
+ code = "not_found"
+ case http.StatusInternalServerError:
+ code = "internal_error"
+ }
+ writeJSON(w, status, map[string]any{
+ "error": map[string]any{
+ "type": "invalid_request_error",
+ "message": message,
+ "code": code,
+ "param": nil,
+ },
+ })
+}
diff --git a/internal/adapter/claude/handler_messages.go b/internal/adapter/claude/handler_messages.go
new file mode 100644
index 0000000..1c4272b
--- /dev/null
+++ b/internal/adapter/claude/handler_messages.go
@@ -0,0 +1,134 @@
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+ claudefmt "ds2api/internal/format/claude"
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+)
+
+func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
+ if strings.TrimSpace(r.Header.Get("anthropic-version")) == "" {
+ r.Header.Set("anthropic-version", "2023-06-01")
+ }
+ a, err := h.Auth.Determine(r)
+ if err != nil {
+ status := http.StatusUnauthorized
+ detail := err.Error()
+ if err == auth.ErrNoAccount {
+ status = http.StatusTooManyRequests
+ }
+ writeClaudeError(w, status, detail)
+ return
+ }
+ defer h.Auth.Release(a)
+
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeClaudeError(w, http.StatusBadRequest, "invalid json")
+ return
+ }
+ norm, err := normalizeClaudeRequest(h.Store, req)
+ if err != nil {
+ writeClaudeError(w, http.StatusBadRequest, err.Error())
+ return
+ }
+ stdReq := norm.Standard
+
+ sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
+ if err != nil {
+ writeClaudeError(w, http.StatusUnauthorized, "invalid token.")
+ return
+ }
+ pow, err := h.DS.GetPow(r.Context(), a, 3)
+ if err != nil {
+ writeClaudeError(w, http.StatusUnauthorized, "Failed to get PoW")
+ return
+ }
+ requestPayload := stdReq.CompletionPayload(sessionID)
+ resp, err := h.DS.CallCompletion(r.Context(), a, requestPayload, pow, 3)
+ if err != nil {
+ writeClaudeError(w, http.StatusInternalServerError, "Failed to get Claude response.")
+ return
+ }
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+ writeClaudeError(w, http.StatusInternalServerError, string(body))
+ return
+ }
+
+ if stdReq.Stream {
+ h.handleClaudeStreamRealtime(w, r, resp, stdReq.ResponseModel, norm.NormalizedMessages, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
+ return
+ }
+ result := sse.CollectStream(resp, stdReq.Thinking, true)
+ respBody := claudefmt.BuildMessageResponse(
+ fmt.Sprintf("msg_%d", time.Now().UnixNano()),
+ stdReq.ResponseModel,
+ norm.NormalizedMessages,
+ result.Thinking,
+ result.Text,
+ stdReq.ToolNames,
+ )
+ writeJSON(w, http.StatusOK, respBody)
+}
+
+func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Request, resp *http.Response, model string, messages []any, thinkingEnabled, searchEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeClaudeError(w, http.StatusInternalServerError, string(body))
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache, no-transform")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("X-Accel-Buffering", "no")
+ rc := http.NewResponseController(w)
+ _, canFlush := w.(http.Flusher)
+ if !canFlush {
+ config.Logger.Warn("[claude_stream] response writer does not support flush; streaming may be buffered")
+ }
+
+ streamRuntime := newClaudeStreamRuntime(
+ w,
+ rc,
+ canFlush,
+ model,
+ messages,
+ thinkingEnabled,
+ searchEnabled,
+ toolNames,
+ )
+ streamRuntime.sendMessageStart()
+
+ initialType := "text"
+ if thinkingEnabled {
+ initialType = "thinking"
+ }
+ streamengine.ConsumeSSE(streamengine.ConsumeConfig{
+ Context: r.Context(),
+ Body: resp.Body,
+ ThinkingEnabled: thinkingEnabled,
+ InitialType: initialType,
+ KeepAliveInterval: claudeStreamPingInterval,
+ IdleTimeout: claudeStreamIdleTimeout,
+ MaxKeepAliveNoInput: claudeStreamMaxKeepaliveCnt,
+ }, streamengine.ConsumeHooks{
+ OnKeepAlive: func() {
+ streamRuntime.sendPing()
+ },
+ OnParsed: streamRuntime.onParsed,
+ OnFinalize: streamRuntime.onFinalize,
+ })
+}
diff --git a/internal/adapter/claude/handler_routes.go b/internal/adapter/claude/handler_routes.go
new file mode 100644
index 0000000..0376b2c
--- /dev/null
+++ b/internal/adapter/claude/handler_routes.go
@@ -0,0 +1,41 @@
+package claude
+
+import (
+ "net/http"
+ "time"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/config"
+ "ds2api/internal/deepseek"
+ "ds2api/internal/util"
+)
+
+// writeJSON is a package-internal alias to avoid mass-renaming all call-sites.
+var writeJSON = util.WriteJSON
+
+type Handler struct {
+ Store ConfigReader
+ Auth AuthResolver
+ DS DeepSeekCaller
+}
+
+var (
+ claudeStreamPingInterval = time.Duration(deepseek.KeepAliveTimeout) * time.Second
+ claudeStreamIdleTimeout = time.Duration(deepseek.StreamIdleTimeout) * time.Second
+ claudeStreamMaxKeepaliveCnt = deepseek.MaxKeepaliveCount
+)
+
+func RegisterRoutes(r chi.Router, h *Handler) {
+ r.Get("/anthropic/v1/models", h.ListModels)
+ r.Post("/anthropic/v1/messages", h.Messages)
+ r.Post("/anthropic/v1/messages/count_tokens", h.CountTokens)
+ r.Post("/v1/messages", h.Messages)
+ r.Post("/messages", h.Messages)
+ r.Post("/v1/messages/count_tokens", h.CountTokens)
+ r.Post("/messages/count_tokens", h.CountTokens)
+}
+
+func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
+ writeJSON(w, http.StatusOK, config.ClaudeModelsResponse())
+}
diff --git a/internal/adapter/claude/handler_tokens.go b/internal/adapter/claude/handler_tokens.go
new file mode 100644
index 0000000..a369345
--- /dev/null
+++ b/internal/adapter/claude/handler_tokens.go
@@ -0,0 +1,51 @@
+package claude
+
+import (
+ "encoding/json"
+ "net/http"
+
+ "ds2api/internal/util"
+)
+
+func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
+ a, err := h.Auth.Determine(r)
+ if err != nil {
+ writeClaudeError(w, http.StatusUnauthorized, err.Error())
+ return
+ }
+ defer h.Auth.Release(a)
+
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeClaudeError(w, http.StatusBadRequest, "invalid json")
+ return
+ }
+ model, _ := req["model"].(string)
+ messages, _ := req["messages"].([]any)
+ if model == "" || len(messages) == 0 {
+ writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
+ return
+ }
+ inputTokens := 0
+ if sys, ok := req["system"].(string); ok {
+ inputTokens += util.EstimateTokens(sys)
+ }
+ for _, item := range messages {
+ msg, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ inputTokens += 2
+ inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
+ }
+ if tools, ok := req["tools"].([]any); ok {
+ for _, t := range tools {
+ b, _ := json.Marshal(t)
+ inputTokens += util.EstimateTokens(string(b))
+ }
+ }
+ if inputTokens < 1 {
+ inputTokens = 1
+ }
+ writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
+}
diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go
new file mode 100644
index 0000000..df4c6b2
--- /dev/null
+++ b/internal/adapter/claude/handler_utils.go
@@ -0,0 +1,143 @@
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+func normalizeClaudeMessages(messages []any) []any {
+ out := make([]any, 0, len(messages))
+ for _, m := range messages {
+ msg, ok := m.(map[string]any)
+ if !ok {
+ continue
+ }
+ copied := cloneMap(msg)
+ switch content := msg["content"].(type) {
+ case []any:
+ parts := make([]string, 0, len(content))
+ for _, block := range content {
+ b, ok := block.(map[string]any)
+ if !ok {
+ continue
+ }
+ typeStr, _ := b["type"].(string)
+ if typeStr == "text" {
+ if t, ok := b["text"].(string); ok {
+ parts = append(parts, t)
+ }
+ }
+ if typeStr == "tool_result" {
+ parts = append(parts, formatClaudeToolResultForPrompt(b))
+ }
+ }
+ copied["content"] = strings.Join(parts, "\n")
+ }
+ out = append(out, copied)
+ }
+ return out
+}
+
+func buildClaudeToolPrompt(tools []any) string {
+ parts := []string{"You are Claude, a helpful AI assistant. You have access to these tools:"}
+ for _, t := range tools {
+ m, ok := t.(map[string]any)
+ if !ok {
+ continue
+ }
+ name, _ := m["name"].(string)
+ desc, _ := m["description"].(string)
+ schema, _ := json.Marshal(m["input_schema"])
+ parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema))
+ }
+ parts = append(parts,
+ "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}",
+ "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
+ "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
+ )
+ return strings.Join(parts, "\n\n")
+}
+
+func formatClaudeToolResultForPrompt(block map[string]any) string {
+ if block == nil {
+ return ""
+ }
+ toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
+ if toolCallID == "" {
+ toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
+ }
+ if toolCallID == "" {
+ toolCallID = "unknown"
+ }
+ name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
+ if name == "" {
+ name = "unknown"
+ }
+ content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
+ if content == "" {
+ content = "null"
+ }
+ return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
+}
+
+func hasSystemMessage(messages []any) bool {
+ for _, m := range messages {
+ msg, ok := m.(map[string]any)
+ if ok && msg["role"] == "system" {
+ return true
+ }
+ }
+ return false
+}
+
+func extractClaudeToolNames(tools []any) []string {
+ out := make([]string, 0, len(tools))
+ for _, t := range tools {
+ m, ok := t.(map[string]any)
+ if !ok {
+ continue
+ }
+ if name, ok := m["name"].(string); ok && name != "" {
+ out = append(out, name)
+ }
+ }
+ return out
+}
+
+func toMessageMaps(v any) []map[string]any {
+ arr, ok := v.([]any)
+ if !ok {
+ return nil
+ }
+ out := make([]map[string]any, 0, len(arr))
+ for _, item := range arr {
+ if m, ok := item.(map[string]any); ok {
+ out = append(out, m)
+ }
+ }
+ return out
+}
+
+func extractMessageContent(v any) string {
+ switch x := v.(type) {
+ case string:
+ return x
+ case []any:
+ parts := make([]string, 0, len(x))
+ for _, it := range x {
+ parts = append(parts, fmt.Sprintf("%v", it))
+ }
+ return strings.Join(parts, "\n")
+ default:
+ return fmt.Sprintf("%v", x)
+ }
+}
+
+func cloneMap(in map[string]any) map[string]any {
+ out := make(map[string]any, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
diff --git a/internal/adapter/claude/stream_runtime.go b/internal/adapter/claude/stream_runtime.go
deleted file mode 100644
index 01e07a9..0000000
--- a/internal/adapter/claude/stream_runtime.go
+++ /dev/null
@@ -1,308 +0,0 @@
-package claude
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- "ds2api/internal/sse"
- streamengine "ds2api/internal/stream"
- "ds2api/internal/util"
-)
-
-type claudeStreamRuntime struct {
- w http.ResponseWriter
- rc *http.ResponseController
- canFlush bool
-
- model string
- toolNames []string
- messages []any
-
- thinkingEnabled bool
- searchEnabled bool
- bufferToolContent bool
-
- messageID string
- thinking strings.Builder
- text strings.Builder
-
- nextBlockIndex int
- thinkingBlockOpen bool
- thinkingBlockIndex int
- textBlockOpen bool
- textBlockIndex int
- ended bool
- upstreamErr string
-}
-
-func newClaudeStreamRuntime(
- w http.ResponseWriter,
- rc *http.ResponseController,
- canFlush bool,
- model string,
- messages []any,
- thinkingEnabled bool,
- searchEnabled bool,
- toolNames []string,
-) *claudeStreamRuntime {
- return &claudeStreamRuntime{
- w: w,
- rc: rc,
- canFlush: canFlush,
- model: model,
- messages: messages,
- thinkingEnabled: thinkingEnabled,
- searchEnabled: searchEnabled,
- bufferToolContent: len(toolNames) > 0,
- toolNames: toolNames,
- messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
- thinkingBlockIndex: -1,
- textBlockIndex: -1,
- }
-}
-
-func (s *claudeStreamRuntime) send(event string, v any) {
- b, _ := json.Marshal(v)
- _, _ = s.w.Write([]byte("event: "))
- _, _ = s.w.Write([]byte(event))
- _, _ = s.w.Write([]byte("\n"))
- _, _ = s.w.Write([]byte("data: "))
- _, _ = s.w.Write(b)
- _, _ = s.w.Write([]byte("\n\n"))
- if s.canFlush {
- _ = s.rc.Flush()
- }
-}
-
-func (s *claudeStreamRuntime) sendError(message string) {
- msg := strings.TrimSpace(message)
- if msg == "" {
- msg = "upstream stream error"
- }
- s.send("error", map[string]any{
- "type": "error",
- "error": map[string]any{
- "type": "api_error",
- "message": msg,
- "code": "internal_error",
- "param": nil,
- },
- })
-}
-
-func (s *claudeStreamRuntime) sendPing() {
- s.send("ping", map[string]any{"type": "ping"})
-}
-
-func (s *claudeStreamRuntime) sendMessageStart() {
- inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
- s.send("message_start", map[string]any{
- "type": "message_start",
- "message": map[string]any{
- "id": s.messageID,
- "type": "message",
- "role": "assistant",
- "model": s.model,
- "content": []any{},
- "stop_reason": nil,
- "stop_sequence": nil,
- "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
- },
- })
-}
-
-func (s *claudeStreamRuntime) closeThinkingBlock() {
- if !s.thinkingBlockOpen {
- return
- }
- s.send("content_block_stop", map[string]any{
- "type": "content_block_stop",
- "index": s.thinkingBlockIndex,
- })
- s.thinkingBlockOpen = false
- s.thinkingBlockIndex = -1
-}
-
-func (s *claudeStreamRuntime) closeTextBlock() {
- if !s.textBlockOpen {
- return
- }
- s.send("content_block_stop", map[string]any{
- "type": "content_block_stop",
- "index": s.textBlockIndex,
- })
- s.textBlockOpen = false
- s.textBlockIndex = -1
-}
-
-func (s *claudeStreamRuntime) finalize(stopReason string) {
- if s.ended {
- return
- }
- s.ended = true
-
- s.closeThinkingBlock()
- s.closeTextBlock()
-
- finalThinking := s.thinking.String()
- finalText := s.text.String()
-
- if s.bufferToolContent {
- detected := util.ParseToolCalls(finalText, s.toolNames)
- if len(detected) > 0 {
- stopReason = "tool_use"
- for i, tc := range detected {
- idx := s.nextBlockIndex + i
- s.send("content_block_start", map[string]any{
- "type": "content_block_start",
- "index": idx,
- "content_block": map[string]any{
- "type": "tool_use",
- "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
- "name": tc.Name,
- "input": tc.Input,
- },
- })
- s.send("content_block_stop", map[string]any{
- "type": "content_block_stop",
- "index": idx,
- })
- }
- s.nextBlockIndex += len(detected)
- } else if finalText != "" {
- idx := s.nextBlockIndex
- s.nextBlockIndex++
- s.send("content_block_start", map[string]any{
- "type": "content_block_start",
- "index": idx,
- "content_block": map[string]any{
- "type": "text",
- "text": "",
- },
- })
- s.send("content_block_delta", map[string]any{
- "type": "content_block_delta",
- "index": idx,
- "delta": map[string]any{
- "type": "text_delta",
- "text": finalText,
- },
- })
- s.send("content_block_stop", map[string]any{
- "type": "content_block_stop",
- "index": idx,
- })
- }
- }
-
- outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
- s.send("message_delta", map[string]any{
- "type": "message_delta",
- "delta": map[string]any{
- "stop_reason": stopReason,
- "stop_sequence": nil,
- },
- "usage": map[string]any{
- "output_tokens": outputTokens,
- },
- })
- s.send("message_stop", map[string]any{"type": "message_stop"})
-}
-
-func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
- if !parsed.Parsed {
- return streamengine.ParsedDecision{}
- }
- if parsed.ErrorMessage != "" {
- s.upstreamErr = parsed.ErrorMessage
- return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
- }
- if parsed.Stop {
- return streamengine.ParsedDecision{Stop: true}
- }
-
- contentSeen := false
- for _, p := range parsed.Parts {
- if p.Text == "" {
- continue
- }
- if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
- continue
- }
- contentSeen = true
-
- if p.Type == "thinking" {
- if !s.thinkingEnabled {
- continue
- }
- s.thinking.WriteString(p.Text)
- s.closeTextBlock()
- if !s.thinkingBlockOpen {
- s.thinkingBlockIndex = s.nextBlockIndex
- s.nextBlockIndex++
- s.send("content_block_start", map[string]any{
- "type": "content_block_start",
- "index": s.thinkingBlockIndex,
- "content_block": map[string]any{
- "type": "thinking",
- "thinking": "",
- },
- })
- s.thinkingBlockOpen = true
- }
- s.send("content_block_delta", map[string]any{
- "type": "content_block_delta",
- "index": s.thinkingBlockIndex,
- "delta": map[string]any{
- "type": "thinking_delta",
- "thinking": p.Text,
- },
- })
- continue
- }
-
- s.text.WriteString(p.Text)
- if s.bufferToolContent {
- continue
- }
- s.closeThinkingBlock()
- if !s.textBlockOpen {
- s.textBlockIndex = s.nextBlockIndex
- s.nextBlockIndex++
- s.send("content_block_start", map[string]any{
- "type": "content_block_start",
- "index": s.textBlockIndex,
- "content_block": map[string]any{
- "type": "text",
- "text": "",
- },
- })
- s.textBlockOpen = true
- }
- s.send("content_block_delta", map[string]any{
- "type": "content_block_delta",
- "index": s.textBlockIndex,
- "delta": map[string]any{
- "type": "text_delta",
- "text": p.Text,
- },
- })
- }
-
- return streamengine.ParsedDecision{ContentSeen: contentSeen}
-}
-
-func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
- if string(reason) == "upstream_error" {
- s.sendError(s.upstreamErr)
- return
- }
- if scannerErr != nil {
- s.sendError(scannerErr.Error())
- return
- }
- s.finalize("end_turn")
-}
diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go
new file mode 100644
index 0000000..cb24bdd
--- /dev/null
+++ b/internal/adapter/claude/stream_runtime_core.go
@@ -0,0 +1,146 @@
+package claude
+
+import (
+ "fmt"
+ "net/http"
+ "strings"
+ "time"
+
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+)
+
+type claudeStreamRuntime struct {
+ w http.ResponseWriter
+ rc *http.ResponseController
+ canFlush bool
+
+ model string
+ toolNames []string
+ messages []any
+
+ thinkingEnabled bool
+ searchEnabled bool
+ bufferToolContent bool
+
+ messageID string
+ thinking strings.Builder
+ text strings.Builder
+
+ nextBlockIndex int
+ thinkingBlockOpen bool
+ thinkingBlockIndex int
+ textBlockOpen bool
+ textBlockIndex int
+ ended bool
+ upstreamErr string
+}
+
+func newClaudeStreamRuntime(
+ w http.ResponseWriter,
+ rc *http.ResponseController,
+ canFlush bool,
+ model string,
+ messages []any,
+ thinkingEnabled bool,
+ searchEnabled bool,
+ toolNames []string,
+) *claudeStreamRuntime {
+ return &claudeStreamRuntime{
+ w: w,
+ rc: rc,
+ canFlush: canFlush,
+ model: model,
+ messages: messages,
+ thinkingEnabled: thinkingEnabled,
+ searchEnabled: searchEnabled,
+ bufferToolContent: len(toolNames) > 0,
+ toolNames: toolNames,
+ messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
+ thinkingBlockIndex: -1,
+ textBlockIndex: -1,
+ }
+}
+
+func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
+ if !parsed.Parsed {
+ return streamengine.ParsedDecision{}
+ }
+ if parsed.ErrorMessage != "" {
+ s.upstreamErr = parsed.ErrorMessage
+ return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
+ }
+ if parsed.Stop {
+ return streamengine.ParsedDecision{Stop: true}
+ }
+
+ contentSeen := false
+ for _, p := range parsed.Parts {
+ if p.Text == "" {
+ continue
+ }
+ if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
+ continue
+ }
+ contentSeen = true
+
+ if p.Type == "thinking" {
+ if !s.thinkingEnabled {
+ continue
+ }
+ s.thinking.WriteString(p.Text)
+ s.closeTextBlock()
+ if !s.thinkingBlockOpen {
+ s.thinkingBlockIndex = s.nextBlockIndex
+ s.nextBlockIndex++
+ s.send("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": s.thinkingBlockIndex,
+ "content_block": map[string]any{
+ "type": "thinking",
+ "thinking": "",
+ },
+ })
+ s.thinkingBlockOpen = true
+ }
+ s.send("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": s.thinkingBlockIndex,
+ "delta": map[string]any{
+ "type": "thinking_delta",
+ "thinking": p.Text,
+ },
+ })
+ continue
+ }
+
+ s.text.WriteString(p.Text)
+ if s.bufferToolContent {
+ continue
+ }
+ s.closeThinkingBlock()
+ if !s.textBlockOpen {
+ s.textBlockIndex = s.nextBlockIndex
+ s.nextBlockIndex++
+ s.send("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": s.textBlockIndex,
+ "content_block": map[string]any{
+ "type": "text",
+ "text": "",
+ },
+ })
+ s.textBlockOpen = true
+ }
+ s.send("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": s.textBlockIndex,
+ "delta": map[string]any{
+ "type": "text_delta",
+ "text": p.Text,
+ },
+ })
+ }
+
+ return streamengine.ParsedDecision{ContentSeen: contentSeen}
+}
diff --git a/internal/adapter/claude/stream_runtime_emit.go b/internal/adapter/claude/stream_runtime_emit.go
new file mode 100644
index 0000000..c2fba19
--- /dev/null
+++ b/internal/adapter/claude/stream_runtime_emit.go
@@ -0,0 +1,59 @@
+package claude
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "ds2api/internal/util"
+)
+
+func (s *claudeStreamRuntime) send(event string, v any) {
+ b, _ := json.Marshal(v)
+ _, _ = s.w.Write([]byte("event: "))
+ _, _ = s.w.Write([]byte(event))
+ _, _ = s.w.Write([]byte("\n"))
+ _, _ = s.w.Write([]byte("data: "))
+ _, _ = s.w.Write(b)
+ _, _ = s.w.Write([]byte("\n\n"))
+ if s.canFlush {
+ _ = s.rc.Flush()
+ }
+}
+
+func (s *claudeStreamRuntime) sendError(message string) {
+ msg := strings.TrimSpace(message)
+ if msg == "" {
+ msg = "upstream stream error"
+ }
+ s.send("error", map[string]any{
+ "type": "error",
+ "error": map[string]any{
+ "type": "api_error",
+ "message": msg,
+ "code": "internal_error",
+ "param": nil,
+ },
+ })
+}
+
+func (s *claudeStreamRuntime) sendPing() {
+ s.send("ping", map[string]any{"type": "ping"})
+}
+
+func (s *claudeStreamRuntime) sendMessageStart() {
+ inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
+ s.send("message_start", map[string]any{
+ "type": "message_start",
+ "message": map[string]any{
+ "id": s.messageID,
+ "type": "message",
+ "role": "assistant",
+ "model": s.model,
+ "content": []any{},
+ "stop_reason": nil,
+ "stop_sequence": nil,
+ "usage": map[string]any{"input_tokens": inputTokens, "output_tokens": 0},
+ },
+ })
+}
diff --git a/internal/adapter/claude/stream_runtime_finalize.go b/internal/adapter/claude/stream_runtime_finalize.go
new file mode 100644
index 0000000..f957ba1
--- /dev/null
+++ b/internal/adapter/claude/stream_runtime_finalize.go
@@ -0,0 +1,119 @@
+package claude
+
+import (
+ "fmt"
+ "time"
+
+ streamengine "ds2api/internal/stream"
+ "ds2api/internal/util"
+)
+
+func (s *claudeStreamRuntime) closeThinkingBlock() {
+ if !s.thinkingBlockOpen {
+ return
+ }
+ s.send("content_block_stop", map[string]any{
+ "type": "content_block_stop",
+ "index": s.thinkingBlockIndex,
+ })
+ s.thinkingBlockOpen = false
+ s.thinkingBlockIndex = -1
+}
+
+func (s *claudeStreamRuntime) closeTextBlock() {
+ if !s.textBlockOpen {
+ return
+ }
+ s.send("content_block_stop", map[string]any{
+ "type": "content_block_stop",
+ "index": s.textBlockIndex,
+ })
+ s.textBlockOpen = false
+ s.textBlockIndex = -1
+}
+
+func (s *claudeStreamRuntime) finalize(stopReason string) {
+ if s.ended {
+ return
+ }
+ s.ended = true
+
+ s.closeThinkingBlock()
+ s.closeTextBlock()
+
+ finalThinking := s.thinking.String()
+ finalText := s.text.String()
+
+ if s.bufferToolContent {
+ detected := util.ParseToolCalls(finalText, s.toolNames)
+ if len(detected) > 0 {
+ stopReason = "tool_use"
+ for i, tc := range detected {
+ idx := s.nextBlockIndex + i
+ s.send("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": idx,
+ "content_block": map[string]any{
+ "type": "tool_use",
+ "id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
+ "name": tc.Name,
+ "input": tc.Input,
+ },
+ })
+ s.send("content_block_stop", map[string]any{
+ "type": "content_block_stop",
+ "index": idx,
+ })
+ }
+ s.nextBlockIndex += len(detected)
+ } else if finalText != "" {
+ idx := s.nextBlockIndex
+ s.nextBlockIndex++
+ s.send("content_block_start", map[string]any{
+ "type": "content_block_start",
+ "index": idx,
+ "content_block": map[string]any{
+ "type": "text",
+ "text": "",
+ },
+ })
+ s.send("content_block_delta", map[string]any{
+ "type": "content_block_delta",
+ "index": idx,
+ "delta": map[string]any{
+ "type": "text_delta",
+ "text": finalText,
+ },
+ })
+ s.send("content_block_stop", map[string]any{
+ "type": "content_block_stop",
+ "index": idx,
+ })
+ }
+ }
+
+ outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
+ s.send("message_delta", map[string]any{
+ "type": "message_delta",
+ "delta": map[string]any{
+ "stop_reason": stopReason,
+ "stop_sequence": nil,
+ },
+ "usage": map[string]any{
+ "output_tokens": outputTokens,
+ },
+ })
+ s.send("message_stop", map[string]any{"type": "message_stop"})
+}
+
+func (s *claudeStreamRuntime) onFinalize(reason streamengine.StopReason, scannerErr error) {
+ if string(reason) == "upstream_error" {
+ s.sendError(s.upstreamErr)
+ return
+ }
+ if scannerErr != nil {
+ s.sendError(scannerErr.Error())
+ return
+ }
+ s.finalize("end_turn")
+}
diff --git a/internal/adapter/gemini/convert.go b/internal/adapter/gemini/convert.go
deleted file mode 100644
index 3f63579..0000000
--- a/internal/adapter/gemini/convert.go
+++ /dev/null
@@ -1,313 +0,0 @@
-package gemini
-
-import (
- "encoding/json"
- "fmt"
- "strings"
-
- "ds2api/internal/adapter/openai"
- "ds2api/internal/config"
- "ds2api/internal/util"
-)
-
-func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
- requestedModel := strings.TrimSpace(routeModel)
- if requestedModel == "" {
- return util.StandardRequest{}, fmt.Errorf("model is required in request path")
- }
-
- resolvedModel, ok := config.ResolveModel(store, requestedModel)
- if !ok {
- return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
- }
- thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
-
- messagesRaw := geminiMessagesFromRequest(req)
- if len(messagesRaw) == 0 {
- return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
- }
-
- toolsRaw := convertGeminiTools(req["tools"])
- finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
- passThrough := collectGeminiPassThrough(req)
-
- return util.StandardRequest{
- Surface: "google_gemini",
- RequestedModel: requestedModel,
- ResolvedModel: resolvedModel,
- ResponseModel: requestedModel,
- Messages: messagesRaw,
- FinalPrompt: finalPrompt,
- ToolNames: toolNames,
- Stream: stream,
- Thinking: thinkingEnabled,
- Search: searchEnabled,
- PassThrough: passThrough,
- }, nil
-}
-
-func geminiMessagesFromRequest(req map[string]any) []any {
- out := make([]any, 0, 8)
- if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
- out = append(out, map[string]any{
- "role": "system",
- "content": sys,
- })
- }
-
- contents, _ := req["contents"].([]any)
- for _, item := range contents {
- content, ok := item.(map[string]any)
- if !ok {
- continue
- }
- role := mapGeminiRole(content["role"])
- if role == "" {
- role = "user"
- }
- parts, _ := content["parts"].([]any)
- if len(parts) == 0 {
- if text := strings.TrimSpace(asString(content["text"])); text != "" {
- out = append(out, map[string]any{
- "role": role,
- "content": text,
- })
- }
- continue
- }
-
- textParts := make([]string, 0, len(parts))
- flushText := func() {
- if len(textParts) == 0 {
- return
- }
- out = append(out, map[string]any{
- "role": role,
- "content": strings.Join(textParts, "\n"),
- })
- textParts = textParts[:0]
- }
-
- for _, rawPart := range parts {
- part, ok := rawPart.(map[string]any)
- if !ok {
- continue
- }
- if text := strings.TrimSpace(asString(part["text"])); text != "" {
- textParts = append(textParts, text)
- continue
- }
-
- if fnCall, ok := part["functionCall"].(map[string]any); ok {
- flushText()
- if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
- callID := strings.TrimSpace(asString(fnCall["id"]))
- if callID == "" {
- callID = "call_gemini"
- }
- out = append(out, map[string]any{
- "role": "assistant",
- "tool_calls": []any{
- map[string]any{
- "id": callID,
- "type": "function",
- "function": map[string]any{
- "name": name,
- "arguments": stringifyJSON(fnCall["args"]),
- },
- },
- },
- })
- }
- continue
- }
-
- if fnResp, ok := part["functionResponse"].(map[string]any); ok {
- flushText()
- name := strings.TrimSpace(asString(fnResp["name"]))
- callID := strings.TrimSpace(asString(fnResp["id"]))
- if callID == "" {
- callID = strings.TrimSpace(asString(fnResp["callId"]))
- }
- if callID == "" {
- callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
- }
- if callID == "" {
- callID = "call_gemini"
- }
- content := fnResp["response"]
- if content == nil {
- content = fnResp["output"]
- }
- if content == nil {
- content = ""
- }
- msg := map[string]any{
- "role": "tool",
- "tool_call_id": callID,
- "content": content,
- }
- if name != "" {
- msg["name"] = name
- }
- out = append(out, msg)
- }
- }
- flushText()
- }
- return out
-}
-
-func normalizeGeminiSystemInstruction(raw any) string {
- switch v := raw.(type) {
- case string:
- return strings.TrimSpace(v)
- case map[string]any:
- if parts, ok := v["parts"].([]any); ok {
- texts := make([]string, 0, len(parts))
- for _, item := range parts {
- part, ok := item.(map[string]any)
- if !ok {
- continue
- }
- if text := strings.TrimSpace(asString(part["text"])); text != "" {
- texts = append(texts, text)
- }
- }
- return strings.Join(texts, "\n")
- }
- if text := strings.TrimSpace(asString(v["text"])); text != "" {
- return text
- }
- }
- return ""
-}
-
-func mapGeminiRole(v any) string {
- switch strings.ToLower(strings.TrimSpace(asString(v))) {
- case "user":
- return "user"
- case "model", "assistant":
- return "assistant"
- case "system":
- return "system"
- default:
- return ""
- }
-}
-
-func convertGeminiTools(raw any) []any {
- tools, _ := raw.([]any)
- if len(tools) == 0 {
- return nil
- }
- out := make([]any, 0, len(tools))
- for _, item := range tools {
- tool, ok := item.(map[string]any)
- if !ok {
- continue
- }
-
- if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
- for _, declRaw := range fnDecls {
- decl, ok := declRaw.(map[string]any)
- if !ok {
- continue
- }
- name := strings.TrimSpace(asString(decl["name"]))
- if name == "" {
- continue
- }
- function := map[string]any{
- "name": name,
- }
- if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
- function["description"] = desc
- }
- if params, ok := decl["parameters"].(map[string]any); ok {
- function["parameters"] = params
- }
- out = append(out, map[string]any{
- "type": "function",
- "function": function,
- })
- }
- continue
- }
-
- // OpenAI-style passthrough fallback.
- if _, ok := tool["function"].(map[string]any); ok {
- out = append(out, tool)
- continue
- }
-
- // Loose fallback for flattened function schema objects.
- name := strings.TrimSpace(asString(tool["name"]))
- if name == "" {
- continue
- }
- fn := map[string]any{"name": name}
- if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
- fn["description"] = desc
- }
- if params, ok := tool["parameters"].(map[string]any); ok {
- fn["parameters"] = params
- }
- out = append(out, map[string]any{
- "type": "function",
- "function": fn,
- })
- }
- if len(out) == 0 {
- return nil
- }
- return out
-}
-
-func collectGeminiPassThrough(req map[string]any) map[string]any {
- cfg, _ := req["generationConfig"].(map[string]any)
- if len(cfg) == 0 {
- return nil
- }
- out := map[string]any{}
- if v, ok := cfg["temperature"]; ok {
- out["temperature"] = v
- }
- if v, ok := cfg["topP"]; ok {
- out["top_p"] = v
- }
- if v, ok := cfg["maxOutputTokens"]; ok {
- out["max_tokens"] = v
- }
- if v, ok := cfg["stopSequences"]; ok {
- out["stop"] = v
- }
- if len(out) == 0 {
- return nil
- }
- return out
-}
-
-func asString(v any) string {
- s, _ := v.(string)
- return s
-}
-
-func stringifyJSON(v any) string {
- switch x := v.(type) {
- case nil:
- return "{}"
- case string:
- s := strings.TrimSpace(x)
- if s == "" {
- return "{}"
- }
- return s
- default:
- b, err := json.Marshal(x)
- if err != nil || len(b) == 0 {
- return "{}"
- }
- return string(b)
- }
-}
diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go
new file mode 100644
index 0000000..1148a7a
--- /dev/null
+++ b/internal/adapter/gemini/convert_messages.go
@@ -0,0 +1,153 @@
+package gemini
+
+import "strings"
+
+func geminiMessagesFromRequest(req map[string]any) []any {
+ out := make([]any, 0, 8)
+ if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
+ out = append(out, map[string]any{
+ "role": "system",
+ "content": sys,
+ })
+ }
+
+ contents, _ := req["contents"].([]any)
+ for _, item := range contents {
+ content, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ role := mapGeminiRole(content["role"])
+ if role == "" {
+ role = "user"
+ }
+ parts, _ := content["parts"].([]any)
+ if len(parts) == 0 {
+ if text := strings.TrimSpace(asString(content["text"])); text != "" {
+ out = append(out, map[string]any{
+ "role": role,
+ "content": text,
+ })
+ }
+ continue
+ }
+
+ textParts := make([]string, 0, len(parts))
+ flushText := func() {
+ if len(textParts) == 0 {
+ return
+ }
+ out = append(out, map[string]any{
+ "role": role,
+ "content": strings.Join(textParts, "\n"),
+ })
+ textParts = textParts[:0]
+ }
+
+ for _, rawPart := range parts {
+ part, ok := rawPart.(map[string]any)
+ if !ok {
+ continue
+ }
+ if text := strings.TrimSpace(asString(part["text"])); text != "" {
+ textParts = append(textParts, text)
+ continue
+ }
+
+ if fnCall, ok := part["functionCall"].(map[string]any); ok {
+ flushText()
+ if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
+ callID := strings.TrimSpace(asString(fnCall["id"]))
+ if callID == "" {
+ callID = "call_gemini"
+ }
+ out = append(out, map[string]any{
+ "role": "assistant",
+ "tool_calls": []any{
+ map[string]any{
+ "id": callID,
+ "type": "function",
+ "function": map[string]any{
+ "name": name,
+ "arguments": stringifyJSON(fnCall["args"]),
+ },
+ },
+ },
+ })
+ }
+ continue
+ }
+
+ if fnResp, ok := part["functionResponse"].(map[string]any); ok {
+ flushText()
+ name := strings.TrimSpace(asString(fnResp["name"]))
+ callID := strings.TrimSpace(asString(fnResp["id"]))
+ if callID == "" {
+ callID = strings.TrimSpace(asString(fnResp["callId"]))
+ }
+ if callID == "" {
+ callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
+ }
+ if callID == "" {
+ callID = "call_gemini"
+ }
+ content := fnResp["response"]
+ if content == nil {
+ content = fnResp["output"]
+ }
+ if content == nil {
+ content = ""
+ }
+ msg := map[string]any{
+ "role": "tool",
+ "tool_call_id": callID,
+ "content": content,
+ }
+ if name != "" {
+ msg["name"] = name
+ }
+ out = append(out, msg)
+ }
+ }
+ flushText()
+ }
+ return out
+}
+
+func normalizeGeminiSystemInstruction(raw any) string {
+ switch v := raw.(type) {
+ case string:
+ return strings.TrimSpace(v)
+ case map[string]any:
+ if parts, ok := v["parts"].([]any); ok {
+ texts := make([]string, 0, len(parts))
+ for _, item := range parts {
+ part, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if text := strings.TrimSpace(asString(part["text"])); text != "" {
+ texts = append(texts, text)
+ }
+ }
+ return strings.Join(texts, "\n")
+ }
+ if text := strings.TrimSpace(asString(v["text"])); text != "" {
+ return text
+ }
+ }
+ return ""
+}
+
+func mapGeminiRole(v any) string {
+ switch strings.ToLower(strings.TrimSpace(asString(v))) {
+ case "user":
+ return "user"
+ case "model", "assistant":
+ return "assistant"
+ case "system":
+ return "system"
+ default:
+ return ""
+ }
+}
diff --git a/internal/adapter/gemini/convert_passthrough.go b/internal/adapter/gemini/convert_passthrough.go
new file mode 100644
index 0000000..05cd6cd
--- /dev/null
+++ b/internal/adapter/gemini/convert_passthrough.go
@@ -0,0 +1,54 @@
+package gemini
+
+import (
+ "encoding/json"
+ "strings"
+)
+
+func collectGeminiPassThrough(req map[string]any) map[string]any {
+ cfg, _ := req["generationConfig"].(map[string]any)
+ if len(cfg) == 0 {
+ return nil
+ }
+ out := map[string]any{}
+ if v, ok := cfg["temperature"]; ok {
+ out["temperature"] = v
+ }
+ if v, ok := cfg["topP"]; ok {
+ out["top_p"] = v
+ }
+ if v, ok := cfg["maxOutputTokens"]; ok {
+ out["max_tokens"] = v
+ }
+ if v, ok := cfg["stopSequences"]; ok {
+ out["stop"] = v
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
+
+func asString(v any) string {
+ s, _ := v.(string)
+ return s
+}
+
+func stringifyJSON(v any) string {
+ switch x := v.(type) {
+ case nil:
+ return "{}"
+ case string:
+ s := strings.TrimSpace(x)
+ if s == "" {
+ return "{}"
+ }
+ return s
+ default:
+ b, err := json.Marshal(x)
+ if err != nil || len(b) == 0 {
+ return "{}"
+ }
+ return string(b)
+ }
+}
diff --git a/internal/adapter/gemini/convert_request.go b/internal/adapter/gemini/convert_request.go
new file mode 100644
index 0000000..2eca687
--- /dev/null
+++ b/internal/adapter/gemini/convert_request.go
@@ -0,0 +1,46 @@
+package gemini
+
+import (
+ "fmt"
+ "strings"
+
+ "ds2api/internal/adapter/openai"
+ "ds2api/internal/config"
+ "ds2api/internal/util"
+)
+
+func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[string]any, stream bool) (util.StandardRequest, error) {
+ requestedModel := strings.TrimSpace(routeModel)
+ if requestedModel == "" {
+ return util.StandardRequest{}, fmt.Errorf("model is required in request path")
+ }
+
+ resolvedModel, ok := config.ResolveModel(store, requestedModel)
+ if !ok {
+ return util.StandardRequest{}, fmt.Errorf("Model '%s' is not available.", requestedModel)
+ }
+ thinkingEnabled, searchEnabled, _ := config.GetModelConfig(resolvedModel)
+
+ messagesRaw := geminiMessagesFromRequest(req)
+ if len(messagesRaw) == 0 {
+ return util.StandardRequest{}, fmt.Errorf("Request must include non-empty contents.")
+ }
+
+ toolsRaw := convertGeminiTools(req["tools"])
+ finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "")
+ passThrough := collectGeminiPassThrough(req)
+
+ return util.StandardRequest{
+ Surface: "google_gemini",
+ RequestedModel: requestedModel,
+ ResolvedModel: resolvedModel,
+ ResponseModel: requestedModel,
+ Messages: messagesRaw,
+ FinalPrompt: finalPrompt,
+ ToolNames: toolNames,
+ Stream: stream,
+ Thinking: thinkingEnabled,
+ Search: searchEnabled,
+ PassThrough: passThrough,
+ }, nil
+}
diff --git a/internal/adapter/gemini/convert_tools.go b/internal/adapter/gemini/convert_tools.go
new file mode 100644
index 0000000..4611f85
--- /dev/null
+++ b/internal/adapter/gemini/convert_tools.go
@@ -0,0 +1,71 @@
+package gemini
+
+import "strings"
+
+func convertGeminiTools(raw any) []any {
+ tools, _ := raw.([]any)
+ if len(tools) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(tools))
+ for _, item := range tools {
+ tool, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+
+ if fnDecls, ok := tool["functionDeclarations"].([]any); ok && len(fnDecls) > 0 {
+ for _, declRaw := range fnDecls {
+ decl, ok := declRaw.(map[string]any)
+ if !ok {
+ continue
+ }
+ name := strings.TrimSpace(asString(decl["name"]))
+ if name == "" {
+ continue
+ }
+ function := map[string]any{
+ "name": name,
+ }
+ if desc := strings.TrimSpace(asString(decl["description"])); desc != "" {
+ function["description"] = desc
+ }
+ if params, ok := decl["parameters"].(map[string]any); ok {
+ function["parameters"] = params
+ }
+ out = append(out, map[string]any{
+ "type": "function",
+ "function": function,
+ })
+ }
+ continue
+ }
+
+ // OpenAI-style passthrough fallback.
+ if _, ok := tool["function"].(map[string]any); ok {
+ out = append(out, tool)
+ continue
+ }
+
+ // Loose fallback for flattened function schema objects.
+ name := strings.TrimSpace(asString(tool["name"]))
+ if name == "" {
+ continue
+ }
+ fn := map[string]any{"name": name}
+ if desc := strings.TrimSpace(asString(tool["description"])); desc != "" {
+ fn["description"] = desc
+ }
+ if params, ok := tool["parameters"].(map[string]any); ok {
+ fn["parameters"] = params
+ }
+ out = append(out, map[string]any{
+ "type": "function",
+ "function": fn,
+ })
+ }
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
diff --git a/internal/adapter/gemini/handler.go b/internal/adapter/gemini/handler.go
deleted file mode 100644
index 8daaeda..0000000
--- a/internal/adapter/gemini/handler.go
+++ /dev/null
@@ -1,348 +0,0 @@
-package gemini
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "strings"
- "time"
-
- "github.com/go-chi/chi/v5"
-
- "ds2api/internal/auth"
- "ds2api/internal/deepseek"
- "ds2api/internal/sse"
- streamengine "ds2api/internal/stream"
- "ds2api/internal/util"
-)
-
-var writeJSON = util.WriteJSON
-
-type Handler struct {
- Store ConfigReader
- Auth AuthResolver
- DS DeepSeekCaller
-}
-
-func RegisterRoutes(r chi.Router, h *Handler) {
- r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
- r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
- r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
- r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
-}
-
-func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
- h.handleGenerateContent(w, r, false)
-}
-
-func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
- h.handleGenerateContent(w, r, true)
-}
-
-func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
- a, err := h.Auth.Determine(r)
- if err != nil {
- status := http.StatusUnauthorized
- detail := err.Error()
- if err == auth.ErrNoAccount {
- status = http.StatusTooManyRequests
- }
- writeGeminiError(w, status, detail)
- return
- }
- defer h.Auth.Release(a)
-
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeGeminiError(w, http.StatusBadRequest, "invalid json")
- return
- }
-
- routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
- stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
- if err != nil {
- writeGeminiError(w, http.StatusBadRequest, err.Error())
- return
- }
-
- sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
- if err != nil {
- if a.UseConfigToken {
- writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
- } else {
- writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
- }
- return
- }
- pow, err := h.DS.GetPow(r.Context(), a, 3)
- if err != nil {
- writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
- return
- }
- payload := stdReq.CompletionPayload(sessionID)
- resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
- if err != nil {
- writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
- return
- }
-
- if stream {
- h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
- return
- }
- h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
-}
-
-func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
- return
- }
-
- result := sse.CollectStream(resp, thinkingEnabled, true)
- writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
-}
-
-func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
- return
- }
-
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache, no-transform")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
-
- rc := http.NewResponseController(w)
- _, canFlush := w.(http.Flusher)
- runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
-
- initialType := "text"
- if thinkingEnabled {
- initialType = "thinking"
- }
- streamengine.ConsumeSSE(streamengine.ConsumeConfig{
- Context: r.Context(),
- Body: resp.Body,
- ThinkingEnabled: thinkingEnabled,
- InitialType: initialType,
- KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
- IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
- MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
- }, streamengine.ConsumeHooks{
- OnParsed: runtime.onParsed,
- OnFinalize: func(_ streamengine.StopReason, _ error) {
- runtime.finalize()
- },
- })
-}
-
-func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
- parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
- usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
- return map[string]any{
- "candidates": []map[string]any{
- {
- "index": 0,
- "content": map[string]any{
- "role": "model",
- "parts": parts,
- },
- "finishReason": "STOP",
- },
- },
- "modelVersion": model,
- "usageMetadata": usage,
- }
-}
-
-func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
- promptTokens := util.EstimateTokens(finalPrompt)
- reasoningTokens := util.EstimateTokens(finalThinking)
- completionTokens := util.EstimateTokens(finalText)
- return map[string]any{
- "promptTokenCount": promptTokens,
- "candidatesTokenCount": reasoningTokens + completionTokens,
- "totalTokenCount": promptTokens + reasoningTokens + completionTokens,
- }
-}
-
-func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
- detected := util.ParseToolCalls(finalText, toolNames)
- if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
- detected = util.ParseToolCalls(finalThinking, toolNames)
- }
- if len(detected) > 0 {
- parts := make([]map[string]any, 0, len(detected))
- for _, tc := range detected {
- parts = append(parts, map[string]any{
- "functionCall": map[string]any{
- "name": tc.Name,
- "args": tc.Input,
- },
- })
- }
- return parts
- }
-
- text := finalText
- if strings.TrimSpace(text) == "" {
- text = finalThinking
- }
- return []map[string]any{{"text": text}}
-}
-
-type geminiStreamRuntime struct {
- w http.ResponseWriter
- rc *http.ResponseController
- canFlush bool
-
- model string
- finalPrompt string
-
- thinkingEnabled bool
- searchEnabled bool
- bufferContent bool
- toolNames []string
-
- thinking strings.Builder
- text strings.Builder
-}
-
-func newGeminiStreamRuntime(
- w http.ResponseWriter,
- rc *http.ResponseController,
- canFlush bool,
- model string,
- finalPrompt string,
- thinkingEnabled bool,
- searchEnabled bool,
- toolNames []string,
-) *geminiStreamRuntime {
- return &geminiStreamRuntime{
- w: w,
- rc: rc,
- canFlush: canFlush,
- model: model,
- finalPrompt: finalPrompt,
- thinkingEnabled: thinkingEnabled,
- searchEnabled: searchEnabled,
- bufferContent: len(toolNames) > 0,
- toolNames: toolNames,
- }
-}
-
-func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
- b, _ := json.Marshal(payload)
- _, _ = s.w.Write([]byte("data: "))
- _, _ = s.w.Write(b)
- _, _ = s.w.Write([]byte("\n\n"))
- if s.canFlush {
- _ = s.rc.Flush()
- }
-}
-
-func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
- if !parsed.Parsed {
- return streamengine.ParsedDecision{}
- }
- if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
- return streamengine.ParsedDecision{Stop: true}
- }
-
- contentSeen := false
- for _, p := range parsed.Parts {
- if p.Text == "" {
- continue
- }
- if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
- continue
- }
- contentSeen = true
- if p.Type == "thinking" {
- if s.thinkingEnabled {
- s.thinking.WriteString(p.Text)
- }
- continue
- }
- s.text.WriteString(p.Text)
- if s.bufferContent {
- continue
- }
- s.sendChunk(map[string]any{
- "candidates": []map[string]any{
- {
- "index": 0,
- "content": map[string]any{
- "role": "model",
- "parts": []map[string]any{{"text": p.Text}},
- },
- },
- },
- "modelVersion": s.model,
- })
- }
- return streamengine.ParsedDecision{ContentSeen: contentSeen}
-}
-
-func (s *geminiStreamRuntime) finalize() {
- finalThinking := s.thinking.String()
- finalText := s.text.String()
-
- if s.bufferContent {
- parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
- s.sendChunk(map[string]any{
- "candidates": []map[string]any{
- {
- "index": 0,
- "content": map[string]any{
- "role": "model",
- "parts": parts,
- },
- },
- },
- "modelVersion": s.model,
- })
- }
-
- s.sendChunk(map[string]any{
- "candidates": []map[string]any{
- {
- "index": 0,
- "finishReason": "STOP",
- },
- },
- "modelVersion": s.model,
- "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
- })
-}
-
-func writeGeminiError(w http.ResponseWriter, status int, message string) {
- errorStatus := "INVALID_ARGUMENT"
- switch status {
- case http.StatusUnauthorized:
- errorStatus = "UNAUTHENTICATED"
- case http.StatusForbidden:
- errorStatus = "PERMISSION_DENIED"
- case http.StatusTooManyRequests:
- errorStatus = "RESOURCE_EXHAUSTED"
- case http.StatusNotFound:
- errorStatus = "NOT_FOUND"
- default:
- if status >= 500 {
- errorStatus = "INTERNAL"
- }
- }
- writeJSON(w, status, map[string]any{
- "error": map[string]any{
- "code": status,
- "message": message,
- "status": errorStatus,
- },
- })
-}
diff --git a/internal/adapter/gemini/handler_errors.go b/internal/adapter/gemini/handler_errors.go
new file mode 100644
index 0000000..09df09b
--- /dev/null
+++ b/internal/adapter/gemini/handler_errors.go
@@ -0,0 +1,28 @@
+package gemini
+
+import "net/http"
+
+func writeGeminiError(w http.ResponseWriter, status int, message string) {
+ errorStatus := "INVALID_ARGUMENT"
+ switch status {
+ case http.StatusUnauthorized:
+ errorStatus = "UNAUTHENTICATED"
+ case http.StatusForbidden:
+ errorStatus = "PERMISSION_DENIED"
+ case http.StatusTooManyRequests:
+ errorStatus = "RESOURCE_EXHAUSTED"
+ case http.StatusNotFound:
+ errorStatus = "NOT_FOUND"
+ default:
+ if status >= 500 {
+ errorStatus = "INTERNAL"
+ }
+ }
+ writeJSON(w, status, map[string]any{
+ "error": map[string]any{
+ "code": status,
+ "message": message,
+ "status": errorStatus,
+ },
+ })
+}
diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go
new file mode 100644
index 0000000..9144a42
--- /dev/null
+++ b/internal/adapter/gemini/handler_generate.go
@@ -0,0 +1,135 @@
+package gemini
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/sse"
+ "ds2api/internal/util"
+)
+
+func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, stream bool) {
+ a, err := h.Auth.Determine(r)
+ if err != nil {
+ status := http.StatusUnauthorized
+ detail := err.Error()
+ if err == auth.ErrNoAccount {
+ status = http.StatusTooManyRequests
+ }
+ writeGeminiError(w, status, detail)
+ return
+ }
+ defer h.Auth.Release(a)
+
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeGeminiError(w, http.StatusBadRequest, "invalid json")
+ return
+ }
+
+ routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
+ stdReq, err := normalizeGeminiRequest(h.Store, routeModel, req, stream)
+ if err != nil {
+ writeGeminiError(w, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
+ if err != nil {
+ if a.UseConfigToken {
+ writeGeminiError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
+ } else {
+ writeGeminiError(w, http.StatusUnauthorized, "Invalid token.")
+ }
+ return
+ }
+ pow, err := h.DS.GetPow(r.Context(), a, 3)
+ if err != nil {
+ writeGeminiError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
+ return
+ }
+ payload := stdReq.CompletionPayload(sessionID)
+ resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
+ if err != nil {
+ writeGeminiError(w, http.StatusInternalServerError, "Failed to get completion.")
+ return
+ }
+
+ if stream {
+ h.handleStreamGenerateContent(w, r, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
+ return
+ }
+ h.handleNonStreamGenerateContent(w, resp, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
+}
+
+func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *http.Response, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
+ return
+ }
+
+ result := sse.CollectStream(resp, thinkingEnabled, true)
+ writeJSON(w, http.StatusOK, buildGeminiGenerateContentResponse(model, finalPrompt, result.Thinking, result.Text, toolNames))
+}
+
+func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
+ parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
+ usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
+ return map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": parts,
+ },
+ "finishReason": "STOP",
+ },
+ },
+ "modelVersion": model,
+ "usageMetadata": usage,
+ }
+}
+
+func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
+ promptTokens := util.EstimateTokens(finalPrompt)
+ reasoningTokens := util.EstimateTokens(finalThinking)
+ completionTokens := util.EstimateTokens(finalText)
+ return map[string]any{
+ "promptTokenCount": promptTokens,
+ "candidatesTokenCount": reasoningTokens + completionTokens,
+ "totalTokenCount": promptTokens + reasoningTokens + completionTokens,
+ }
+}
+
+func buildGeminiPartsFromFinal(finalText, finalThinking string, toolNames []string) []map[string]any {
+ detected := util.ParseToolCalls(finalText, toolNames)
+ if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
+ detected = util.ParseToolCalls(finalThinking, toolNames)
+ }
+ if len(detected) > 0 {
+ parts := make([]map[string]any, 0, len(detected))
+ for _, tc := range detected {
+ parts = append(parts, map[string]any{
+ "functionCall": map[string]any{
+ "name": tc.Name,
+ "args": tc.Input,
+ },
+ })
+ }
+ return parts
+ }
+
+ text := finalText
+ if strings.TrimSpace(text) == "" {
+ text = finalThinking
+ }
+ return []map[string]any{{"text": text}}
+}
diff --git a/internal/adapter/gemini/handler_routes.go b/internal/adapter/gemini/handler_routes.go
new file mode 100644
index 0000000..6850b51
--- /dev/null
+++ b/internal/adapter/gemini/handler_routes.go
@@ -0,0 +1,32 @@
+package gemini
+
+import (
+ "net/http"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/util"
+)
+
+var writeJSON = util.WriteJSON
+
+type Handler struct {
+ Store ConfigReader
+ Auth AuthResolver
+ DS DeepSeekCaller
+}
+
+func RegisterRoutes(r chi.Router, h *Handler) {
+ r.Post("/v1beta/models/{model}:generateContent", h.GenerateContent)
+ r.Post("/v1beta/models/{model}:streamGenerateContent", h.StreamGenerateContent)
+ r.Post("/v1/models/{model}:generateContent", h.GenerateContent)
+ r.Post("/v1/models/{model}:streamGenerateContent", h.StreamGenerateContent)
+}
+
+func (h *Handler) GenerateContent(w http.ResponseWriter, r *http.Request) {
+ h.handleGenerateContent(w, r, false)
+}
+
+func (h *Handler) StreamGenerateContent(w http.ResponseWriter, r *http.Request) {
+ h.handleGenerateContent(w, r, true)
+}
diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go
new file mode 100644
index 0000000..83a393d
--- /dev/null
+++ b/internal/adapter/gemini/handler_stream_runtime.go
@@ -0,0 +1,175 @@
+package gemini
+
+import (
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+
+ "ds2api/internal/deepseek"
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+)
+
+func (h *Handler) handleStreamGenerateContent(w http.ResponseWriter, r *http.Request, resp *http.Response, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeGeminiError(w, resp.StatusCode, strings.TrimSpace(string(body)))
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache, no-transform")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("X-Accel-Buffering", "no")
+
+ rc := http.NewResponseController(w)
+ _, canFlush := w.(http.Flusher)
+ runtime := newGeminiStreamRuntime(w, rc, canFlush, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames)
+
+ initialType := "text"
+ if thinkingEnabled {
+ initialType = "thinking"
+ }
+ streamengine.ConsumeSSE(streamengine.ConsumeConfig{
+ Context: r.Context(),
+ Body: resp.Body,
+ ThinkingEnabled: thinkingEnabled,
+ InitialType: initialType,
+ KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
+ IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
+ MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
+ }, streamengine.ConsumeHooks{
+ OnParsed: runtime.onParsed,
+ OnFinalize: func(_ streamengine.StopReason, _ error) {
+ runtime.finalize()
+ },
+ })
+}
+
+type geminiStreamRuntime struct {
+ w http.ResponseWriter
+ rc *http.ResponseController
+ canFlush bool
+
+ model string
+ finalPrompt string
+
+ thinkingEnabled bool
+ searchEnabled bool
+ bufferContent bool
+ toolNames []string
+
+ thinking strings.Builder
+ text strings.Builder
+}
+
+func newGeminiStreamRuntime(
+ w http.ResponseWriter,
+ rc *http.ResponseController,
+ canFlush bool,
+ model string,
+ finalPrompt string,
+ thinkingEnabled bool,
+ searchEnabled bool,
+ toolNames []string,
+) *geminiStreamRuntime {
+ return &geminiStreamRuntime{
+ w: w,
+ rc: rc,
+ canFlush: canFlush,
+ model: model,
+ finalPrompt: finalPrompt,
+ thinkingEnabled: thinkingEnabled,
+ searchEnabled: searchEnabled,
+ bufferContent: len(toolNames) > 0,
+ toolNames: toolNames,
+ }
+}
+
+func (s *geminiStreamRuntime) sendChunk(payload map[string]any) {
+ b, _ := json.Marshal(payload)
+ _, _ = s.w.Write([]byte("data: "))
+ _, _ = s.w.Write(b)
+ _, _ = s.w.Write([]byte("\n\n"))
+ if s.canFlush {
+ _ = s.rc.Flush()
+ }
+}
+
+func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
+ if !parsed.Parsed {
+ return streamengine.ParsedDecision{}
+ }
+ if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
+ return streamengine.ParsedDecision{Stop: true}
+ }
+
+ contentSeen := false
+ for _, p := range parsed.Parts {
+ if p.Text == "" {
+ continue
+ }
+ if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
+ continue
+ }
+ contentSeen = true
+ if p.Type == "thinking" {
+ if s.thinkingEnabled {
+ s.thinking.WriteString(p.Text)
+ }
+ continue
+ }
+ s.text.WriteString(p.Text)
+ if s.bufferContent {
+ continue
+ }
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": []map[string]any{{"text": p.Text}},
+ },
+ },
+ },
+ "modelVersion": s.model,
+ })
+ }
+ return streamengine.ParsedDecision{ContentSeen: contentSeen}
+}
+
+func (s *geminiStreamRuntime) finalize() {
+ finalThinking := s.thinking.String()
+ finalText := s.text.String()
+
+ if s.bufferContent {
+ parts := buildGeminiPartsFromFinal(finalText, finalThinking, s.toolNames)
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "content": map[string]any{
+ "role": "model",
+ "parts": parts,
+ },
+ },
+ },
+ "modelVersion": s.model,
+ })
+ }
+
+ s.sendChunk(map[string]any{
+ "candidates": []map[string]any{
+ {
+ "index": 0,
+ "finishReason": "STOP",
+ },
+ },
+ "modelVersion": s.model,
+ "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
+ })
+}
diff --git a/internal/adapter/openai/error_shape_test.go b/internal/adapter/openai/error_shape_test.go
index c169e04..8c73e4b 100644
--- a/internal/adapter/openai/error_shape_test.go
+++ b/internal/adapter/openai/error_shape_test.go
@@ -32,4 +32,3 @@ func TestWriteOpenAIErrorIncludesUnifiedFields(t *testing.T) {
t.Fatal("expected param field")
}
}
-
diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go
deleted file mode 100644
index 391a035..0000000
--- a/internal/adapter/openai/handler.go
+++ /dev/null
@@ -1,386 +0,0 @@
-package openai
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "sync"
- "time"
-
- "github.com/go-chi/chi/v5"
- "github.com/google/uuid"
-
- "ds2api/internal/auth"
- "ds2api/internal/config"
- "ds2api/internal/deepseek"
- openaifmt "ds2api/internal/format/openai"
- "ds2api/internal/sse"
- streamengine "ds2api/internal/stream"
- "ds2api/internal/util"
-)
-
-// writeJSON is a package-internal alias kept to avoid mass-renaming across
-// every call-site in this file. It delegates to the shared util version.
-var writeJSON = util.WriteJSON
-
-type Handler struct {
- Store ConfigReader
- Auth AuthResolver
- DS DeepSeekCaller
-
- leaseMu sync.Mutex
- streamLeases map[string]streamLease
- responsesMu sync.Mutex
- responses *responseStore
-}
-
-type streamLease struct {
- Auth *auth.RequestAuth
- ExpiresAt time.Time
-}
-
-func RegisterRoutes(r chi.Router, h *Handler) {
- r.Get("/v1/models", h.ListModels)
- r.Get("/v1/models/{model_id}", h.GetModel)
- r.Post("/v1/chat/completions", h.ChatCompletions)
- r.Post("/v1/responses", h.Responses)
- r.Get("/v1/responses/{response_id}", h.GetResponseByID)
- r.Post("/v1/embeddings", h.Embeddings)
-}
-
-func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
- writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
-}
-
-func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
- modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
- model, ok := config.OpenAIModelByID(h.Store, modelID)
- if !ok {
- writeOpenAIError(w, http.StatusNotFound, "Model not found.")
- return
- }
- writeJSON(w, http.StatusOK, model)
-}
-
-func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
- if isVercelStreamReleaseRequest(r) {
- h.handleVercelStreamRelease(w, r)
- return
- }
- if isVercelStreamPrepareRequest(r) {
- h.handleVercelStreamPrepare(w, r)
- return
- }
-
- a, err := h.Auth.Determine(r)
- if err != nil {
- status := http.StatusUnauthorized
- detail := err.Error()
- if err == auth.ErrNoAccount {
- status = http.StatusTooManyRequests
- }
- writeOpenAIError(w, status, detail)
- return
- }
- defer h.Auth.Release(a)
- r = r.WithContext(auth.WithAuth(r.Context(), a))
-
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeOpenAIError(w, http.StatusBadRequest, "invalid json")
- return
- }
- stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
- if err != nil {
- writeOpenAIError(w, http.StatusBadRequest, err.Error())
- return
- }
-
- sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
- if err != nil {
- if a.UseConfigToken {
- writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
- } else {
- writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
- }
- return
- }
- pow, err := h.DS.GetPow(r.Context(), a, 3)
- if err != nil {
- writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
- return
- }
- payload := stdReq.CompletionPayload(sessionID)
- resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
- if err != nil {
- writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
- return
- }
- if stdReq.Stream {
- h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
- return
- }
- h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
-}
-
-func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
- if resp.StatusCode != http.StatusOK {
- defer resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
- writeOpenAIError(w, resp.StatusCode, string(body))
- return
- }
- _ = ctx
- result := sse.CollectStream(resp, thinkingEnabled, true)
-
- finalThinking := result.Thinking
- finalText := result.Text
- respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
- writeJSON(w, http.StatusOK, respBody)
-}
-
-func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- body, _ := io.ReadAll(resp.Body)
- writeOpenAIError(w, resp.StatusCode, string(body))
- return
- }
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set("Cache-Control", "no-cache, no-transform")
- w.Header().Set("Connection", "keep-alive")
- w.Header().Set("X-Accel-Buffering", "no")
- rc := http.NewResponseController(w)
- _, canFlush := w.(http.Flusher)
- if !canFlush {
- config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
- }
-
- created := time.Now().Unix()
- bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
- emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
- initialType := "text"
- if thinkingEnabled {
- initialType = "thinking"
- }
-
- streamRuntime := newChatStreamRuntime(
- w,
- rc,
- canFlush,
- completionID,
- created,
- model,
- finalPrompt,
- thinkingEnabled,
- searchEnabled,
- toolNames,
- bufferToolContent,
- emitEarlyToolDeltas,
- )
-
- streamengine.ConsumeSSE(streamengine.ConsumeConfig{
- Context: r.Context(),
- Body: resp.Body,
- ThinkingEnabled: thinkingEnabled,
- InitialType: initialType,
- KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
- IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
- MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
- }, streamengine.ConsumeHooks{
- OnKeepAlive: func() {
- streamRuntime.sendKeepAlive()
- },
- OnParsed: streamRuntime.onParsed,
- OnFinalize: func(reason streamengine.StopReason, _ error) {
- if string(reason) == "content_filter" {
- streamRuntime.finalize("content_filter")
- return
- }
- streamRuntime.finalize("stop")
- },
- })
-}
-
-func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
- toolSchemas := make([]string, 0, len(tools))
- names := make([]string, 0, len(tools))
- for _, t := range tools {
- tool, ok := t.(map[string]any)
- if !ok {
- continue
- }
- fn, _ := tool["function"].(map[string]any)
- if len(fn) == 0 {
- fn = tool
- }
- name, _ := fn["name"].(string)
- desc, _ := fn["description"].(string)
- schema, _ := fn["parameters"].(map[string]any)
- if name == "" {
- name = "unknown"
- }
- names = append(names, name)
- if desc == "" {
- desc = "No description available"
- }
- b, _ := json.Marshal(schema)
- toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
- }
- if len(toolSchemas) == 0 {
- return messages, names
- }
- toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
-
- for i := range messages {
- if messages[i]["role"] == "system" {
- old, _ := messages[i]["content"].(string)
- messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
- return messages, names
- }
- }
- messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
- return messages, names
-}
-
-func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
- if len(deltas) == 0 {
- return nil
- }
- out := make([]map[string]any, 0, len(deltas))
- for _, d := range deltas {
- if d.Name == "" && d.Arguments == "" {
- continue
- }
- callID, ok := ids[d.Index]
- if !ok || callID == "" {
- callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- ids[d.Index] = callID
- }
- item := map[string]any{
- "index": d.Index,
- "id": callID,
- "type": "function",
- }
- fn := map[string]any{}
- if d.Name != "" {
- fn["name"] = d.Name
- }
- if d.Arguments != "" {
- fn["arguments"] = d.Arguments
- }
- if len(fn) > 0 {
- item["function"] = fn
- }
- out = append(out, item)
- }
- return out
-}
-
-func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
- if len(calls) == 0 {
- return nil
- }
- out := make([]map[string]any, 0, len(calls))
- for i, c := range calls {
- callID := ""
- if ids != nil {
- callID = strings.TrimSpace(ids[i])
- }
- if callID == "" {
- callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- if ids != nil {
- ids[i] = callID
- }
- }
- args, _ := json.Marshal(c.Input)
- out = append(out, map[string]any{
- "index": i,
- "id": callID,
- "type": "function",
- "function": map[string]any{
- "name": c.Name,
- "arguments": string(args),
- },
- })
- }
- return out
-}
-
-func writeOpenAIError(w http.ResponseWriter, status int, message string) {
- writeJSON(w, status, map[string]any{
- "error": map[string]any{
- "message": message,
- "type": openAIErrorType(status),
- "code": openAIErrorCode(status),
- "param": nil,
- },
- })
-}
-
-func openAIErrorType(status int) string {
- switch status {
- case http.StatusBadRequest:
- return "invalid_request_error"
- case http.StatusUnauthorized:
- return "authentication_error"
- case http.StatusForbidden:
- return "permission_error"
- case http.StatusTooManyRequests:
- return "rate_limit_error"
- case http.StatusServiceUnavailable:
- return "service_unavailable_error"
- default:
- if status >= 500 {
- return "api_error"
- }
- return "invalid_request_error"
- }
-}
-
-func openAIErrorCode(status int) string {
- switch status {
- case http.StatusBadRequest:
- return "invalid_request"
- case http.StatusUnauthorized:
- return "authentication_failed"
- case http.StatusForbidden:
- return "forbidden"
- case http.StatusTooManyRequests:
- return "rate_limit_exceeded"
- case http.StatusNotFound:
- return "not_found"
- case http.StatusServiceUnavailable:
- return "service_unavailable"
- default:
- if status >= 500 {
- return "internal_error"
- }
- return "invalid_request"
- }
-}
-
-func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
- for k, v := range collectOpenAIChatPassThrough(req) {
- payload[k] = v
- }
-}
-
-func (h *Handler) toolcallFeatureMatchEnabled() bool {
- if h == nil || h.Store == nil {
- return true
- }
- mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
- return mode == "" || mode == "feature_match"
-}
-
-func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
- if h == nil || h.Store == nil {
- return true
- }
- level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
- return level == "" || level == "high"
-}
diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go
new file mode 100644
index 0000000..26a4bf2
--- /dev/null
+++ b/internal/adapter/openai/handler_chat.go
@@ -0,0 +1,156 @@
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "time"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+ "ds2api/internal/deepseek"
+ openaifmt "ds2api/internal/format/openai"
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+)
+
+func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
+ if isVercelStreamReleaseRequest(r) {
+ h.handleVercelStreamRelease(w, r)
+ return
+ }
+ if isVercelStreamPrepareRequest(r) {
+ h.handleVercelStreamPrepare(w, r)
+ return
+ }
+
+ a, err := h.Auth.Determine(r)
+ if err != nil {
+ status := http.StatusUnauthorized
+ detail := err.Error()
+ if err == auth.ErrNoAccount {
+ status = http.StatusTooManyRequests
+ }
+ writeOpenAIError(w, status, detail)
+ return
+ }
+ defer h.Auth.Release(a)
+ r = r.WithContext(auth.WithAuth(r.Context(), a))
+
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeOpenAIError(w, http.StatusBadRequest, "invalid json")
+ return
+ }
+ stdReq, err := normalizeOpenAIChatRequest(h.Store, req, requestTraceID(r))
+ if err != nil {
+ writeOpenAIError(w, http.StatusBadRequest, err.Error())
+ return
+ }
+
+ sessionID, err := h.DS.CreateSession(r.Context(), a, 3)
+ if err != nil {
+ if a.UseConfigToken {
+ writeOpenAIError(w, http.StatusUnauthorized, "Account token is invalid. Please re-login the account in admin.")
+ } else {
+ writeOpenAIError(w, http.StatusUnauthorized, "Invalid token. If this should be a DS2API key, add it to config.keys first.")
+ }
+ return
+ }
+ pow, err := h.DS.GetPow(r.Context(), a, 3)
+ if err != nil {
+ writeOpenAIError(w, http.StatusUnauthorized, "Failed to get PoW (invalid token or unknown error).")
+ return
+ }
+ payload := stdReq.CompletionPayload(sessionID)
+ resp, err := h.DS.CallCompletion(r.Context(), a, payload, pow, 3)
+ if err != nil {
+ writeOpenAIError(w, http.StatusInternalServerError, "Failed to get completion.")
+ return
+ }
+ if stdReq.Stream {
+ h.handleStream(w, r, resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames)
+ return
+ }
+ h.handleNonStream(w, r.Context(), resp, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.ToolNames)
+}
+
+func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled bool, toolNames []string) {
+ if resp.StatusCode != http.StatusOK {
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+ writeOpenAIError(w, resp.StatusCode, string(body))
+ return
+ }
+ _ = ctx
+ result := sse.CollectStream(resp, thinkingEnabled, true)
+
+ finalThinking := result.Thinking
+ finalText := result.Text
+ respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
+ writeJSON(w, http.StatusOK, respBody)
+}
+
+func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) {
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ writeOpenAIError(w, resp.StatusCode, string(body))
+ return
+ }
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache, no-transform")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("X-Accel-Buffering", "no")
+ rc := http.NewResponseController(w)
+ _, canFlush := w.(http.Flusher)
+ if !canFlush {
+ config.Logger.Warn("[stream] response writer does not support flush; streaming may be buffered")
+ }
+
+ created := time.Now().Unix()
+ bufferToolContent := len(toolNames) > 0 && h.toolcallFeatureMatchEnabled()
+ emitEarlyToolDeltas := h.toolcallEarlyEmitHighConfidence()
+ initialType := "text"
+ if thinkingEnabled {
+ initialType = "thinking"
+ }
+
+ streamRuntime := newChatStreamRuntime(
+ w,
+ rc,
+ canFlush,
+ completionID,
+ created,
+ model,
+ finalPrompt,
+ thinkingEnabled,
+ searchEnabled,
+ toolNames,
+ bufferToolContent,
+ emitEarlyToolDeltas,
+ )
+
+ streamengine.ConsumeSSE(streamengine.ConsumeConfig{
+ Context: r.Context(),
+ Body: resp.Body,
+ ThinkingEnabled: thinkingEnabled,
+ InitialType: initialType,
+ KeepAliveInterval: time.Duration(deepseek.KeepAliveTimeout) * time.Second,
+ IdleTimeout: time.Duration(deepseek.StreamIdleTimeout) * time.Second,
+ MaxKeepAliveNoInput: deepseek.MaxKeepaliveCount,
+ }, streamengine.ConsumeHooks{
+ OnKeepAlive: func() {
+ streamRuntime.sendKeepAlive()
+ },
+ OnParsed: streamRuntime.onParsed,
+ OnFinalize: func(reason streamengine.StopReason, _ error) {
+ if string(reason) == "content_filter" {
+ streamRuntime.finalize("content_filter")
+ return
+ }
+ streamRuntime.finalize("stop")
+ },
+ })
+}
diff --git a/internal/adapter/openai/handler_errors.go b/internal/adapter/openai/handler_errors.go
new file mode 100644
index 0000000..62249d2
--- /dev/null
+++ b/internal/adapter/openai/handler_errors.go
@@ -0,0 +1,56 @@
+package openai
+
+import "net/http"
+
+func writeOpenAIError(w http.ResponseWriter, status int, message string) {
+ writeJSON(w, status, map[string]any{
+ "error": map[string]any{
+ "message": message,
+ "type": openAIErrorType(status),
+ "code": openAIErrorCode(status),
+ "param": nil,
+ },
+ })
+}
+
+func openAIErrorType(status int) string {
+ switch status {
+ case http.StatusBadRequest:
+ return "invalid_request_error"
+ case http.StatusUnauthorized:
+ return "authentication_error"
+ case http.StatusForbidden:
+ return "permission_error"
+ case http.StatusTooManyRequests:
+ return "rate_limit_error"
+ case http.StatusServiceUnavailable:
+ return "service_unavailable_error"
+ default:
+ if status >= 500 {
+ return "api_error"
+ }
+ return "invalid_request_error"
+ }
+}
+
+func openAIErrorCode(status int) string {
+ switch status {
+ case http.StatusBadRequest:
+ return "invalid_request"
+ case http.StatusUnauthorized:
+ return "authentication_failed"
+ case http.StatusForbidden:
+ return "forbidden"
+ case http.StatusTooManyRequests:
+ return "rate_limit_exceeded"
+ case http.StatusNotFound:
+ return "not_found"
+ case http.StatusServiceUnavailable:
+ return "service_unavailable"
+ default:
+ if status >= 500 {
+ return "internal_error"
+ }
+ return "invalid_request"
+ }
+}
diff --git a/internal/adapter/openai/handler_routes.go b/internal/adapter/openai/handler_routes.go
new file mode 100644
index 0000000..a0cfcd6
--- /dev/null
+++ b/internal/adapter/openai/handler_routes.go
@@ -0,0 +1,57 @@
+package openai
+
+import (
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+ "ds2api/internal/util"
+)
+
+// writeJSON is a package-internal alias kept to avoid mass-renaming across
+// every call-site in this package.
+var writeJSON = util.WriteJSON
+
+type Handler struct {
+ Store ConfigReader
+ Auth AuthResolver
+ DS DeepSeekCaller
+
+ leaseMu sync.Mutex
+ streamLeases map[string]streamLease
+ responsesMu sync.Mutex
+ responses *responseStore
+}
+
+type streamLease struct {
+ Auth *auth.RequestAuth
+ ExpiresAt time.Time
+}
+
+func RegisterRoutes(r chi.Router, h *Handler) {
+ r.Get("/v1/models", h.ListModels)
+ r.Get("/v1/models/{model_id}", h.GetModel)
+ r.Post("/v1/chat/completions", h.ChatCompletions)
+ r.Post("/v1/responses", h.Responses)
+ r.Get("/v1/responses/{response_id}", h.GetResponseByID)
+ r.Post("/v1/embeddings", h.Embeddings)
+}
+
+func (h *Handler) ListModels(w http.ResponseWriter, _ *http.Request) {
+ writeJSON(w, http.StatusOK, config.OpenAIModelsResponse())
+}
+
+func (h *Handler) GetModel(w http.ResponseWriter, r *http.Request) {
+ modelID := strings.TrimSpace(chi.URLParam(r, "model_id"))
+ model, ok := config.OpenAIModelByID(h.Store, modelID)
+ if !ok {
+ writeOpenAIError(w, http.StatusNotFound, "Model not found.")
+ return
+ }
+ writeJSON(w, http.StatusOK, model)
+}
diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go
new file mode 100644
index 0000000..d939c68
--- /dev/null
+++ b/internal/adapter/openai/handler_toolcall_format.go
@@ -0,0 +1,116 @@
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/google/uuid"
+
+ "ds2api/internal/util"
+)
+
+func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
+ toolSchemas := make([]string, 0, len(tools))
+ names := make([]string, 0, len(tools))
+ for _, t := range tools {
+ tool, ok := t.(map[string]any)
+ if !ok {
+ continue
+ }
+ fn, _ := tool["function"].(map[string]any)
+ if len(fn) == 0 {
+ fn = tool
+ }
+ name, _ := fn["name"].(string)
+ desc, _ := fn["description"].(string)
+ schema, _ := fn["parameters"].(map[string]any)
+ if name == "" {
+ name = "unknown"
+ }
+ names = append(names, name)
+ if desc == "" {
+ desc = "No description available"
+ }
+ b, _ := json.Marshal(schema)
+ toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
+ }
+ if len(toolSchemas) == 0 {
+ return messages, names
+ }
+ toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
+
+ for i := range messages {
+ if messages[i]["role"] == "system" {
+ old, _ := messages[i]["content"].(string)
+ messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
+ return messages, names
+ }
+ }
+ messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
+ return messages, names
+}
+
+func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
+ if len(deltas) == 0 {
+ return nil
+ }
+ out := make([]map[string]any, 0, len(deltas))
+ for _, d := range deltas {
+ if d.Name == "" && d.Arguments == "" {
+ continue
+ }
+ callID, ok := ids[d.Index]
+ if !ok || callID == "" {
+ callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ ids[d.Index] = callID
+ }
+ item := map[string]any{
+ "index": d.Index,
+ "id": callID,
+ "type": "function",
+ }
+ fn := map[string]any{}
+ if d.Name != "" {
+ fn["name"] = d.Name
+ }
+ if d.Arguments != "" {
+ fn["arguments"] = d.Arguments
+ }
+ if len(fn) > 0 {
+ item["function"] = fn
+ }
+ out = append(out, item)
+ }
+ return out
+}
+
+func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
+ if len(calls) == 0 {
+ return nil
+ }
+ out := make([]map[string]any, 0, len(calls))
+ for i, c := range calls {
+ callID := ""
+ if ids != nil {
+ callID = strings.TrimSpace(ids[i])
+ }
+ if callID == "" {
+ callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ if ids != nil {
+ ids[i] = callID
+ }
+ }
+ args, _ := json.Marshal(c.Input)
+ out = append(out, map[string]any{
+ "index": i,
+ "id": callID,
+ "type": "function",
+ "function": map[string]any{
+ "name": c.Name,
+ "arguments": string(args),
+ },
+ })
+ }
+ return out
+}
diff --git a/internal/adapter/openai/handler_toolcall_policy.go b/internal/adapter/openai/handler_toolcall_policy.go
new file mode 100644
index 0000000..9f0e839
--- /dev/null
+++ b/internal/adapter/openai/handler_toolcall_policy.go
@@ -0,0 +1,25 @@
+package openai
+
+import "strings"
+
+func applyOpenAIChatPassThrough(req map[string]any, payload map[string]any) {
+ for k, v := range collectOpenAIChatPassThrough(req) {
+ payload[k] = v
+ }
+}
+
+func (h *Handler) toolcallFeatureMatchEnabled() bool {
+ if h == nil || h.Store == nil {
+ return true
+ }
+ mode := strings.TrimSpace(strings.ToLower(h.Store.ToolcallMode()))
+ return mode == "" || mode == "feature_match"
+}
+
+func (h *Handler) toolcallEarlyEmitHighConfidence() bool {
+ if h == nil || h.Store == nil {
+ return true
+ }
+ level := strings.TrimSpace(strings.ToLower(h.Store.ToolcallEarlyEmitConfidence()))
+ return level == "" || level == "high"
+}
diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go
index e71cafe..7b35e0c 100644
--- a/internal/adapter/openai/responses_handler.go
+++ b/internal/adapter/openai/responses_handler.go
@@ -2,7 +2,6 @@ package openai
import (
"encoding/json"
- "fmt"
"io"
"net/http"
"strings"
@@ -170,264 +169,3 @@ func (h *Handler) handleResponsesStream(w http.ResponseWriter, r *http.Request,
},
})
}
-
-func responsesMessagesFromRequest(req map[string]any) []any {
- if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
- return prependInstructionMessage(msgs, req["instructions"])
- }
- if rawInput, ok := req["input"]; ok {
- if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
- return prependInstructionMessage(msgs, req["instructions"])
- }
- }
- return nil
-}
-
-func prependInstructionMessage(messages []any, instructions any) []any {
- sys, _ := instructions.(string)
- sys = strings.TrimSpace(sys)
- if sys == "" {
- return messages
- }
- out := make([]any, 0, len(messages)+1)
- out = append(out, map[string]any{"role": "system", "content": sys})
- out = append(out, messages...)
- return out
-}
-
-func normalizeResponsesInputAsMessages(input any) []any {
- switch v := input.(type) {
- case string:
- if strings.TrimSpace(v) == "" {
- return nil
- }
- return []any{map[string]any{"role": "user", "content": v}}
- case []any:
- return normalizeResponsesInputArray(v)
- case map[string]any:
- if msg := normalizeResponsesInputItem(v); msg != nil {
- return []any{msg}
- }
- if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
- return []any{map[string]any{"role": "user", "content": txt}}
- }
- if content, ok := v["content"]; ok {
- if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
- return []any{map[string]any{"role": "user", "content": content}}
- }
- }
- }
- return nil
-}
-
-func normalizeResponsesInputArray(items []any) []any {
- if len(items) == 0 {
- return nil
- }
- out := make([]any, 0, len(items))
- fallbackParts := make([]string, 0, len(items))
- flushFallback := func() {
- if len(fallbackParts) == 0 {
- return
- }
- out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
- fallbackParts = fallbackParts[:0]
- }
-
- for _, item := range items {
- switch x := item.(type) {
- case map[string]any:
- if msg := normalizeResponsesInputItem(x); msg != nil {
- flushFallback()
- out = append(out, msg)
- continue
- }
- if s := normalizeResponsesFallbackPart(x); s != "" {
- fallbackParts = append(fallbackParts, s)
- }
- default:
- if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
- fallbackParts = append(fallbackParts, s)
- }
- }
- }
- flushFallback()
- if len(out) == 0 {
- return nil
- }
- return out
-}
-
-func normalizeResponsesInputItem(m map[string]any) map[string]any {
- if m == nil {
- return nil
- }
-
- role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
- if role != "" {
- content := m["content"]
- if content == nil {
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- content = txt
- }
- }
- if content == nil {
- return nil
- }
- return map[string]any{
- "role": role,
- "content": content,
- }
- }
-
- itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
- switch itemType {
- case "message", "input_message":
- content := m["content"]
- if content == nil {
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- content = txt
- }
- }
- if content == nil {
- return nil
- }
- role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
- if role == "" {
- role = "user"
- }
- return map[string]any{
- "role": role,
- "content": content,
- }
- case "function_call_output", "tool_result":
- content := m["output"]
- if content == nil {
- content = m["content"]
- }
- if content == nil {
- content = ""
- }
- out := map[string]any{
- "role": "tool",
- "content": content,
- }
- if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
- out["tool_call_id"] = callID
- } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
- out["tool_call_id"] = callID
- }
- if name := strings.TrimSpace(asString(m["name"])); name != "" {
- out["name"] = name
- } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
- out["name"] = name
- }
- return out
- case "function_call", "tool_call":
- name := strings.TrimSpace(asString(m["name"]))
- var fn map[string]any
- if rawFn, ok := m["function"].(map[string]any); ok {
- fn = rawFn
- if name == "" {
- name = strings.TrimSpace(asString(fn["name"]))
- }
- }
- if name == "" {
- return nil
- }
-
- var argsRaw any
- if v, ok := m["arguments"]; ok {
- argsRaw = v
- } else if v, ok := m["input"]; ok {
- argsRaw = v
- }
- if argsRaw == nil && fn != nil {
- if v, ok := fn["arguments"]; ok {
- argsRaw = v
- } else if v, ok := fn["input"]; ok {
- argsRaw = v
- }
- }
-
- functionPayload := map[string]any{
- "name": name,
- "arguments": stringifyToolCallArguments(argsRaw),
- }
- call := map[string]any{
- "type": "function",
- "function": functionPayload,
- }
- if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
- call["id"] = callID
- } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
- call["id"] = callID
- }
- return map[string]any{
- "role": "assistant",
- "tool_calls": []any{call},
- }
- case "input_text":
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- return map[string]any{
- "role": "user",
- "content": txt,
- }
- }
- }
-
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- return map[string]any{
- "role": "user",
- "content": txt,
- }
- }
- if content, ok := m["content"]; ok {
- if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
- return map[string]any{
- "role": "user",
- "content": content,
- }
- }
- }
- return nil
-}
-
-func normalizeResponsesFallbackPart(m map[string]any) string {
- if m == nil {
- return ""
- }
- if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- return txt
- }
- }
- if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
- return txt
- }
- if content, ok := m["content"]; ok {
- if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
- return normalized
- }
- }
- return strings.TrimSpace(fmt.Sprintf("%v", m))
-}
-
-func stringifyToolCallArguments(v any) string {
- switch x := v.(type) {
- case nil:
- return "{}"
- case string:
- s := strings.TrimSpace(x)
- if s == "" {
- return "{}"
- }
- return s
- default:
- b, err := json.Marshal(x)
- if err != nil || len(b) == 0 {
- return "{}"
- }
- return string(b)
- }
-}
diff --git a/internal/adapter/openai/responses_input_items.go b/internal/adapter/openai/responses_input_items.go
new file mode 100644
index 0000000..2a2dfc4
--- /dev/null
+++ b/internal/adapter/openai/responses_input_items.go
@@ -0,0 +1,181 @@
+package openai
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+)
+
+func normalizeResponsesInputItem(m map[string]any) map[string]any {
+ if m == nil {
+ return nil
+ }
+
+ role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
+ if role != "" {
+ content := m["content"]
+ if content == nil {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ content = txt
+ }
+ }
+ if content == nil {
+ return nil
+ }
+ return map[string]any{
+ "role": role,
+ "content": content,
+ }
+ }
+
+ itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
+ switch itemType {
+ case "message", "input_message":
+ content := m["content"]
+ if content == nil {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ content = txt
+ }
+ }
+ if content == nil {
+ return nil
+ }
+ role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
+ if role == "" {
+ role = "user"
+ }
+ return map[string]any{
+ "role": role,
+ "content": content,
+ }
+ case "function_call_output", "tool_result":
+ content := m["output"]
+ if content == nil {
+ content = m["content"]
+ }
+ if content == nil {
+ content = ""
+ }
+ out := map[string]any{
+ "role": "tool",
+ "content": content,
+ }
+ if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
+ out["tool_call_id"] = callID
+ } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
+ out["tool_call_id"] = callID
+ }
+ if name := strings.TrimSpace(asString(m["name"])); name != "" {
+ out["name"] = name
+ } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
+ out["name"] = name
+ }
+ return out
+ case "function_call", "tool_call":
+ name := strings.TrimSpace(asString(m["name"]))
+ var fn map[string]any
+ if rawFn, ok := m["function"].(map[string]any); ok {
+ fn = rawFn
+ if name == "" {
+ name = strings.TrimSpace(asString(fn["name"]))
+ }
+ }
+ if name == "" {
+ return nil
+ }
+
+ var argsRaw any
+ if v, ok := m["arguments"]; ok {
+ argsRaw = v
+ } else if v, ok := m["input"]; ok {
+ argsRaw = v
+ }
+ if argsRaw == nil && fn != nil {
+ if v, ok := fn["arguments"]; ok {
+ argsRaw = v
+ } else if v, ok := fn["input"]; ok {
+ argsRaw = v
+ }
+ }
+
+ functionPayload := map[string]any{
+ "name": name,
+ "arguments": stringifyToolCallArguments(argsRaw),
+ }
+ call := map[string]any{
+ "type": "function",
+ "function": functionPayload,
+ }
+ if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
+ call["id"] = callID
+ } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
+ call["id"] = callID
+ }
+ return map[string]any{
+ "role": "assistant",
+ "tool_calls": []any{call},
+ }
+ case "input_text":
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": txt,
+ }
+ }
+ }
+
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": txt,
+ }
+ }
+ if content, ok := m["content"]; ok {
+ if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
+ return map[string]any{
+ "role": "user",
+ "content": content,
+ }
+ }
+ }
+ return nil
+}
+
+func normalizeResponsesFallbackPart(m map[string]any) string {
+ if m == nil {
+ return ""
+ }
+ if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return txt
+ }
+ }
+ if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
+ return txt
+ }
+ if content, ok := m["content"]; ok {
+ if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
+ return normalized
+ }
+ }
+ return strings.TrimSpace(fmt.Sprintf("%v", m))
+}
+
+func stringifyToolCallArguments(v any) string {
+ switch x := v.(type) {
+ case nil:
+ return "{}"
+ case string:
+ s := strings.TrimSpace(x)
+ if s == "" {
+ return "{}"
+ }
+ return s
+ default:
+ b, err := json.Marshal(x)
+ if err != nil || len(b) == 0 {
+ return "{}"
+ }
+ return string(b)
+ }
+}
diff --git a/internal/adapter/openai/responses_input_normalize.go b/internal/adapter/openai/responses_input_normalize.go
new file mode 100644
index 0000000..13f9e1a
--- /dev/null
+++ b/internal/adapter/openai/responses_input_normalize.go
@@ -0,0 +1,93 @@
+package openai
+
+import (
+ "fmt"
+ "strings"
+)
+
+func responsesMessagesFromRequest(req map[string]any) []any {
+ if msgs, ok := req["messages"].([]any); ok && len(msgs) > 0 {
+ return prependInstructionMessage(msgs, req["instructions"])
+ }
+ if rawInput, ok := req["input"]; ok {
+ if msgs := normalizeResponsesInputAsMessages(rawInput); len(msgs) > 0 {
+ return prependInstructionMessage(msgs, req["instructions"])
+ }
+ }
+ return nil
+}
+
+func prependInstructionMessage(messages []any, instructions any) []any {
+ sys, _ := instructions.(string)
+ sys = strings.TrimSpace(sys)
+ if sys == "" {
+ return messages
+ }
+ out := make([]any, 0, len(messages)+1)
+ out = append(out, map[string]any{"role": "system", "content": sys})
+ out = append(out, messages...)
+ return out
+}
+
+func normalizeResponsesInputAsMessages(input any) []any {
+ switch v := input.(type) {
+ case string:
+ if strings.TrimSpace(v) == "" {
+ return nil
+ }
+ return []any{map[string]any{"role": "user", "content": v}}
+ case []any:
+ return normalizeResponsesInputArray(v)
+ case map[string]any:
+ if msg := normalizeResponsesInputItem(v); msg != nil {
+ return []any{msg}
+ }
+ if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
+ return []any{map[string]any{"role": "user", "content": txt}}
+ }
+ if content, ok := v["content"]; ok {
+ if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
+ return []any{map[string]any{"role": "user", "content": content}}
+ }
+ }
+ }
+ return nil
+}
+
+func normalizeResponsesInputArray(items []any) []any {
+ if len(items) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(items))
+ fallbackParts := make([]string, 0, len(items))
+ flushFallback := func() {
+ if len(fallbackParts) == 0 {
+ return
+ }
+ out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
+ fallbackParts = fallbackParts[:0]
+ }
+
+ for _, item := range items {
+ switch x := item.(type) {
+ case map[string]any:
+ if msg := normalizeResponsesInputItem(x); msg != nil {
+ flushFallback()
+ out = append(out, msg)
+ continue
+ }
+ if s := normalizeResponsesFallbackPart(x); s != "" {
+ fallbackParts = append(fallbackParts, s)
+ }
+ default:
+ if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
+ fallbackParts = append(fallbackParts, s)
+ }
+ }
+ }
+ flushFallback()
+ if len(out) == 0 {
+ return nil
+ }
+ return out
+}
diff --git a/internal/adapter/openai/responses_stream_runtime.go b/internal/adapter/openai/responses_stream_runtime.go
deleted file mode 100644
index 050965c..0000000
--- a/internal/adapter/openai/responses_stream_runtime.go
+++ /dev/null
@@ -1,366 +0,0 @@
-package openai
-
-import (
- "encoding/json"
- "net/http"
- "sort"
- "strings"
-
- openaifmt "ds2api/internal/format/openai"
- "ds2api/internal/sse"
- streamengine "ds2api/internal/stream"
- "ds2api/internal/util"
-
- "github.com/google/uuid"
-)
-
-type responsesStreamRuntime struct {
- w http.ResponseWriter
- rc *http.ResponseController
- canFlush bool
-
- responseID string
- model string
- finalPrompt string
- toolNames []string
-
- thinkingEnabled bool
- searchEnabled bool
-
- bufferToolContent bool
- emitEarlyToolDeltas bool
- toolCallsEmitted bool
- toolCallsDoneEmitted bool
-
- sieve toolStreamSieveState
- thinkingSieve toolStreamSieveState
- thinking strings.Builder
- text strings.Builder
- streamToolCallIDs map[int]string
- streamFunctionIDs map[int]string
- functionDone map[int]bool
- toolCallsDoneSigs map[string]bool
- reasoningItemID string
-
- persistResponse func(obj map[string]any)
-}
-
-func newResponsesStreamRuntime(
- w http.ResponseWriter,
- rc *http.ResponseController,
- canFlush bool,
- responseID string,
- model string,
- finalPrompt string,
- thinkingEnabled bool,
- searchEnabled bool,
- toolNames []string,
- bufferToolContent bool,
- emitEarlyToolDeltas bool,
- persistResponse func(obj map[string]any),
-) *responsesStreamRuntime {
- return &responsesStreamRuntime{
- w: w,
- rc: rc,
- canFlush: canFlush,
- responseID: responseID,
- model: model,
- finalPrompt: finalPrompt,
- thinkingEnabled: thinkingEnabled,
- searchEnabled: searchEnabled,
- toolNames: toolNames,
- bufferToolContent: bufferToolContent,
- emitEarlyToolDeltas: emitEarlyToolDeltas,
- streamToolCallIDs: map[int]string{},
- streamFunctionIDs: map[int]string{},
- functionDone: map[int]bool{},
- toolCallsDoneSigs: map[string]bool{},
- persistResponse: persistResponse,
- }
-}
-
-func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
- b, _ := json.Marshal(payload)
- _, _ = s.w.Write([]byte("event: " + event + "\n"))
- _, _ = s.w.Write([]byte("data: "))
- _, _ = s.w.Write(b)
- _, _ = s.w.Write([]byte("\n\n"))
- if s.canFlush {
- _ = s.rc.Flush()
- }
-}
-
-func (s *responsesStreamRuntime) sendCreated() {
- s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
-}
-
-func (s *responsesStreamRuntime) sendDone() {
- _, _ = s.w.Write([]byte("data: [DONE]\n\n"))
- if s.canFlush {
- _ = s.rc.Flush()
- }
-}
-
-func (s *responsesStreamRuntime) finalize() {
- finalThinking := s.thinking.String()
- finalText := s.text.String()
- if strings.TrimSpace(finalThinking) != "" {
- s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
- }
- if s.bufferToolContent {
- s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
- s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
- }
- // Compatibility fallback: some streams only emit incremental tool deltas.
- // Ensure final function_call_arguments.done is emitted at least once.
- if s.toolCallsEmitted {
- detected := util.ParseToolCalls(finalText, s.toolNames)
- if len(detected) == 0 {
- detected = util.ParseToolCalls(finalThinking, s.toolNames)
- }
- if len(detected) > 0 {
- if !s.toolCallsDoneEmitted {
- s.emitToolCallsDone(detected)
- } else {
- s.emitFunctionCallDoneEvents(detected)
- }
- }
- }
-
- obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
- if s.toolCallsEmitted {
- s.alignCompletedOutputCallIDs(obj)
- }
- if s.toolCallsEmitted {
- obj["status"] = "completed"
- }
- if s.persistResponse != nil {
- s.persistResponse(obj)
- }
- s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
- s.sendDone()
-}
-
-func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
- if !parsed.Parsed {
- return streamengine.ParsedDecision{}
- }
- if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
- return streamengine.ParsedDecision{Stop: true}
- }
-
- contentSeen := false
- for _, p := range parsed.Parts {
- if p.Text == "" {
- continue
- }
- if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
- continue
- }
- contentSeen = true
- if p.Type == "thinking" {
- if !s.thinkingEnabled {
- continue
- }
- s.thinking.WriteString(p.Text)
- s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
- s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
- if s.bufferToolContent {
- s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
- }
- continue
- }
-
- s.text.WriteString(p.Text)
- if !s.bufferToolContent {
- s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
- continue
- }
- s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
- }
-
- return streamengine.ParsedDecision{ContentSeen: contentSeen}
-}
-
-func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
- for _, evt := range events {
- if emitContent && evt.Content != "" {
- s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
- }
- if len(evt.ToolCallDeltas) > 0 {
- if !s.emitEarlyToolDeltas {
- continue
- }
- formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
- if len(formatted) == 0 {
- continue
- }
- s.toolCallsEmitted = true
- s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
- s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
- }
- if len(evt.ToolCalls) > 0 {
- s.emitToolCallsDone(evt.ToolCalls)
- }
- }
-}
-
-func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
- if len(calls) == 0 {
- return
- }
- sig := toolCallListSignature(calls)
- if sig != "" && s.toolCallsDoneSigs[sig] {
- return
- }
- if sig != "" {
- s.toolCallsDoneSigs[sig] = true
- }
- formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
- if len(formatted) == 0 {
- return
- }
- s.toolCallsEmitted = true
- s.toolCallsDoneEmitted = true
- s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
- s.emitFunctionCallDoneEvents(calls)
-}
-
-func (s *responsesStreamRuntime) ensureReasoningItemID() string {
- if strings.TrimSpace(s.reasoningItemID) != "" {
- return s.reasoningItemID
- }
- s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- return s.reasoningItemID
-}
-
-func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
- if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
- return id
- }
- id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- s.streamFunctionIDs[index] = id
- return id
-}
-
-func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
- if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
- return id
- }
- id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- s.streamToolCallIDs[index] = id
- return id
-}
-
-func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
- if strings.TrimSpace(s.thinking.String()) != "" {
- return 1
- }
- return 0
-}
-
-func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
- for _, d := range deltas {
- if strings.TrimSpace(d.Arguments) == "" {
- continue
- }
- outputIndex := s.functionOutputBaseIndex() + d.Index
- itemID := s.ensureFunctionItemID(outputIndex)
- callID := s.ensureToolCallID(d.Index)
- s.sendEvent(
- "response.function_call_arguments.delta",
- openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
- )
- }
-}
-
-func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
- base := s.functionOutputBaseIndex()
- for idx, tc := range calls {
- if strings.TrimSpace(tc.Name) == "" {
- continue
- }
- outputIndex := base + idx
- if s.functionDone[outputIndex] {
- continue
- }
- itemID := s.ensureFunctionItemID(outputIndex)
- callID := s.ensureToolCallID(idx)
- argsBytes, _ := json.Marshal(tc.Input)
- s.sendEvent(
- "response.function_call_arguments.done",
- openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
- )
- s.functionDone[outputIndex] = true
- }
-}
-
-func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
- if obj == nil || len(s.streamToolCallIDs) == 0 {
- return
- }
- output, _ := obj["output"].([]any)
- if len(output) == 0 {
- return
- }
- indices := make([]int, 0, len(s.streamToolCallIDs))
- for idx := range s.streamToolCallIDs {
- indices = append(indices, idx)
- }
- sort.Ints(indices)
- ordered := make([]string, 0, len(indices))
- for _, idx := range indices {
- id := strings.TrimSpace(s.streamToolCallIDs[idx])
- if id == "" {
- continue
- }
- ordered = append(ordered, id)
- }
- if len(ordered) == 0 {
- return
- }
-
- functionIdx := 0
- for _, item := range output {
- m, _ := item.(map[string]any)
- if m == nil {
- continue
- }
- typ, _ := m["type"].(string)
- switch typ {
- case "function_call":
- if functionIdx < len(ordered) {
- m["call_id"] = ordered[functionIdx]
- functionIdx++
- }
- case "tool_calls":
- tcArr, _ := m["tool_calls"].([]any)
- for i, raw := range tcArr {
- tc, _ := raw.(map[string]any)
- if tc == nil {
- continue
- }
- if i < len(ordered) {
- tc["id"] = ordered[i]
- }
- }
- }
- }
-}
-
-func toolCallListSignature(calls []util.ParsedToolCall) string {
- if len(calls) == 0 {
- return ""
- }
- var b strings.Builder
- for i, tc := range calls {
- if i > 0 {
- b.WriteString("|")
- }
- b.WriteString(strings.TrimSpace(tc.Name))
- b.WriteString(":")
- args, _ := json.Marshal(tc.Input)
- b.Write(args)
- }
- return b.String()
-}
diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go
new file mode 100644
index 0000000..5aad11e
--- /dev/null
+++ b/internal/adapter/openai/responses_stream_runtime_core.go
@@ -0,0 +1,157 @@
+package openai
+
+import (
+ "net/http"
+ "strings"
+
+ openaifmt "ds2api/internal/format/openai"
+ "ds2api/internal/sse"
+ streamengine "ds2api/internal/stream"
+ "ds2api/internal/util"
+)
+
+type responsesStreamRuntime struct {
+ w http.ResponseWriter
+ rc *http.ResponseController
+ canFlush bool
+
+ responseID string
+ model string
+ finalPrompt string
+ toolNames []string
+
+ thinkingEnabled bool
+ searchEnabled bool
+
+ bufferToolContent bool
+ emitEarlyToolDeltas bool
+ toolCallsEmitted bool
+ toolCallsDoneEmitted bool
+
+ sieve toolStreamSieveState
+ thinkingSieve toolStreamSieveState
+ thinking strings.Builder
+ text strings.Builder
+ streamToolCallIDs map[int]string
+ streamFunctionIDs map[int]string
+ functionDone map[int]bool
+ toolCallsDoneSigs map[string]bool
+ reasoningItemID string
+
+ persistResponse func(obj map[string]any)
+}
+
+func newResponsesStreamRuntime(
+ w http.ResponseWriter,
+ rc *http.ResponseController,
+ canFlush bool,
+ responseID string,
+ model string,
+ finalPrompt string,
+ thinkingEnabled bool,
+ searchEnabled bool,
+ toolNames []string,
+ bufferToolContent bool,
+ emitEarlyToolDeltas bool,
+ persistResponse func(obj map[string]any),
+) *responsesStreamRuntime {
+ return &responsesStreamRuntime{
+ w: w,
+ rc: rc,
+ canFlush: canFlush,
+ responseID: responseID,
+ model: model,
+ finalPrompt: finalPrompt,
+ thinkingEnabled: thinkingEnabled,
+ searchEnabled: searchEnabled,
+ toolNames: toolNames,
+ bufferToolContent: bufferToolContent,
+ emitEarlyToolDeltas: emitEarlyToolDeltas,
+ streamToolCallIDs: map[int]string{},
+ streamFunctionIDs: map[int]string{},
+ functionDone: map[int]bool{},
+ toolCallsDoneSigs: map[string]bool{},
+ persistResponse: persistResponse,
+ }
+}
+
+func (s *responsesStreamRuntime) finalize() {
+ finalThinking := s.thinking.String()
+ finalText := s.text.String()
+ if strings.TrimSpace(finalThinking) != "" {
+ s.sendEvent("response.reasoning_text.done", openaifmt.BuildResponsesReasoningTextDonePayload(s.responseID, s.ensureReasoningItemID(), 0, 0, finalThinking))
+ }
+ if s.bufferToolContent {
+ s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
+ s.processToolStreamEvents(flushToolSieve(&s.thinkingSieve, s.toolNames), false)
+ }
+ // Compatibility fallback: some streams only emit incremental tool deltas.
+ // Ensure final function_call_arguments.done is emitted at least once.
+ if s.toolCallsEmitted {
+ detected := util.ParseToolCalls(finalText, s.toolNames)
+ if len(detected) == 0 {
+ detected = util.ParseToolCalls(finalThinking, s.toolNames)
+ }
+ if len(detected) > 0 {
+ if !s.toolCallsDoneEmitted {
+ s.emitToolCallsDone(detected)
+ } else {
+ s.emitFunctionCallDoneEvents(detected)
+ }
+ }
+ }
+
+ obj := openaifmt.BuildResponseObject(s.responseID, s.model, s.finalPrompt, finalThinking, finalText, s.toolNames)
+ if s.toolCallsEmitted {
+ s.alignCompletedOutputCallIDs(obj)
+ }
+ if s.toolCallsEmitted {
+ obj["status"] = "completed"
+ }
+ if s.persistResponse != nil {
+ s.persistResponse(obj)
+ }
+ s.sendEvent("response.completed", openaifmt.BuildResponsesCompletedPayload(obj))
+ s.sendDone()
+}
+
+func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedDecision {
+ if !parsed.Parsed {
+ return streamengine.ParsedDecision{}
+ }
+ if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
+ return streamengine.ParsedDecision{Stop: true}
+ }
+
+ contentSeen := false
+ for _, p := range parsed.Parts {
+ if p.Text == "" {
+ continue
+ }
+ if p.Type != "thinking" && s.searchEnabled && sse.IsCitation(p.Text) {
+ continue
+ }
+ contentSeen = true
+ if p.Type == "thinking" {
+ if !s.thinkingEnabled {
+ continue
+ }
+ s.thinking.WriteString(p.Text)
+ s.sendEvent("response.reasoning.delta", openaifmt.BuildResponsesReasoningDeltaPayload(s.responseID, p.Text))
+ s.sendEvent("response.reasoning_text.delta", openaifmt.BuildResponsesReasoningTextDeltaPayload(s.responseID, s.ensureReasoningItemID(), 0, 0, p.Text))
+ if s.bufferToolContent {
+ s.processToolStreamEvents(processToolSieveChunk(&s.thinkingSieve, p.Text, s.toolNames), false)
+ }
+ continue
+ }
+
+ s.text.WriteString(p.Text)
+ if !s.bufferToolContent {
+ s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, p.Text))
+ continue
+ }
+ s.processToolStreamEvents(processToolSieveChunk(&s.sieve, p.Text, s.toolNames), true)
+ }
+
+ return streamengine.ParsedDecision{ContentSeen: contentSeen}
+}
diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go
new file mode 100644
index 0000000..fd36b6a
--- /dev/null
+++ b/internal/adapter/openai/responses_stream_runtime_events.go
@@ -0,0 +1,52 @@
+package openai
+
+import (
+ "encoding/json"
+
+ openaifmt "ds2api/internal/format/openai"
+)
+
+func (s *responsesStreamRuntime) sendEvent(event string, payload map[string]any) {
+ b, _ := json.Marshal(payload)
+ _, _ = s.w.Write([]byte("event: " + event + "\n"))
+ _, _ = s.w.Write([]byte("data: "))
+ _, _ = s.w.Write(b)
+ _, _ = s.w.Write([]byte("\n\n"))
+ if s.canFlush {
+ _ = s.rc.Flush()
+ }
+}
+
+func (s *responsesStreamRuntime) sendCreated() {
+ s.sendEvent("response.created", openaifmt.BuildResponsesCreatedPayload(s.responseID, s.model))
+}
+
+func (s *responsesStreamRuntime) sendDone() {
+ _, _ = s.w.Write([]byte("data: [DONE]\n\n"))
+ if s.canFlush {
+ _ = s.rc.Flush()
+ }
+}
+
+func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
+ for _, evt := range events {
+ if emitContent && evt.Content != "" {
+ s.sendEvent("response.output_text.delta", openaifmt.BuildResponsesTextDeltaPayload(s.responseID, evt.Content))
+ }
+ if len(evt.ToolCallDeltas) > 0 {
+ if !s.emitEarlyToolDeltas {
+ continue
+ }
+ formatted := formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, s.streamToolCallIDs)
+ if len(formatted) == 0 {
+ continue
+ }
+ s.toolCallsEmitted = true
+ s.sendEvent("response.output_tool_call.delta", openaifmt.BuildResponsesToolCallDeltaPayload(s.responseID, formatted))
+ s.emitFunctionCallDeltaEvents(evt.ToolCallDeltas)
+ }
+ if len(evt.ToolCalls) > 0 {
+ s.emitToolCallsDone(evt.ToolCalls)
+ }
+ }
+}
diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go
new file mode 100644
index 0000000..7891425
--- /dev/null
+++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go
@@ -0,0 +1,172 @@
+package openai
+
+import (
+ "encoding/json"
+ "sort"
+ "strings"
+
+ openaifmt "ds2api/internal/format/openai"
+ "ds2api/internal/util"
+
+ "github.com/google/uuid"
+)
+
+func (s *responsesStreamRuntime) emitToolCallsDone(calls []util.ParsedToolCall) {
+ if len(calls) == 0 {
+ return
+ }
+ sig := toolCallListSignature(calls)
+ if sig != "" && s.toolCallsDoneSigs[sig] {
+ return
+ }
+ if sig != "" {
+ s.toolCallsDoneSigs[sig] = true
+ }
+ formatted := formatFinalStreamToolCallsWithStableIDs(calls, s.streamToolCallIDs)
+ if len(formatted) == 0 {
+ return
+ }
+ s.toolCallsEmitted = true
+ s.toolCallsDoneEmitted = true
+ s.sendEvent("response.output_tool_call.done", openaifmt.BuildResponsesToolCallDonePayload(s.responseID, formatted))
+ s.emitFunctionCallDoneEvents(calls)
+}
+
+func (s *responsesStreamRuntime) ensureReasoningItemID() string {
+ if strings.TrimSpace(s.reasoningItemID) != "" {
+ return s.reasoningItemID
+ }
+ s.reasoningItemID = "rs_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ return s.reasoningItemID
+}
+
+func (s *responsesStreamRuntime) ensureFunctionItemID(index int) string {
+ if id, ok := s.streamFunctionIDs[index]; ok && strings.TrimSpace(id) != "" {
+ return id
+ }
+ id := "fc_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ s.streamFunctionIDs[index] = id
+ return id
+}
+
+func (s *responsesStreamRuntime) ensureToolCallID(index int) string {
+ if id, ok := s.streamToolCallIDs[index]; ok && strings.TrimSpace(id) != "" {
+ return id
+ }
+ id := "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ s.streamToolCallIDs[index] = id
+ return id
+}
+
+func (s *responsesStreamRuntime) functionOutputBaseIndex() int {
+ if strings.TrimSpace(s.thinking.String()) != "" {
+ return 1
+ }
+ return 0
+}
+
+func (s *responsesStreamRuntime) emitFunctionCallDeltaEvents(deltas []toolCallDelta) {
+ for _, d := range deltas {
+ if strings.TrimSpace(d.Arguments) == "" {
+ continue
+ }
+ outputIndex := s.functionOutputBaseIndex() + d.Index
+ itemID := s.ensureFunctionItemID(outputIndex)
+ callID := s.ensureToolCallID(d.Index)
+ s.sendEvent(
+ "response.function_call_arguments.delta",
+ openaifmt.BuildResponsesFunctionCallArgumentsDeltaPayload(s.responseID, itemID, outputIndex, callID, d.Arguments),
+ )
+ }
+}
+
+func (s *responsesStreamRuntime) emitFunctionCallDoneEvents(calls []util.ParsedToolCall) {
+ base := s.functionOutputBaseIndex()
+ for idx, tc := range calls {
+ if strings.TrimSpace(tc.Name) == "" {
+ continue
+ }
+ outputIndex := base + idx
+ if s.functionDone[outputIndex] {
+ continue
+ }
+ itemID := s.ensureFunctionItemID(outputIndex)
+ callID := s.ensureToolCallID(idx)
+ argsBytes, _ := json.Marshal(tc.Input)
+ s.sendEvent(
+ "response.function_call_arguments.done",
+ openaifmt.BuildResponsesFunctionCallArgumentsDonePayload(s.responseID, itemID, outputIndex, callID, tc.Name, string(argsBytes)),
+ )
+ s.functionDone[outputIndex] = true
+ }
+}
+
+func (s *responsesStreamRuntime) alignCompletedOutputCallIDs(obj map[string]any) {
+ if obj == nil || len(s.streamToolCallIDs) == 0 {
+ return
+ }
+ output, _ := obj["output"].([]any)
+ if len(output) == 0 {
+ return
+ }
+ indices := make([]int, 0, len(s.streamToolCallIDs))
+ for idx := range s.streamToolCallIDs {
+ indices = append(indices, idx)
+ }
+ sort.Ints(indices)
+ ordered := make([]string, 0, len(indices))
+ for _, idx := range indices {
+ id := strings.TrimSpace(s.streamToolCallIDs[idx])
+ if id == "" {
+ continue
+ }
+ ordered = append(ordered, id)
+ }
+ if len(ordered) == 0 {
+ return
+ }
+
+ functionIdx := 0
+ for _, item := range output {
+ m, _ := item.(map[string]any)
+ if m == nil {
+ continue
+ }
+ typ, _ := m["type"].(string)
+ switch typ {
+ case "function_call":
+ if functionIdx < len(ordered) {
+ m["call_id"] = ordered[functionIdx]
+ functionIdx++
+ }
+ case "tool_calls":
+ tcArr, _ := m["tool_calls"].([]any)
+ for i, raw := range tcArr {
+ tc, _ := raw.(map[string]any)
+ if tc == nil {
+ continue
+ }
+ if i < len(ordered) {
+ tc["id"] = ordered[i]
+ }
+ }
+ }
+ }
+}
+
+func toolCallListSignature(calls []util.ParsedToolCall) string {
+ if len(calls) == 0 {
+ return ""
+ }
+ var b strings.Builder
+ for i, tc := range calls {
+ if i > 0 {
+ b.WriteString("|")
+ }
+ b.WriteString(strings.TrimSpace(tc.Name))
+ b.WriteString(":")
+ args, _ := json.Marshal(tc.Input)
+ b.Write(args)
+ }
+ return b.String()
+}
diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go
deleted file mode 100644
index 9c46649..0000000
--- a/internal/adapter/openai/tool_sieve.go
+++ /dev/null
@@ -1,713 +0,0 @@
-package openai
-
-import (
- "strings"
-
- "ds2api/internal/util"
-)
-
-type toolStreamSieveState struct {
- pending strings.Builder
- capture strings.Builder
- capturing bool
- recentTextTail string
- disableDeltas bool
- toolNameSent bool
- toolName string
- toolArgsStart int
- toolArgsSent int
- toolArgsString bool
- toolArgsDone bool
-}
-
-type toolStreamEvent struct {
- Content string
- ToolCalls []util.ParsedToolCall
- ToolCallDeltas []toolCallDelta
-}
-
-type toolCallDelta struct {
- Index int
- Name string
- Arguments string
-}
-
-const toolSieveCaptureLimit = 8 * 1024
-const toolSieveContextTailLimit = 256
-
-func (s *toolStreamSieveState) resetIncrementalToolState() {
- s.disableDeltas = false
- s.toolNameSent = false
- s.toolName = ""
- s.toolArgsStart = -1
- s.toolArgsSent = -1
- s.toolArgsString = false
- s.toolArgsDone = false
-}
-
-func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
- if state == nil {
- return nil
- }
- if chunk != "" {
- state.pending.WriteString(chunk)
- }
- events := make([]toolStreamEvent, 0, 2)
-
- for {
- if state.capturing {
- if state.pending.Len() > 0 {
- state.capture.WriteString(state.pending.String())
- state.pending.Reset()
- }
- if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
- events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
- }
- prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
- if !ready {
- if state.capture.Len() > toolSieveCaptureLimit {
- content := state.capture.String()
- state.capture.Reset()
- state.capturing = false
- state.resetIncrementalToolState()
- state.noteText(content)
- events = append(events, toolStreamEvent{Content: content})
- continue
- }
- break
- }
- state.capture.Reset()
- state.capturing = false
- state.resetIncrementalToolState()
- if prefix != "" {
- state.noteText(prefix)
- events = append(events, toolStreamEvent{Content: prefix})
- }
- if len(calls) > 0 {
- events = append(events, toolStreamEvent{ToolCalls: calls})
- }
- if suffix != "" {
- state.pending.WriteString(suffix)
- }
- continue
- }
-
- pending := state.pending.String()
- if pending == "" {
- break
- }
- start := findToolSegmentStart(pending)
- if start >= 0 {
- prefix := pending[:start]
- if prefix != "" {
- state.noteText(prefix)
- events = append(events, toolStreamEvent{Content: prefix})
- }
- state.pending.Reset()
- state.capture.WriteString(pending[start:])
- state.capturing = true
- state.resetIncrementalToolState()
- continue
- }
-
- safe, hold := splitSafeContentForToolDetection(pending)
- if safe == "" {
- break
- }
- state.pending.Reset()
- state.pending.WriteString(hold)
- state.noteText(safe)
- events = append(events, toolStreamEvent{Content: safe})
- }
-
- return events
-}
-
-func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
- if state == nil {
- return nil
- }
- events := processToolSieveChunk(state, "", toolNames)
- if state.capturing {
- consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
- if ready {
- if consumedPrefix != "" {
- state.noteText(consumedPrefix)
- events = append(events, toolStreamEvent{Content: consumedPrefix})
- }
- if len(consumedCalls) > 0 {
- events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
- }
- if consumedSuffix != "" {
- state.noteText(consumedSuffix)
- events = append(events, toolStreamEvent{Content: consumedSuffix})
- }
- } else {
- content := state.capture.String()
- if content != "" {
- state.noteText(content)
- events = append(events, toolStreamEvent{Content: content})
- }
- }
- state.capture.Reset()
- state.capturing = false
- state.resetIncrementalToolState()
- }
- if state.pending.Len() > 0 {
- content := state.pending.String()
- state.noteText(content)
- events = append(events, toolStreamEvent{Content: content})
- state.pending.Reset()
- }
- return events
-}
-
-func splitSafeContentForToolDetection(s string) (safe, hold string) {
- if s == "" {
- return "", ""
- }
- suspiciousStart := findSuspiciousPrefixStart(s)
- if suspiciousStart < 0 {
- return s, ""
- }
- if suspiciousStart > 0 {
- return s[:suspiciousStart], s[suspiciousStart:]
- }
- // If suspicious content starts at position 0, keep holding until we can
- // parse a complete tool JSON block or reach stream flush.
- return "", s
-}
-
-func findSuspiciousPrefixStart(s string) int {
- start := -1
- indices := []int{
- strings.LastIndex(s, "{"),
- strings.LastIndex(s, "["),
- strings.LastIndex(s, "```"),
- }
- for _, idx := range indices {
- if idx > start {
- start = idx
- }
- }
- return start
-}
-
-func findToolSegmentStart(s string) int {
- if s == "" {
- return -1
- }
- lower := strings.ToLower(s)
- offset := 0
- for {
- keyRel := strings.Index(lower[offset:], "tool_calls")
- if keyRel < 0 {
- return -1
- }
- keyIdx := offset + keyRel
- start := strings.LastIndex(s[:keyIdx], "{")
- if start < 0 {
- start = keyIdx
- }
- if !insideCodeFence(s[:start]) {
- return start
- }
- offset = keyIdx + len("tool_calls")
- }
-}
-
-func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
- captured := state.capture.String()
- if captured == "" {
- return "", nil, "", false
- }
- lower := strings.ToLower(captured)
- keyIdx := strings.Index(lower, "tool_calls")
- if keyIdx < 0 {
- return "", nil, "", false
- }
- start := strings.LastIndex(captured[:keyIdx], "{")
- if start < 0 {
- return "", nil, "", false
- }
- obj, end, ok := extractJSONObjectFrom(captured, start)
- if !ok {
- return "", nil, "", false
- }
- prefixPart := captured[:start]
- suffixPart := captured[end:]
- if insideCodeFence(state.recentTextTail + prefixPart) {
- return captured, nil, "", true
- }
- parsed := util.ParseStandaloneToolCalls(obj, toolNames)
- if len(parsed) == 0 {
- return captured, nil, "", true
- }
- return prefixPart, parsed, suffixPart, true
-}
-
-func extractJSONObjectFrom(text string, start int) (string, int, bool) {
- if start < 0 || start >= len(text) || text[start] != '{' {
- return "", 0, false
- }
- depth := 0
- quote := byte(0)
- escaped := false
- for i := start; i < len(text); i++ {
- ch := text[i]
- if quote != 0 {
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == quote {
- quote = 0
- }
- continue
- }
- if ch == '"' || ch == '\'' {
- quote = ch
- continue
- }
- if ch == '{' {
- depth++
- continue
- }
- if ch == '}' {
- depth--
- if depth == 0 {
- end := i + 1
- return text[start:end], end, true
- }
- }
- }
- return "", 0, false
-}
-
-func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
- if state.disableDeltas {
- return nil
- }
- captured := state.capture.String()
- if captured == "" {
- return nil
- }
- lower := strings.ToLower(captured)
- keyIdx := strings.Index(lower, "tool_calls")
- if keyIdx < 0 {
- return nil
- }
- start := strings.LastIndex(captured[:keyIdx], "{")
- if start < 0 {
- return nil
- }
- if insideCodeFence(state.recentTextTail + captured[:start]) {
- return nil
- }
- certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
- if hasMultiple {
- state.disableDeltas = true
- return nil
- }
- if !certainSingle {
- // In uncertain phases (e.g. first call arrived but array not closed yet),
- // avoid speculative deltas and wait for final parsed tool_calls payload.
- return nil
- }
- callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
- if !ok {
- return nil
- }
- deltas := make([]toolCallDelta, 0, 2)
- if state.toolName == "" {
- name, ok := extractToolCallName(captured, callStart)
- if !ok || name == "" {
- return nil
- }
- state.toolName = name
- }
- if state.toolArgsStart < 0 {
- argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
- if ok {
- state.toolArgsString = stringMode
- if stringMode {
- state.toolArgsStart = argsStart + 1
- } else {
- state.toolArgsStart = argsStart
- }
- state.toolArgsSent = state.toolArgsStart
- }
- }
- if !state.toolNameSent {
- if state.toolArgsStart < 0 {
- return nil
- }
- state.toolNameSent = true
- deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
- }
- if state.toolArgsStart < 0 || state.toolArgsDone {
- return deltas
- }
- end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
- if !ok {
- return deltas
- }
- if end > state.toolArgsSent {
- deltas = append(deltas, toolCallDelta{
- Index: 0,
- Arguments: captured[state.toolArgsSent:end],
- })
- state.toolArgsSent = end
- }
- if complete {
- state.toolArgsDone = true
- }
- return deltas
-}
-
-func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
- arrStart, ok := findToolCallsArrayStart(text, keyIdx)
- if !ok {
- return false, false
- }
- i := skipSpaces(text, arrStart+1)
- if i >= len(text) || text[i] != '{' {
- return false, false
- }
- count := 0
- depth := 0
- quote := byte(0)
- escaped := false
- for ; i < len(text); i++ {
- ch := text[i]
- if quote != 0 {
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == quote {
- quote = 0
- }
- continue
- }
- if ch == '"' || ch == '\'' {
- quote = ch
- continue
- }
- if ch == '{' {
- if depth == 0 {
- count++
- if count > 1 {
- return false, true
- }
- }
- depth++
- continue
- }
- if ch == '}' {
- if depth > 0 {
- depth--
- }
- continue
- }
- if ch == ',' && depth == 0 {
- // top-level separator means at least one more tool call exists
- // (or is expected). Treat as multi-call and stop incremental deltas.
- return false, true
- }
- if ch == ']' && depth == 0 {
- return count == 1, false
- }
- }
- // array not closed yet: still uncertain whether more calls will appear
- return false, false
-}
-
-func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
- arrStart, ok := findToolCallsArrayStart(text, keyIdx)
- if !ok {
- return -1, false
- }
- i := skipSpaces(text, arrStart+1)
- if i >= len(text) || text[i] != '{' {
- return -1, false
- }
- return i, true
-}
-
-func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
- i := keyIdx + len("tool_calls")
- for i < len(text) && text[i] != ':' {
- i++
- }
- if i >= len(text) {
- return -1, false
- }
- i = skipSpaces(text, i+1)
- if i >= len(text) || text[i] != '[' {
- return -1, false
- }
- return i, true
-}
-
-func extractToolCallName(text string, callStart int) (string, bool) {
- valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
- if !ok || valueStart >= len(text) || text[valueStart] != '"' {
- fnStart, fnOK := findFunctionObjectStart(text, callStart)
- if !fnOK {
- return "", false
- }
- valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
- if !ok || valueStart >= len(text) || text[valueStart] != '"' {
- return "", false
- }
- }
- name, _, ok := parseJSONStringLiteral(text, valueStart)
- if !ok {
- return "", false
- }
- return name, true
-}
-
-func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
- keys := []string{"input", "arguments", "args", "parameters", "params"}
- valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
- if !ok {
- fnStart, fnOK := findFunctionObjectStart(text, callStart)
- if !fnOK {
- return -1, false, false
- }
- valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
- if !ok {
- return -1, false, false
- }
- }
- if valueStart >= len(text) {
- return -1, false, false
- }
- ch := text[valueStart]
- if ch == '{' || ch == '[' {
- return valueStart, false, true
- }
- if ch == '"' {
- return valueStart, true, true
- }
- return -1, false, false
-}
-
-func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
- if start < 0 || start > len(text) {
- return 0, false, false
- }
- if stringMode {
- escaped := false
- for i := start; i < len(text); i++ {
- ch := text[i]
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == '"' {
- return i, true, true
- }
- }
- return len(text), false, true
- }
- if start >= len(text) {
- return start, false, false
- }
- if text[start] != '{' && text[start] != '[' {
- return 0, false, false
- }
- depth := 0
- quote := byte(0)
- escaped := false
- for i := start; i < len(text); i++ {
- ch := text[i]
- if quote != 0 {
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == quote {
- quote = 0
- }
- continue
- }
- if ch == '"' || ch == '\'' {
- quote = ch
- continue
- }
- if ch == '{' || ch == '[' {
- depth++
- continue
- }
- if ch == '}' || ch == ']' {
- depth--
- if depth == 0 {
- return i + 1, true, true
- }
- }
- }
- return len(text), false, true
-}
-
-func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
- if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
- return 0, false
- }
- depth := 0
- quote := byte(0)
- escaped := false
- for i := objStart; i < len(text); i++ {
- ch := text[i]
- if quote != 0 {
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == quote {
- quote = 0
- }
- continue
- }
- if ch == '"' || ch == '\'' {
- if depth == 1 {
- key, end, ok := parseJSONStringLiteral(text, i)
- if !ok {
- return 0, false
- }
- j := skipSpaces(text, end)
- if j >= len(text) || text[j] != ':' {
- i = end - 1
- continue
- }
- j = skipSpaces(text, j+1)
- if j >= len(text) {
- return 0, false
- }
- if containsKey(keys, key) {
- return j, true
- }
- i = j - 1
- continue
- }
- quote = ch
- continue
- }
- if ch == '{' {
- depth++
- continue
- }
- if ch == '}' {
- depth--
- if depth == 0 {
- break
- }
- }
- }
- return 0, false
-}
-
-func findFunctionObjectStart(text string, callStart int) (int, bool) {
- valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
- if !ok || valueStart >= len(text) || text[valueStart] != '{' {
- return -1, false
- }
- return valueStart, true
-}
-
-func parseJSONStringLiteral(text string, start int) (string, int, bool) {
- if start < 0 || start >= len(text) || text[start] != '"' {
- return "", 0, false
- }
- var b strings.Builder
- escaped := false
- for i := start + 1; i < len(text); i++ {
- ch := text[i]
- if escaped {
- b.WriteByte(ch)
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == '"' {
- return b.String(), i + 1, true
- }
- b.WriteByte(ch)
- }
- return "", 0, false
-}
-
-func containsKey(keys []string, value string) bool {
- for _, k := range keys {
- if k == value {
- return true
- }
- }
- return false
-}
-
-func skipSpaces(text string, i int) int {
- for i < len(text) {
- switch text[i] {
- case ' ', '\t', '\n', '\r':
- i++
- default:
- return i
- }
- }
- return i
-}
-
-func (s *toolStreamSieveState) noteText(content string) {
- if strings.TrimSpace(content) == "" {
- return
- }
- s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
-}
-
-func appendTail(prev, next string, max int) string {
- if max <= 0 {
- return ""
- }
- combined := prev + next
- if len(combined) <= max {
- return combined
- }
- return combined[len(combined)-max:]
-}
-
-func looksLikeToolExampleContext(text string) bool {
- return insideCodeFence(text)
-}
-
-func insideCodeFence(text string) bool {
- if text == "" {
- return false
- }
- return strings.Count(text, "```")%2 == 1
-}
diff --git a/internal/adapter/openai/tool_sieve_core.go b/internal/adapter/openai/tool_sieve_core.go
new file mode 100644
index 0000000..1bcf102
--- /dev/null
+++ b/internal/adapter/openai/tool_sieve_core.go
@@ -0,0 +1,208 @@
+package openai
+
+import (
+ "strings"
+
+ "ds2api/internal/util"
+)
+
+func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
+ if state == nil {
+ return nil
+ }
+ if chunk != "" {
+ state.pending.WriteString(chunk)
+ }
+ events := make([]toolStreamEvent, 0, 2)
+
+ for {
+ if state.capturing {
+ if state.pending.Len() > 0 {
+ state.capture.WriteString(state.pending.String())
+ state.pending.Reset()
+ }
+ if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
+ events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
+ }
+ prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
+ if !ready {
+ if state.capture.Len() > toolSieveCaptureLimit {
+ content := state.capture.String()
+ state.capture.Reset()
+ state.capturing = false
+ state.resetIncrementalToolState()
+ state.noteText(content)
+ events = append(events, toolStreamEvent{Content: content})
+ continue
+ }
+ break
+ }
+ state.capture.Reset()
+ state.capturing = false
+ state.resetIncrementalToolState()
+ if prefix != "" {
+ state.noteText(prefix)
+ events = append(events, toolStreamEvent{Content: prefix})
+ }
+ if len(calls) > 0 {
+ events = append(events, toolStreamEvent{ToolCalls: calls})
+ }
+ if suffix != "" {
+ state.pending.WriteString(suffix)
+ }
+ continue
+ }
+
+ pending := state.pending.String()
+ if pending == "" {
+ break
+ }
+ start := findToolSegmentStart(pending)
+ if start >= 0 {
+ prefix := pending[:start]
+ if prefix != "" {
+ state.noteText(prefix)
+ events = append(events, toolStreamEvent{Content: prefix})
+ }
+ state.pending.Reset()
+ state.capture.WriteString(pending[start:])
+ state.capturing = true
+ state.resetIncrementalToolState()
+ continue
+ }
+
+ safe, hold := splitSafeContentForToolDetection(pending)
+ if safe == "" {
+ break
+ }
+ state.pending.Reset()
+ state.pending.WriteString(hold)
+ state.noteText(safe)
+ events = append(events, toolStreamEvent{Content: safe})
+ }
+
+ return events
+}
+
+func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStreamEvent {
+ if state == nil {
+ return nil
+ }
+ events := processToolSieveChunk(state, "", toolNames)
+ if state.capturing {
+ consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
+ if ready {
+ if consumedPrefix != "" {
+ state.noteText(consumedPrefix)
+ events = append(events, toolStreamEvent{Content: consumedPrefix})
+ }
+ if len(consumedCalls) > 0 {
+ events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
+ }
+ if consumedSuffix != "" {
+ state.noteText(consumedSuffix)
+ events = append(events, toolStreamEvent{Content: consumedSuffix})
+ }
+ } else {
+ content := state.capture.String()
+ if content != "" {
+ state.noteText(content)
+ events = append(events, toolStreamEvent{Content: content})
+ }
+ }
+ state.capture.Reset()
+ state.capturing = false
+ state.resetIncrementalToolState()
+ }
+ if state.pending.Len() > 0 {
+ content := state.pending.String()
+ state.noteText(content)
+ events = append(events, toolStreamEvent{Content: content})
+ state.pending.Reset()
+ }
+ return events
+}
+
+func splitSafeContentForToolDetection(s string) (safe, hold string) {
+ if s == "" {
+ return "", ""
+ }
+ suspiciousStart := findSuspiciousPrefixStart(s)
+ if suspiciousStart < 0 {
+ return s, ""
+ }
+ if suspiciousStart > 0 {
+ return s[:suspiciousStart], s[suspiciousStart:]
+ }
+ // If suspicious content starts at position 0, keep holding until we can
+ // parse a complete tool JSON block or reach stream flush.
+ return "", s
+}
+
+func findSuspiciousPrefixStart(s string) int {
+ start := -1
+ indices := []int{
+ strings.LastIndex(s, "{"),
+ strings.LastIndex(s, "["),
+ strings.LastIndex(s, "```"),
+ }
+ for _, idx := range indices {
+ if idx > start {
+ start = idx
+ }
+ }
+ return start
+}
+
+func findToolSegmentStart(s string) int {
+ if s == "" {
+ return -1
+ }
+ lower := strings.ToLower(s)
+ offset := 0
+ for {
+ keyRel := strings.Index(lower[offset:], "tool_calls")
+ if keyRel < 0 {
+ return -1
+ }
+ keyIdx := offset + keyRel
+ start := strings.LastIndex(s[:keyIdx], "{")
+ if start < 0 {
+ start = keyIdx
+ }
+ if !insideCodeFence(s[:start]) {
+ return start
+ }
+ offset = keyIdx + len("tool_calls")
+ }
+}
+
+func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
+ captured := state.capture.String()
+ if captured == "" {
+ return "", nil, "", false
+ }
+ lower := strings.ToLower(captured)
+ keyIdx := strings.Index(lower, "tool_calls")
+ if keyIdx < 0 {
+ return "", nil, "", false
+ }
+ start := strings.LastIndex(captured[:keyIdx], "{")
+ if start < 0 {
+ return "", nil, "", false
+ }
+ obj, end, ok := extractJSONObjectFrom(captured, start)
+ if !ok {
+ return "", nil, "", false
+ }
+ prefixPart := captured[:start]
+ suffixPart := captured[end:]
+ if insideCodeFence(state.recentTextTail + prefixPart) {
+ return captured, nil, "", true
+ }
+ parsed := util.ParseStandaloneToolCalls(obj, toolNames)
+ if len(parsed) == 0 {
+ return captured, nil, "", true
+ }
+ return prefixPart, parsed, suffixPart, true
+}
diff --git a/internal/adapter/openai/tool_sieve_incremental.go b/internal/adapter/openai/tool_sieve_incremental.go
new file mode 100644
index 0000000..ad0f901
--- /dev/null
+++ b/internal/adapter/openai/tool_sieve_incremental.go
@@ -0,0 +1,291 @@
+package openai
+
+import "strings"
+
+func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
+ if state.disableDeltas {
+ return nil
+ }
+ captured := state.capture.String()
+ if captured == "" {
+ return nil
+ }
+ lower := strings.ToLower(captured)
+ keyIdx := strings.Index(lower, "tool_calls")
+ if keyIdx < 0 {
+ return nil
+ }
+ start := strings.LastIndex(captured[:keyIdx], "{")
+ if start < 0 {
+ return nil
+ }
+ if insideCodeFence(state.recentTextTail + captured[:start]) {
+ return nil
+ }
+ certainSingle, hasMultiple := classifyToolCallsIncrementalSafety(captured, keyIdx)
+ if hasMultiple {
+ state.disableDeltas = true
+ return nil
+ }
+ if !certainSingle {
+ // In uncertain phases (e.g. first call arrived but array not closed yet),
+ // avoid speculative deltas and wait for final parsed tool_calls payload.
+ return nil
+ }
+ callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
+ if !ok {
+ return nil
+ }
+ deltas := make([]toolCallDelta, 0, 2)
+ if state.toolName == "" {
+ name, ok := extractToolCallName(captured, callStart)
+ if !ok || name == "" {
+ return nil
+ }
+ state.toolName = name
+ }
+ if state.toolArgsStart < 0 {
+ argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
+ if ok {
+ state.toolArgsString = stringMode
+ if stringMode {
+ state.toolArgsStart = argsStart + 1
+ } else {
+ state.toolArgsStart = argsStart
+ }
+ state.toolArgsSent = state.toolArgsStart
+ }
+ }
+ if !state.toolNameSent {
+ if state.toolArgsStart < 0 {
+ return nil
+ }
+ state.toolNameSent = true
+ deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
+ }
+ if state.toolArgsStart < 0 || state.toolArgsDone {
+ return deltas
+ }
+ end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
+ if !ok {
+ return deltas
+ }
+ if end > state.toolArgsSent {
+ deltas = append(deltas, toolCallDelta{
+ Index: 0,
+ Arguments: captured[state.toolArgsSent:end],
+ })
+ state.toolArgsSent = end
+ }
+ if complete {
+ state.toolArgsDone = true
+ }
+ return deltas
+}
+
+func classifyToolCallsIncrementalSafety(text string, keyIdx int) (certainSingle bool, hasMultiple bool) {
+ arrStart, ok := findToolCallsArrayStart(text, keyIdx)
+ if !ok {
+ return false, false
+ }
+ i := skipSpaces(text, arrStart+1)
+ if i >= len(text) || text[i] != '{' {
+ return false, false
+ }
+ count := 0
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for ; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ quote = ch
+ continue
+ }
+ if ch == '{' {
+ if depth == 0 {
+ count++
+ if count > 1 {
+ return false, true
+ }
+ }
+ depth++
+ continue
+ }
+ if ch == '}' {
+ if depth > 0 {
+ depth--
+ }
+ continue
+ }
+ if ch == ',' && depth == 0 {
+ // top-level separator means at least one more tool call exists
+ // (or is expected). Treat as multi-call and stop incremental deltas.
+ return false, true
+ }
+ if ch == ']' && depth == 0 {
+ return count == 1, false
+ }
+ }
+ // array not closed yet: still uncertain whether more calls will appear
+ return false, false
+}
+
+func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
+ arrStart, ok := findToolCallsArrayStart(text, keyIdx)
+ if !ok {
+ return -1, false
+ }
+ i := skipSpaces(text, arrStart+1)
+ if i >= len(text) || text[i] != '{' {
+ return -1, false
+ }
+ return i, true
+}
+
+func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
+ i := keyIdx + len("tool_calls")
+ for i < len(text) && text[i] != ':' {
+ i++
+ }
+ if i >= len(text) {
+ return -1, false
+ }
+ i = skipSpaces(text, i+1)
+ if i >= len(text) || text[i] != '[' {
+ return -1, false
+ }
+ return i, true
+}
+
+func extractToolCallName(text string, callStart int) (string, bool) {
+ valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
+ if !ok || valueStart >= len(text) || text[valueStart] != '"' {
+ fnStart, fnOK := findFunctionObjectStart(text, callStart)
+ if !fnOK {
+ return "", false
+ }
+ valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
+ if !ok || valueStart >= len(text) || text[valueStart] != '"' {
+ return "", false
+ }
+ }
+ name, _, ok := parseJSONStringLiteral(text, valueStart)
+ if !ok {
+ return "", false
+ }
+ return name, true
+}
+
+func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
+ keys := []string{"input", "arguments", "args", "parameters", "params"}
+ valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
+ if !ok {
+ fnStart, fnOK := findFunctionObjectStart(text, callStart)
+ if !fnOK {
+ return -1, false, false
+ }
+ valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
+ if !ok {
+ return -1, false, false
+ }
+ }
+ if valueStart >= len(text) {
+ return -1, false, false
+ }
+ ch := text[valueStart]
+ if ch == '{' || ch == '[' {
+ return valueStart, false, true
+ }
+ if ch == '"' {
+ return valueStart, true, true
+ }
+ return -1, false, false
+}
+
+func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
+ if start < 0 || start > len(text) {
+ return 0, false, false
+ }
+ if stringMode {
+ escaped := false
+ for i := start; i < len(text); i++ {
+ ch := text[i]
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == '"' {
+ return i, true, true
+ }
+ }
+ return len(text), false, true
+ }
+ if start >= len(text) {
+ return start, false, false
+ }
+ if text[start] != '{' && text[start] != '[' {
+ return 0, false, false
+ }
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for i := start; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ quote = ch
+ continue
+ }
+ if ch == '{' || ch == '[' {
+ depth++
+ continue
+ }
+ if ch == '}' || ch == ']' {
+ depth--
+ if depth == 0 {
+ return i + 1, true, true
+ }
+ }
+ }
+ return len(text), false, true
+}
+
+func findFunctionObjectStart(text string, callStart int) (int, bool) {
+ valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
+ if !ok || valueStart >= len(text) || text[valueStart] != '{' {
+ return -1, false
+ }
+ return valueStart, true
+}
diff --git a/internal/adapter/openai/tool_sieve_jsonscan.go b/internal/adapter/openai/tool_sieve_jsonscan.go
new file mode 100644
index 0000000..d3abcc5
--- /dev/null
+++ b/internal/adapter/openai/tool_sieve_jsonscan.go
@@ -0,0 +1,152 @@
+package openai
+
+import "strings"
+
+func extractJSONObjectFrom(text string, start int) (string, int, bool) {
+ if start < 0 || start >= len(text) || text[start] != '{' {
+ return "", 0, false
+ }
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for i := start; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ quote = ch
+ continue
+ }
+ if ch == '{' {
+ depth++
+ continue
+ }
+ if ch == '}' {
+ depth--
+ if depth == 0 {
+ end := i + 1
+ return text[start:end], end, true
+ }
+ }
+ }
+ return "", 0, false
+}
+
+func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
+ if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
+ return 0, false
+ }
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for i := objStart; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ if depth == 1 {
+ key, end, ok := parseJSONStringLiteral(text, i)
+ if !ok {
+ return 0, false
+ }
+ j := skipSpaces(text, end)
+ if j >= len(text) || text[j] != ':' {
+ i = end - 1
+ continue
+ }
+ j = skipSpaces(text, j+1)
+ if j >= len(text) {
+ return 0, false
+ }
+ if containsKey(keys, key) {
+ return j, true
+ }
+ i = j - 1
+ continue
+ }
+ quote = ch
+ continue
+ }
+ if ch == '{' {
+ depth++
+ continue
+ }
+ if ch == '}' {
+ depth--
+ if depth == 0 {
+ break
+ }
+ }
+ }
+ return 0, false
+}
+
+func parseJSONStringLiteral(text string, start int) (string, int, bool) {
+ if start < 0 || start >= len(text) || text[start] != '"' {
+ return "", 0, false
+ }
+ var b strings.Builder
+ escaped := false
+ for i := start + 1; i < len(text); i++ {
+ ch := text[i]
+ if escaped {
+ b.WriteByte(ch)
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == '"' {
+ return b.String(), i + 1, true
+ }
+ b.WriteByte(ch)
+ }
+ return "", 0, false
+}
+
+func containsKey(keys []string, value string) bool {
+ for _, k := range keys {
+ if k == value {
+ return true
+ }
+ }
+ return false
+}
+
+func skipSpaces(text string, i int) int {
+ for i < len(text) {
+ switch text[i] {
+ case ' ', '\t', '\n', '\r':
+ i++
+ default:
+ return i
+ }
+ }
+ return i
+}
diff --git a/internal/adapter/openai/tool_sieve_state.go b/internal/adapter/openai/tool_sieve_state.go
new file mode 100644
index 0000000..04699e6
--- /dev/null
+++ b/internal/adapter/openai/tool_sieve_state.go
@@ -0,0 +1,75 @@
+package openai
+
+import (
+ "strings"
+
+ "ds2api/internal/util"
+)
+
+type toolStreamSieveState struct {
+ pending strings.Builder
+ capture strings.Builder
+ capturing bool
+ recentTextTail string
+ disableDeltas bool
+ toolNameSent bool
+ toolName string
+ toolArgsStart int
+ toolArgsSent int
+ toolArgsString bool
+ toolArgsDone bool
+}
+
+type toolStreamEvent struct {
+ Content string
+ ToolCalls []util.ParsedToolCall
+ ToolCallDeltas []toolCallDelta
+}
+
+type toolCallDelta struct {
+ Index int
+ Name string
+ Arguments string
+}
+
+const toolSieveCaptureLimit = 8 * 1024
+const toolSieveContextTailLimit = 256
+
+func (s *toolStreamSieveState) resetIncrementalToolState() {
+ s.disableDeltas = false
+ s.toolNameSent = false
+ s.toolName = ""
+ s.toolArgsStart = -1
+ s.toolArgsSent = -1
+ s.toolArgsString = false
+ s.toolArgsDone = false
+}
+
+func (s *toolStreamSieveState) noteText(content string) {
+ if strings.TrimSpace(content) == "" {
+ return
+ }
+ s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
+}
+
+func appendTail(prev, next string, max int) string {
+ if max <= 0 {
+ return ""
+ }
+ combined := prev + next
+ if len(combined) <= max {
+ return combined
+ }
+ return combined[len(combined)-max:]
+}
+
+func looksLikeToolExampleContext(text string) bool {
+ return insideCodeFence(text)
+}
+
+func insideCodeFence(text string) bool {
+ if text == "" {
+ return false
+ }
+ return strings.Count(text, "```")%2 == 1
+}
diff --git a/internal/admin/handler_accounts_crud.go b/internal/admin/handler_accounts_crud.go
new file mode 100644
index 0000000..daaa434
--- /dev/null
+++ b/internal/admin/handler_accounts_crud.go
@@ -0,0 +1,114 @@
+package admin
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/config"
+)
+
+func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
+ page := intFromQuery(r, "page", 1)
+ pageSize := intFromQuery(r, "page_size", 10)
+ if page < 1 {
+ page = 1
+ }
+ if pageSize < 1 {
+ pageSize = 1
+ }
+ if pageSize > 100 {
+ pageSize = 100
+ }
+ accounts := h.Store.Snapshot().Accounts
+ total := len(accounts)
+ reverseAccounts(accounts)
+ totalPages := 1
+ if total > 0 {
+ totalPages = (total + pageSize - 1) / pageSize
+ }
+ start := (page - 1) * pageSize
+ if start > total {
+ start = total
+ }
+ end := start + pageSize
+ if end > total {
+ end = total
+ }
+ items := make([]map[string]any, 0, end-start)
+ for _, acc := range accounts[start:end] {
+ token := strings.TrimSpace(acc.Token)
+ preview := ""
+ if token != "" {
+ if len(token) > 20 {
+ preview = token[:20] + "..."
+ } else {
+ preview = token
+ }
+ }
+ items = append(items, map[string]any{
+ "identifier": acc.Identifier(),
+ "email": acc.Email,
+ "mobile": acc.Mobile,
+ "has_password": acc.Password != "",
+ "has_token": token != "",
+ "token_preview": preview,
+ })
+ }
+ writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
+}
+
+func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&req)
+ acc := toAccount(req)
+ if acc.Identifier() == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
+ return
+ }
+ err := h.Store.Update(func(c *config.Config) error {
+ for _, a := range c.Accounts {
+ if acc.Email != "" && a.Email == acc.Email {
+ return fmt.Errorf("邮箱已存在")
+ }
+ if acc.Mobile != "" && a.Mobile == acc.Mobile {
+ return fmt.Errorf("手机号已存在")
+ }
+ }
+ c.Accounts = append(c.Accounts, acc)
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
+ return
+ }
+ h.Pool.Reset()
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
+}
+
+func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
+ identifier := chi.URLParam(r, "identifier")
+ err := h.Store.Update(func(c *config.Config) error {
+ idx := -1
+ for i, a := range c.Accounts {
+ if accountMatchesIdentifier(a, identifier) {
+ idx = i
+ break
+ }
+ }
+ if idx < 0 {
+ return fmt.Errorf("账号不存在")
+ }
+ c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
+ return
+ }
+ h.Pool.Reset()
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
+}
diff --git a/internal/admin/handler_accounts_queue.go b/internal/admin/handler_accounts_queue.go
new file mode 100644
index 0000000..108f802
--- /dev/null
+++ b/internal/admin/handler_accounts_queue.go
@@ -0,0 +1,7 @@
+package admin
+
+import "net/http"
+
+func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
+ writeJSON(w, http.StatusOK, h.Pool.Status())
+}
diff --git a/internal/admin/handler_accounts.go b/internal/admin/handler_accounts_testing.go
similarity index 69%
rename from internal/admin/handler_accounts.go
rename to internal/admin/handler_accounts_testing.go
index 5cb88cc..2bd7706 100644
--- a/internal/admin/handler_accounts.go
+++ b/internal/admin/handler_accounts_testing.go
@@ -11,119 +11,11 @@ import (
"sync"
"time"
- "github.com/go-chi/chi/v5"
-
authn "ds2api/internal/auth"
"ds2api/internal/config"
"ds2api/internal/sse"
)
-func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
- page := intFromQuery(r, "page", 1)
- pageSize := intFromQuery(r, "page_size", 10)
- if page < 1 {
- page = 1
- }
- if pageSize < 1 {
- pageSize = 1
- }
- if pageSize > 100 {
- pageSize = 100
- }
- accounts := h.Store.Snapshot().Accounts
- total := len(accounts)
- reverseAccounts(accounts)
- totalPages := 1
- if total > 0 {
- totalPages = (total + pageSize - 1) / pageSize
- }
- start := (page - 1) * pageSize
- if start > total {
- start = total
- }
- end := start + pageSize
- if end > total {
- end = total
- }
- items := make([]map[string]any, 0, end-start)
- for _, acc := range accounts[start:end] {
- token := strings.TrimSpace(acc.Token)
- preview := ""
- if token != "" {
- if len(token) > 20 {
- preview = token[:20] + "..."
- } else {
- preview = token
- }
- }
- items = append(items, map[string]any{
- "identifier": acc.Identifier(),
- "email": acc.Email,
- "mobile": acc.Mobile,
- "has_password": acc.Password != "",
- "has_token": token != "",
- "token_preview": preview,
- })
- }
- writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": total, "page": page, "page_size": pageSize, "total_pages": totalPages})
-}
-
-func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- _ = json.NewDecoder(r.Body).Decode(&req)
- acc := toAccount(req)
- if acc.Identifier() == "" {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "需要 email 或 mobile"})
- return
- }
- err := h.Store.Update(func(c *config.Config) error {
- for _, a := range c.Accounts {
- if acc.Email != "" && a.Email == acc.Email {
- return fmt.Errorf("邮箱已存在")
- }
- if acc.Mobile != "" && a.Mobile == acc.Mobile {
- return fmt.Errorf("手机号已存在")
- }
- }
- c.Accounts = append(c.Accounts, acc)
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
- return
- }
- h.Pool.Reset()
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
-}
-
-func (h *Handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
- identifier := chi.URLParam(r, "identifier")
- err := h.Store.Update(func(c *config.Config) error {
- idx := -1
- for i, a := range c.Accounts {
- if accountMatchesIdentifier(a, identifier) {
- idx = i
- break
- }
- }
- if idx < 0 {
- return fmt.Errorf("账号不存在")
- }
- c.Accounts = append(c.Accounts[:idx], c.Accounts[idx+1:]...)
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
- return
- }
- h.Pool.Reset()
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_accounts": len(h.Store.Snapshot().Accounts)})
-}
-
-func (h *Handler) queueStatus(w http.ResponseWriter, _ *http.Request) {
- writeJSON(w, http.StatusOK, h.Pool.Status())
-}
-
func (h *Handler) testSingleAccount(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
diff --git a/internal/admin/handler_config.go b/internal/admin/handler_config.go
deleted file mode 100644
index dfbd005..0000000
--- a/internal/admin/handler_config.go
+++ /dev/null
@@ -1,393 +0,0 @@
-package admin
-
-import (
- "crypto/md5"
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
-
- "github.com/go-chi/chi/v5"
-
- "ds2api/internal/config"
-)
-
-func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
- snap := h.Store.Snapshot()
- safe := map[string]any{
- "keys": snap.Keys,
- "accounts": []map[string]any{},
- "claude_mapping": func() map[string]string {
- if len(snap.ClaudeMapping) > 0 {
- return snap.ClaudeMapping
- }
- return snap.ClaudeModelMap
- }(),
- }
- accounts := make([]map[string]any, 0, len(snap.Accounts))
- for _, acc := range snap.Accounts {
- token := strings.TrimSpace(acc.Token)
- preview := ""
- if token != "" {
- if len(token) > 20 {
- preview = token[:20] + "..."
- } else {
- preview = token
- }
- }
- accounts = append(accounts, map[string]any{
- "identifier": acc.Identifier(),
- "email": acc.Email,
- "mobile": acc.Mobile,
- "has_password": strings.TrimSpace(acc.Password) != "",
- "has_token": token != "",
- "token_preview": preview,
- })
- }
- safe["accounts"] = accounts
- writeJSON(w, http.StatusOK, safe)
-}
-
-func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
- return
- }
- old := h.Store.Snapshot()
- err := h.Store.Update(func(c *config.Config) error {
- if keys, ok := toStringSlice(req["keys"]); ok {
- c.Keys = keys
- }
- if accountsRaw, ok := req["accounts"].([]any); ok {
- existing := map[string]config.Account{}
- for _, a := range old.Accounts {
- existing[a.Identifier()] = a
- }
- accounts := make([]config.Account, 0, len(accountsRaw))
- for _, item := range accountsRaw {
- m, ok := item.(map[string]any)
- if !ok {
- continue
- }
- acc := toAccount(m)
- id := acc.Identifier()
- if prev, ok := existing[id]; ok {
- if strings.TrimSpace(acc.Password) == "" {
- acc.Password = prev.Password
- }
- if strings.TrimSpace(acc.Token) == "" {
- acc.Token = prev.Token
- }
- }
- accounts = append(accounts, acc)
- }
- c.Accounts = accounts
- }
- if m, ok := req["claude_mapping"].(map[string]any); ok {
- newMap := map[string]string{}
- for k, v := range m {
- newMap[k] = fmt.Sprintf("%v", v)
- }
- c.ClaudeMapping = newMap
- }
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
- h.Pool.Reset()
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
-}
-
-func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- _ = json.NewDecoder(r.Body).Decode(&req)
- key, _ := req["key"].(string)
- key = strings.TrimSpace(key)
- if key == "" {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
- return
- }
- err := h.Store.Update(func(c *config.Config) error {
- for _, k := range c.Keys {
- if k == key {
- return fmt.Errorf("Key 已存在")
- }
- }
- c.Keys = append(c.Keys, key)
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
- return
- }
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
-}
-
-func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
- key := chi.URLParam(r, "key")
- err := h.Store.Update(func(c *config.Config) error {
- idx := -1
- for i, k := range c.Keys {
- if k == key {
- idx = i
- break
- }
- }
- if idx < 0 {
- return fmt.Errorf("Key 不存在")
- }
- c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
- return
- }
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
-}
-
-func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
- return
- }
- importedKeys, importedAccounts := 0, 0
- err := h.Store.Update(func(c *config.Config) error {
- if keys, ok := req["keys"].([]any); ok {
- existing := map[string]bool{}
- for _, k := range c.Keys {
- existing[k] = true
- }
- for _, k := range keys {
- key := strings.TrimSpace(fmt.Sprintf("%v", k))
- if key == "" || existing[key] {
- continue
- }
- c.Keys = append(c.Keys, key)
- existing[key] = true
- importedKeys++
- }
- }
- if accounts, ok := req["accounts"].([]any); ok {
- existing := map[string]bool{}
- for _, a := range c.Accounts {
- existing[a.Identifier()] = true
- }
- for _, item := range accounts {
- m, ok := item.(map[string]any)
- if !ok {
- continue
- }
- acc := toAccount(m)
- id := acc.Identifier()
- if id == "" || existing[id] {
- continue
- }
- c.Accounts = append(c.Accounts, acc)
- existing[id] = true
- importedAccounts++
- }
- }
- return nil
- })
- if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
- h.Pool.Reset()
- writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
-}
-
-func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
- h.configExport(w, nil)
-}
-
-func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
- snap := h.Store.Snapshot()
- jsonStr, b64, err := h.Store.ExportJSONAndBase64()
- if err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
- writeJSON(w, http.StatusOK, map[string]any{
- "success": true,
- "config": snap,
- "json": jsonStr,
- "base64": b64,
- })
-}
-
-func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
- return
- }
-
- mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
- if mode == "" {
- mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
- }
- if mode == "" {
- mode = "merge"
- }
- if mode != "merge" && mode != "replace" {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
- return
- }
-
- payload := req
- if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
- payload = raw
- }
- rawJSON, err := json.Marshal(payload)
- if err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
- return
- }
- var incoming config.Config
- if err := json.Unmarshal(rawJSON, &incoming); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
- return
- }
-
- importedKeys, importedAccounts := 0, 0
- err = h.Store.Update(func(c *config.Config) error {
- next := c.Clone()
- if mode == "replace" {
- next = incoming.Clone()
- next.VercelSyncHash = c.VercelSyncHash
- next.VercelSyncTime = c.VercelSyncTime
- importedKeys = len(next.Keys)
- importedAccounts = len(next.Accounts)
- } else {
- existingKeys := map[string]struct{}{}
- for _, k := range next.Keys {
- existingKeys[k] = struct{}{}
- }
- for _, k := range incoming.Keys {
- key := strings.TrimSpace(k)
- if key == "" {
- continue
- }
- if _, ok := existingKeys[key]; ok {
- continue
- }
- existingKeys[key] = struct{}{}
- next.Keys = append(next.Keys, key)
- importedKeys++
- }
-
- existingAccounts := map[string]struct{}{}
- for _, acc := range next.Accounts {
- existingAccounts[acc.Identifier()] = struct{}{}
- }
- for _, acc := range incoming.Accounts {
- id := acc.Identifier()
- if id == "" {
- continue
- }
- if _, ok := existingAccounts[id]; ok {
- continue
- }
- existingAccounts[id] = struct{}{}
- next.Accounts = append(next.Accounts, acc)
- importedAccounts++
- }
-
- if len(incoming.ClaudeMapping) > 0 {
- if next.ClaudeMapping == nil {
- next.ClaudeMapping = map[string]string{}
- }
- for k, v := range incoming.ClaudeMapping {
- next.ClaudeMapping[k] = v
- }
- }
- if len(incoming.ClaudeModelMap) > 0 {
- if next.ClaudeModelMap == nil {
- next.ClaudeModelMap = map[string]string{}
- }
- for k, v := range incoming.ClaudeModelMap {
- next.ClaudeModelMap[k] = v
- }
- }
-
- if len(incoming.ModelAliases) > 0 {
- if next.ModelAliases == nil {
- next.ModelAliases = map[string]string{}
- }
- for k, v := range incoming.ModelAliases {
- next.ModelAliases[k] = v
- }
- }
- if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
- next.Toolcall.Mode = incoming.Toolcall.Mode
- }
- if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
- next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
- }
- if incoming.Responses.StoreTTLSeconds > 0 {
- next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
- }
- if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
- next.Embeddings.Provider = incoming.Embeddings.Provider
- }
- if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
- next.Admin.PasswordHash = incoming.Admin.PasswordHash
- }
- if incoming.Admin.JWTExpireHours > 0 {
- next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
- }
- if incoming.Admin.JWTValidAfterUnix > 0 {
- next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
- }
- if incoming.Runtime.AccountMaxInflight > 0 {
- next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
- }
- if incoming.Runtime.AccountMaxQueue > 0 {
- next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
- }
- if incoming.Runtime.GlobalMaxInflight > 0 {
- next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
- }
- }
-
- normalizeSettingsConfig(&next)
- if err := validateSettingsConfig(next); err != nil {
- return newRequestError(err.Error())
- }
-
- *c = next
- return nil
- })
- if err != nil {
- if detail, ok := requestErrorDetail(err); ok {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
- return
- }
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
-
- h.Pool.Reset()
- writeJSON(w, http.StatusOK, map[string]any{
- "success": true,
- "mode": mode,
- "imported_keys": importedKeys,
- "imported_accounts": importedAccounts,
- "message": "config imported",
- })
-}
-
-func (h *Handler) computeSyncHash() string {
- snap := h.Store.Snapshot().Clone()
- snap.VercelSyncHash = ""
- snap.VercelSyncTime = 0
- b, _ := json.Marshal(snap)
- sum := md5.Sum(b)
- return fmt.Sprintf("%x", sum)
-}
diff --git a/internal/admin/handler_config_import.go b/internal/admin/handler_config_import.go
new file mode 100644
index 0000000..674d8b2
--- /dev/null
+++ b/internal/admin/handler_config_import.go
@@ -0,0 +1,182 @@
+package admin
+
+import (
+ "crypto/md5"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "ds2api/internal/config"
+)
+
+func (h *Handler) configImport(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
+ return
+ }
+
+ mode := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("mode")))
+ if mode == "" {
+ mode = strings.TrimSpace(strings.ToLower(fieldString(req, "mode")))
+ }
+ if mode == "" {
+ mode = "merge"
+ }
+ if mode != "merge" && mode != "replace" {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "mode must be merge or replace"})
+ return
+ }
+
+ payload := req
+ if raw, ok := req["config"].(map[string]any); ok && len(raw) > 0 {
+ payload = raw
+ }
+ rawJSON, err := json.Marshal(payload)
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid config payload"})
+ return
+ }
+ var incoming config.Config
+ if err := json.Unmarshal(rawJSON, &incoming); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
+ return
+ }
+
+ importedKeys, importedAccounts := 0, 0
+ err = h.Store.Update(func(c *config.Config) error {
+ next := c.Clone()
+ if mode == "replace" {
+ next = incoming.Clone()
+ next.VercelSyncHash = c.VercelSyncHash
+ next.VercelSyncTime = c.VercelSyncTime
+ importedKeys = len(next.Keys)
+ importedAccounts = len(next.Accounts)
+ } else {
+ existingKeys := map[string]struct{}{}
+ for _, k := range next.Keys {
+ existingKeys[k] = struct{}{}
+ }
+ for _, k := range incoming.Keys {
+ key := strings.TrimSpace(k)
+ if key == "" {
+ continue
+ }
+ if _, ok := existingKeys[key]; ok {
+ continue
+ }
+ existingKeys[key] = struct{}{}
+ next.Keys = append(next.Keys, key)
+ importedKeys++
+ }
+
+ existingAccounts := map[string]struct{}{}
+ for _, acc := range next.Accounts {
+ existingAccounts[acc.Identifier()] = struct{}{}
+ }
+ for _, acc := range incoming.Accounts {
+ id := acc.Identifier()
+ if id == "" {
+ continue
+ }
+ if _, ok := existingAccounts[id]; ok {
+ continue
+ }
+ existingAccounts[id] = struct{}{}
+ next.Accounts = append(next.Accounts, acc)
+ importedAccounts++
+ }
+
+ if len(incoming.ClaudeMapping) > 0 {
+ if next.ClaudeMapping == nil {
+ next.ClaudeMapping = map[string]string{}
+ }
+ for k, v := range incoming.ClaudeMapping {
+ next.ClaudeMapping[k] = v
+ }
+ }
+ if len(incoming.ClaudeModelMap) > 0 {
+ if next.ClaudeModelMap == nil {
+ next.ClaudeModelMap = map[string]string{}
+ }
+ for k, v := range incoming.ClaudeModelMap {
+ next.ClaudeModelMap[k] = v
+ }
+ }
+
+ if len(incoming.ModelAliases) > 0 {
+ if next.ModelAliases == nil {
+ next.ModelAliases = map[string]string{}
+ }
+ for k, v := range incoming.ModelAliases {
+ next.ModelAliases[k] = v
+ }
+ }
+ if strings.TrimSpace(incoming.Toolcall.Mode) != "" {
+ next.Toolcall.Mode = incoming.Toolcall.Mode
+ }
+ if strings.TrimSpace(incoming.Toolcall.EarlyEmitConfidence) != "" {
+ next.Toolcall.EarlyEmitConfidence = incoming.Toolcall.EarlyEmitConfidence
+ }
+ if incoming.Responses.StoreTTLSeconds > 0 {
+ next.Responses.StoreTTLSeconds = incoming.Responses.StoreTTLSeconds
+ }
+ if strings.TrimSpace(incoming.Embeddings.Provider) != "" {
+ next.Embeddings.Provider = incoming.Embeddings.Provider
+ }
+ if strings.TrimSpace(incoming.Admin.PasswordHash) != "" {
+ next.Admin.PasswordHash = incoming.Admin.PasswordHash
+ }
+ if incoming.Admin.JWTExpireHours > 0 {
+ next.Admin.JWTExpireHours = incoming.Admin.JWTExpireHours
+ }
+ if incoming.Admin.JWTValidAfterUnix > 0 {
+ next.Admin.JWTValidAfterUnix = incoming.Admin.JWTValidAfterUnix
+ }
+ if incoming.Runtime.AccountMaxInflight > 0 {
+ next.Runtime.AccountMaxInflight = incoming.Runtime.AccountMaxInflight
+ }
+ if incoming.Runtime.AccountMaxQueue > 0 {
+ next.Runtime.AccountMaxQueue = incoming.Runtime.AccountMaxQueue
+ }
+ if incoming.Runtime.GlobalMaxInflight > 0 {
+ next.Runtime.GlobalMaxInflight = incoming.Runtime.GlobalMaxInflight
+ }
+ }
+
+ normalizeSettingsConfig(&next)
+ if err := validateSettingsConfig(next); err != nil {
+ return newRequestError(err.Error())
+ }
+
+ *c = next
+ return nil
+ })
+ if err != nil {
+ if detail, ok := requestErrorDetail(err); ok {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
+ return
+ }
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+
+ h.Pool.Reset()
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "mode": mode,
+ "imported_keys": importedKeys,
+ "imported_accounts": importedAccounts,
+ "message": "config imported",
+ })
+}
+
+func (h *Handler) computeSyncHash() string {
+ snap := h.Store.Snapshot().Clone()
+ snap.VercelSyncHash = ""
+ snap.VercelSyncTime = 0
+ b, _ := json.Marshal(snap)
+ sum := md5.Sum(b)
+ return fmt.Sprintf("%x", sum)
+}
diff --git a/internal/admin/handler_config_read.go b/internal/admin/handler_config_read.go
new file mode 100644
index 0000000..e32aabd
--- /dev/null
+++ b/internal/admin/handler_config_read.go
@@ -0,0 +1,61 @@
+package admin
+
+import (
+ "net/http"
+ "strings"
+)
+
+func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
+ snap := h.Store.Snapshot()
+ safe := map[string]any{
+ "keys": snap.Keys,
+ "accounts": []map[string]any{},
+ "claude_mapping": func() map[string]string {
+ if len(snap.ClaudeMapping) > 0 {
+ return snap.ClaudeMapping
+ }
+ return snap.ClaudeModelMap
+ }(),
+ }
+ accounts := make([]map[string]any, 0, len(snap.Accounts))
+ for _, acc := range snap.Accounts {
+ token := strings.TrimSpace(acc.Token)
+ preview := ""
+ if token != "" {
+ if len(token) > 20 {
+ preview = token[:20] + "..."
+ } else {
+ preview = token
+ }
+ }
+ accounts = append(accounts, map[string]any{
+ "identifier": acc.Identifier(),
+ "email": acc.Email,
+ "mobile": acc.Mobile,
+ "has_password": strings.TrimSpace(acc.Password) != "",
+ "has_token": token != "",
+ "token_preview": preview,
+ })
+ }
+ safe["accounts"] = accounts
+ writeJSON(w, http.StatusOK, safe)
+}
+
+func (h *Handler) exportConfig(w http.ResponseWriter, _ *http.Request) {
+ h.configExport(w, nil)
+}
+
+func (h *Handler) configExport(w http.ResponseWriter, _ *http.Request) {
+ snap := h.Store.Snapshot()
+ jsonStr, b64, err := h.Store.ExportJSONAndBase64()
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "config": snap,
+ "json": jsonStr,
+ "base64": b64,
+ })
+}
diff --git a/internal/admin/handler_config_write.go b/internal/admin/handler_config_write.go
new file mode 100644
index 0000000..792e696
--- /dev/null
+++ b/internal/admin/handler_config_write.go
@@ -0,0 +1,166 @@
+package admin
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "github.com/go-chi/chi/v5"
+
+ "ds2api/internal/config"
+)
+
+func (h *Handler) updateConfig(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
+ return
+ }
+ old := h.Store.Snapshot()
+ err := h.Store.Update(func(c *config.Config) error {
+ if keys, ok := toStringSlice(req["keys"]); ok {
+ c.Keys = keys
+ }
+ if accountsRaw, ok := req["accounts"].([]any); ok {
+ existing := map[string]config.Account{}
+ for _, a := range old.Accounts {
+ existing[a.Identifier()] = a
+ }
+ accounts := make([]config.Account, 0, len(accountsRaw))
+ for _, item := range accountsRaw {
+ m, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ acc := toAccount(m)
+ id := acc.Identifier()
+ if prev, ok := existing[id]; ok {
+ if strings.TrimSpace(acc.Password) == "" {
+ acc.Password = prev.Password
+ }
+ if strings.TrimSpace(acc.Token) == "" {
+ acc.Token = prev.Token
+ }
+ }
+ accounts = append(accounts, acc)
+ }
+ c.Accounts = accounts
+ }
+ if m, ok := req["claude_mapping"].(map[string]any); ok {
+ newMap := map[string]string{}
+ for k, v := range m {
+ newMap[k] = fmt.Sprintf("%v", v)
+ }
+ c.ClaudeMapping = newMap
+ }
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+ h.Pool.Reset()
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "message": "配置已更新"})
+}
+
+func (h *Handler) addKey(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ _ = json.NewDecoder(r.Body).Decode(&req)
+ key, _ := req["key"].(string)
+ key = strings.TrimSpace(key)
+ if key == "" {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "Key 不能为空"})
+ return
+ }
+ err := h.Store.Update(func(c *config.Config) error {
+ for _, k := range c.Keys {
+ if k == key {
+ return fmt.Errorf("Key 已存在")
+ }
+ }
+ c.Keys = append(c.Keys, key)
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
+}
+
+func (h *Handler) deleteKey(w http.ResponseWriter, r *http.Request) {
+ key := chi.URLParam(r, "key")
+ err := h.Store.Update(func(c *config.Config) error {
+ idx := -1
+ for i, k := range c.Keys {
+ if k == key {
+ idx = i
+ break
+ }
+ }
+ if idx < 0 {
+ return fmt.Errorf("Key 不存在")
+ }
+ c.Keys = append(c.Keys[:idx], c.Keys[idx+1:]...)
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusNotFound, map[string]any{"detail": err.Error()})
+ return
+ }
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "total_keys": len(h.Store.Snapshot().Keys)})
+}
+
+func (h *Handler) batchImport(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "无效的 JSON 格式"})
+ return
+ }
+ importedKeys, importedAccounts := 0, 0
+ err := h.Store.Update(func(c *config.Config) error {
+ if keys, ok := req["keys"].([]any); ok {
+ existing := map[string]bool{}
+ for _, k := range c.Keys {
+ existing[k] = true
+ }
+ for _, k := range keys {
+ key := strings.TrimSpace(fmt.Sprintf("%v", k))
+ if key == "" || existing[key] {
+ continue
+ }
+ c.Keys = append(c.Keys, key)
+ existing[key] = true
+ importedKeys++
+ }
+ }
+ if accounts, ok := req["accounts"].([]any); ok {
+ existing := map[string]bool{}
+ for _, a := range c.Accounts {
+ existing[a.Identifier()] = true
+ }
+ for _, item := range accounts {
+ m, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ acc := toAccount(m)
+ id := acc.Identifier()
+ if id == "" || existing[id] {
+ continue
+ }
+ c.Accounts = append(c.Accounts, acc)
+ existing[id] = true
+ importedAccounts++
+ }
+ }
+ return nil
+ })
+ if err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+ h.Pool.Reset()
+ writeJSON(w, http.StatusOK, map[string]any{"success": true, "imported_keys": importedKeys, "imported_accounts": importedAccounts})
+}
diff --git a/internal/admin/handler_settings.go b/internal/admin/handler_settings.go
deleted file mode 100644
index 06c234c..0000000
--- a/internal/admin/handler_settings.go
+++ /dev/null
@@ -1,321 +0,0 @@
-package admin
-
-import (
- "encoding/json"
- "fmt"
- "net/http"
- "strings"
- "time"
-
- authn "ds2api/internal/auth"
- "ds2api/internal/config"
-)
-
-func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
- snap := h.Store.Snapshot()
- recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
- needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
- writeJSON(w, http.StatusOK, map[string]any{
- "success": true,
- "admin": map[string]any{
- "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
- "jwt_expire_hours": h.Store.AdminJWTExpireHours(),
- "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
- "default_password_warning": authn.UsingDefaultAdminKey(h.Store),
- },
- "runtime": map[string]any{
- "account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
- "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
- "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
- },
- "toolcall": snap.Toolcall,
- "responses": snap.Responses,
- "embeddings": snap.Embeddings,
- "claude_mapping": settingsClaudeMapping(snap),
- "model_aliases": snap.ModelAliases,
- "env_backed": h.Store.IsEnvBacked(),
- "needs_vercel_sync": needsSync,
- })
-}
-
-func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
- return
- }
-
- adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
- if err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
- return
- }
- if runtimeCfg != nil {
- if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
- return
- }
- }
-
- if err := h.Store.Update(func(c *config.Config) error {
- if adminCfg != nil {
- if adminCfg.JWTExpireHours > 0 {
- c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
- }
- }
- if runtimeCfg != nil {
- if runtimeCfg.AccountMaxInflight > 0 {
- c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
- }
- if runtimeCfg.AccountMaxQueue > 0 {
- c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
- }
- if runtimeCfg.GlobalMaxInflight > 0 {
- c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
- }
- }
- if toolcallCfg != nil {
- if strings.TrimSpace(toolcallCfg.Mode) != "" {
- c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
- }
- if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
- c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
- }
- }
- if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
- c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
- }
- if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
- c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
- }
- if claudeMap != nil {
- c.ClaudeMapping = claudeMap
- c.ClaudeModelMap = nil
- }
- if aliasMap != nil {
- c.ModelAliases = aliasMap
- }
- return nil
- }); err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
-
- h.applyRuntimeSettings()
- needsSync := config.IsVercel() || h.Store.IsEnvBacked()
- writeJSON(w, http.StatusOK, map[string]any{
- "success": true,
- "message": "settings updated and hot reloaded",
- "env_backed": h.Store.IsEnvBacked(),
- "needs_vercel_sync": needsSync,
- "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
- })
-}
-
-func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
- merged := current
- if incoming != nil {
- if incoming.AccountMaxInflight > 0 {
- merged.AccountMaxInflight = incoming.AccountMaxInflight
- }
- if incoming.AccountMaxQueue > 0 {
- merged.AccountMaxQueue = incoming.AccountMaxQueue
- }
- if incoming.GlobalMaxInflight > 0 {
- merged.GlobalMaxInflight = incoming.GlobalMaxInflight
- }
- }
- return validateRuntimeSettings(merged)
-}
-
-func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
- var req map[string]any
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
- return
- }
- newPassword := strings.TrimSpace(fieldString(req, "new_password"))
- if newPassword == "" {
- newPassword = strings.TrimSpace(fieldString(req, "password"))
- }
- if len(newPassword) < 4 {
- writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
- return
- }
-
- now := time.Now().Unix()
- hash := authn.HashAdminPassword(newPassword)
- if err := h.Store.Update(func(c *config.Config) error {
- c.Admin.PasswordHash = hash
- c.Admin.JWTValidAfterUnix = now
- return nil
- }); err != nil {
- writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
- return
- }
-
- writeJSON(w, http.StatusOK, map[string]any{
- "success": true,
- "message": "password updated",
- "force_relogin": true,
- "jwt_valid_after_unix": now,
- })
-}
-
-func (h *Handler) applyRuntimeSettings() {
- if h == nil || h.Store == nil || h.Pool == nil {
- return
- }
- accountCount := len(h.Store.Accounts())
- maxPer := h.Store.RuntimeAccountMaxInflight()
- recommended := defaultRuntimeRecommended(accountCount, maxPer)
- maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
- global := h.Store.RuntimeGlobalMaxInflight(recommended)
- h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
-}
-
-func defaultRuntimeRecommended(accountCount, maxPer int) int {
- if maxPer <= 0 {
- maxPer = 1
- }
- if accountCount <= 0 {
- return maxPer
- }
- return accountCount * maxPer
-}
-
-func settingsClaudeMapping(c config.Config) map[string]string {
- if len(c.ClaudeMapping) > 0 {
- return c.ClaudeMapping
- }
- if len(c.ClaudeModelMap) > 0 {
- return c.ClaudeModelMap
- }
- return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
-}
-
-func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
- var (
- adminCfg *config.AdminConfig
- runtimeCfg *config.RuntimeConfig
- toolcallCfg *config.ToolcallConfig
- respCfg *config.ResponsesConfig
- embCfg *config.EmbeddingsConfig
- claudeMap map[string]string
- aliasMap map[string]string
- )
-
- if raw, ok := req["admin"].(map[string]any); ok {
- cfg := &config.AdminConfig{}
- if v, exists := raw["jwt_expire_hours"]; exists {
- n := intFrom(v)
- if n < 1 || n > 720 {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
- }
- cfg.JWTExpireHours = n
- }
- adminCfg = cfg
- }
-
- if raw, ok := req["runtime"].(map[string]any); ok {
- cfg := &config.RuntimeConfig{}
- if v, exists := raw["account_max_inflight"]; exists {
- n := intFrom(v)
- if n < 1 || n > 256 {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
- }
- cfg.AccountMaxInflight = n
- }
- if v, exists := raw["account_max_queue"]; exists {
- n := intFrom(v)
- if n < 1 || n > 200000 {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
- }
- cfg.AccountMaxQueue = n
- }
- if v, exists := raw["global_max_inflight"]; exists {
- n := intFrom(v)
- if n < 1 || n > 200000 {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
- }
- cfg.GlobalMaxInflight = n
- }
- if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
- }
- runtimeCfg = cfg
- }
-
- if raw, ok := req["toolcall"].(map[string]any); ok {
- cfg := &config.ToolcallConfig{}
- if v, exists := raw["mode"]; exists {
- mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
- switch mode {
- case "feature_match", "off":
- cfg.Mode = mode
- default:
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
- }
- }
- if v, exists := raw["early_emit_confidence"]; exists {
- level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
- switch level {
- case "high", "low", "off":
- cfg.EarlyEmitConfidence = level
- default:
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
- }
- }
- toolcallCfg = cfg
- }
-
- if raw, ok := req["responses"].(map[string]any); ok {
- cfg := &config.ResponsesConfig{}
- if v, exists := raw["store_ttl_seconds"]; exists {
- n := intFrom(v)
- if n < 30 || n > 86400 {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
- }
- cfg.StoreTTLSeconds = n
- }
- respCfg = cfg
- }
-
- if raw, ok := req["embeddings"].(map[string]any); ok {
- cfg := &config.EmbeddingsConfig{}
- if v, exists := raw["provider"]; exists {
- p := strings.TrimSpace(fmt.Sprintf("%v", v))
- if p == "" {
- return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
- }
- cfg.Provider = p
- }
- embCfg = cfg
- }
-
- if raw, ok := req["claude_mapping"].(map[string]any); ok {
- claudeMap = map[string]string{}
- for k, v := range raw {
- key := strings.TrimSpace(k)
- val := strings.TrimSpace(fmt.Sprintf("%v", v))
- if key == "" || val == "" {
- continue
- }
- claudeMap[key] = val
- }
- }
-
- if raw, ok := req["model_aliases"].(map[string]any); ok {
- aliasMap = map[string]string{}
- for k, v := range raw {
- key := strings.TrimSpace(k)
- val := strings.TrimSpace(fmt.Sprintf("%v", v))
- if key == "" || val == "" {
- continue
- }
- aliasMap[key] = val
- }
- }
-
- return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
-}
diff --git a/internal/admin/handler_settings_parse.go b/internal/admin/handler_settings_parse.go
new file mode 100644
index 0000000..6c5b7ee
--- /dev/null
+++ b/internal/admin/handler_settings_parse.go
@@ -0,0 +1,134 @@
+package admin
+
+import (
+ "fmt"
+ "strings"
+
+ "ds2api/internal/config"
+)
+
+func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *config.RuntimeConfig, *config.ToolcallConfig, *config.ResponsesConfig, *config.EmbeddingsConfig, map[string]string, map[string]string, error) {
+ var (
+ adminCfg *config.AdminConfig
+ runtimeCfg *config.RuntimeConfig
+ toolcallCfg *config.ToolcallConfig
+ respCfg *config.ResponsesConfig
+ embCfg *config.EmbeddingsConfig
+ claudeMap map[string]string
+ aliasMap map[string]string
+ )
+
+ if raw, ok := req["admin"].(map[string]any); ok {
+ cfg := &config.AdminConfig{}
+ if v, exists := raw["jwt_expire_hours"]; exists {
+ n := intFrom(v)
+ if n < 1 || n > 720 {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
+ }
+ cfg.JWTExpireHours = n
+ }
+ adminCfg = cfg
+ }
+
+ if raw, ok := req["runtime"].(map[string]any); ok {
+ cfg := &config.RuntimeConfig{}
+ if v, exists := raw["account_max_inflight"]; exists {
+ n := intFrom(v)
+ if n < 1 || n > 256 {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
+ }
+ cfg.AccountMaxInflight = n
+ }
+ if v, exists := raw["account_max_queue"]; exists {
+ n := intFrom(v)
+ if n < 1 || n > 200000 {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
+ }
+ cfg.AccountMaxQueue = n
+ }
+ if v, exists := raw["global_max_inflight"]; exists {
+ n := intFrom(v)
+ if n < 1 || n > 200000 {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
+ }
+ cfg.GlobalMaxInflight = n
+ }
+ if cfg.AccountMaxInflight > 0 && cfg.GlobalMaxInflight > 0 && cfg.GlobalMaxInflight < cfg.AccountMaxInflight {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
+ }
+ runtimeCfg = cfg
+ }
+
+ if raw, ok := req["toolcall"].(map[string]any); ok {
+ cfg := &config.ToolcallConfig{}
+ if v, exists := raw["mode"]; exists {
+ mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
+ switch mode {
+ case "feature_match", "off":
+ cfg.Mode = mode
+ default:
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.mode must be feature_match or off")
+ }
+ }
+ if v, exists := raw["early_emit_confidence"]; exists {
+ level := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
+ switch level {
+ case "high", "low", "off":
+ cfg.EarlyEmitConfidence = level
+ default:
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("toolcall.early_emit_confidence must be high, low or off")
+ }
+ }
+ toolcallCfg = cfg
+ }
+
+ if raw, ok := req["responses"].(map[string]any); ok {
+ cfg := &config.ResponsesConfig{}
+ if v, exists := raw["store_ttl_seconds"]; exists {
+ n := intFrom(v)
+ if n < 30 || n > 86400 {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
+ }
+ cfg.StoreTTLSeconds = n
+ }
+ respCfg = cfg
+ }
+
+ if raw, ok := req["embeddings"].(map[string]any); ok {
+ cfg := &config.EmbeddingsConfig{}
+ if v, exists := raw["provider"]; exists {
+ p := strings.TrimSpace(fmt.Sprintf("%v", v))
+ if p == "" {
+ return nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("embeddings.provider cannot be empty")
+ }
+ cfg.Provider = p
+ }
+ embCfg = cfg
+ }
+
+ if raw, ok := req["claude_mapping"].(map[string]any); ok {
+ claudeMap = map[string]string{}
+ for k, v := range raw {
+ key := strings.TrimSpace(k)
+ val := strings.TrimSpace(fmt.Sprintf("%v", v))
+ if key == "" || val == "" {
+ continue
+ }
+ claudeMap[key] = val
+ }
+ }
+
+ if raw, ok := req["model_aliases"].(map[string]any); ok {
+ aliasMap = map[string]string{}
+ for k, v := range raw {
+ key := strings.TrimSpace(k)
+ val := strings.TrimSpace(fmt.Sprintf("%v", v))
+ if key == "" || val == "" {
+ continue
+ }
+ aliasMap[key] = val
+ }
+ }
+
+ return adminCfg, runtimeCfg, toolcallCfg, respCfg, embCfg, claudeMap, aliasMap, nil
+}
diff --git a/internal/admin/handler_settings_read.go b/internal/admin/handler_settings_read.go
new file mode 100644
index 0000000..565519f
--- /dev/null
+++ b/internal/admin/handler_settings_read.go
@@ -0,0 +1,36 @@
+package admin
+
+import (
+ "net/http"
+ "strings"
+
+ authn "ds2api/internal/auth"
+ "ds2api/internal/config"
+)
+
+func (h *Handler) getSettings(w http.ResponseWriter, _ *http.Request) {
+ snap := h.Store.Snapshot()
+ recommended := defaultRuntimeRecommended(len(snap.Accounts), h.Store.RuntimeAccountMaxInflight())
+ needsSync := config.IsVercel() && snap.VercelSyncHash != "" && snap.VercelSyncHash != h.computeSyncHash()
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "admin": map[string]any{
+ "has_password_hash": strings.TrimSpace(snap.Admin.PasswordHash) != "",
+ "jwt_expire_hours": h.Store.AdminJWTExpireHours(),
+ "jwt_valid_after_unix": snap.Admin.JWTValidAfterUnix,
+ "default_password_warning": authn.UsingDefaultAdminKey(h.Store),
+ },
+ "runtime": map[string]any{
+ "account_max_inflight": h.Store.RuntimeAccountMaxInflight(),
+ "account_max_queue": h.Store.RuntimeAccountMaxQueue(recommended),
+ "global_max_inflight": h.Store.RuntimeGlobalMaxInflight(recommended),
+ },
+ "toolcall": snap.Toolcall,
+ "responses": snap.Responses,
+ "embeddings": snap.Embeddings,
+ "claude_mapping": settingsClaudeMapping(snap),
+ "model_aliases": snap.ModelAliases,
+ "env_backed": h.Store.IsEnvBacked(),
+ "needs_vercel_sync": needsSync,
+ })
+}
diff --git a/internal/admin/handler_settings_runtime.go b/internal/admin/handler_settings_runtime.go
new file mode 100644
index 0000000..6ff6902
--- /dev/null
+++ b/internal/admin/handler_settings_runtime.go
@@ -0,0 +1,51 @@
+package admin
+
+import "ds2api/internal/config"
+
+func validateMergedRuntimeSettings(current config.RuntimeConfig, incoming *config.RuntimeConfig) error {
+ merged := current
+ if incoming != nil {
+ if incoming.AccountMaxInflight > 0 {
+ merged.AccountMaxInflight = incoming.AccountMaxInflight
+ }
+ if incoming.AccountMaxQueue > 0 {
+ merged.AccountMaxQueue = incoming.AccountMaxQueue
+ }
+ if incoming.GlobalMaxInflight > 0 {
+ merged.GlobalMaxInflight = incoming.GlobalMaxInflight
+ }
+ }
+ return validateRuntimeSettings(merged)
+}
+
+func (h *Handler) applyRuntimeSettings() {
+ if h == nil || h.Store == nil || h.Pool == nil {
+ return
+ }
+ accountCount := len(h.Store.Accounts())
+ maxPer := h.Store.RuntimeAccountMaxInflight()
+ recommended := defaultRuntimeRecommended(accountCount, maxPer)
+ maxQueue := h.Store.RuntimeAccountMaxQueue(recommended)
+ global := h.Store.RuntimeGlobalMaxInflight(recommended)
+ h.Pool.ApplyRuntimeLimits(maxPer, maxQueue, global)
+}
+
+func defaultRuntimeRecommended(accountCount, maxPer int) int {
+ if maxPer <= 0 {
+ maxPer = 1
+ }
+ if accountCount <= 0 {
+ return maxPer
+ }
+ return accountCount * maxPer
+}
+
+func settingsClaudeMapping(c config.Config) map[string]string {
+ if len(c.ClaudeMapping) > 0 {
+ return c.ClaudeMapping
+ }
+ if len(c.ClaudeModelMap) > 0 {
+ return c.ClaudeModelMap
+ }
+ return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
+}
diff --git a/internal/admin/handler_settings_write.go b/internal/admin/handler_settings_write.go
new file mode 100644
index 0000000..c0076ea
--- /dev/null
+++ b/internal/admin/handler_settings_write.go
@@ -0,0 +1,119 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "strings"
+ "time"
+
+ authn "ds2api/internal/auth"
+ "ds2api/internal/config"
+)
+
+func (h *Handler) updateSettings(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
+ return
+ }
+
+ adminCfg, runtimeCfg, toolcallCfg, responsesCfg, embeddingsCfg, claudeMap, aliasMap, err := parseSettingsUpdateRequest(req)
+ if err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
+ return
+ }
+ if runtimeCfg != nil {
+ if err := validateMergedRuntimeSettings(h.Store.Snapshot().Runtime, runtimeCfg); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
+ return
+ }
+ }
+
+ if err := h.Store.Update(func(c *config.Config) error {
+ if adminCfg != nil {
+ if adminCfg.JWTExpireHours > 0 {
+ c.Admin.JWTExpireHours = adminCfg.JWTExpireHours
+ }
+ }
+ if runtimeCfg != nil {
+ if runtimeCfg.AccountMaxInflight > 0 {
+ c.Runtime.AccountMaxInflight = runtimeCfg.AccountMaxInflight
+ }
+ if runtimeCfg.AccountMaxQueue > 0 {
+ c.Runtime.AccountMaxQueue = runtimeCfg.AccountMaxQueue
+ }
+ if runtimeCfg.GlobalMaxInflight > 0 {
+ c.Runtime.GlobalMaxInflight = runtimeCfg.GlobalMaxInflight
+ }
+ }
+ if toolcallCfg != nil {
+ if strings.TrimSpace(toolcallCfg.Mode) != "" {
+ c.Toolcall.Mode = strings.TrimSpace(toolcallCfg.Mode)
+ }
+ if strings.TrimSpace(toolcallCfg.EarlyEmitConfidence) != "" {
+ c.Toolcall.EarlyEmitConfidence = strings.TrimSpace(toolcallCfg.EarlyEmitConfidence)
+ }
+ }
+ if responsesCfg != nil && responsesCfg.StoreTTLSeconds > 0 {
+ c.Responses.StoreTTLSeconds = responsesCfg.StoreTTLSeconds
+ }
+ if embeddingsCfg != nil && strings.TrimSpace(embeddingsCfg.Provider) != "" {
+ c.Embeddings.Provider = strings.TrimSpace(embeddingsCfg.Provider)
+ }
+ if claudeMap != nil {
+ c.ClaudeMapping = claudeMap
+ c.ClaudeModelMap = nil
+ }
+ if aliasMap != nil {
+ c.ModelAliases = aliasMap
+ }
+ return nil
+ }); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+
+ h.applyRuntimeSettings()
+ needsSync := config.IsVercel() || h.Store.IsEnvBacked()
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "message": "settings updated and hot reloaded",
+ "env_backed": h.Store.IsEnvBacked(),
+ "needs_vercel_sync": needsSync,
+ "manual_sync_message": "配置已保存。Vercel 部署请在 Vercel Sync 页面手动同步。",
+ })
+}
+
+func (h *Handler) updateSettingsPassword(w http.ResponseWriter, r *http.Request) {
+ var req map[string]any
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "invalid json"})
+ return
+ }
+ newPassword := strings.TrimSpace(fieldString(req, "new_password"))
+ if newPassword == "" {
+ newPassword = strings.TrimSpace(fieldString(req, "password"))
+ }
+ if len(newPassword) < 4 {
+ writeJSON(w, http.StatusBadRequest, map[string]any{"detail": "new password must be at least 4 characters"})
+ return
+ }
+
+ now := time.Now().Unix()
+ hash := authn.HashAdminPassword(newPassword)
+ if err := h.Store.Update(func(c *config.Config) error {
+ c.Admin.PasswordHash = hash
+ c.Admin.JWTValidAfterUnix = now
+ return nil
+ }); err != nil {
+ writeJSON(w, http.StatusInternalServerError, map[string]any{"detail": err.Error()})
+ return
+ }
+
+ writeJSON(w, http.StatusOK, map[string]any{
+ "success": true,
+ "message": "password updated",
+ "force_relogin": true,
+ "jwt_valid_after_unix": now,
+ })
+}
diff --git a/internal/config/account.go b/internal/config/account.go
new file mode 100644
index 0000000..29a4947
--- /dev/null
+++ b/internal/config/account.go
@@ -0,0 +1,24 @@
+package config
+
+import (
+ "crypto/sha256"
+ "encoding/hex"
+ "strings"
+)
+
+func (a Account) Identifier() string {
+ if strings.TrimSpace(a.Email) != "" {
+ return strings.TrimSpace(a.Email)
+ }
+ if strings.TrimSpace(a.Mobile) != "" {
+ return strings.TrimSpace(a.Mobile)
+ }
+ // Backward compatibility: old configs may contain token-only accounts.
+ // Use a stable non-sensitive synthetic id so they can still join the pool.
+ token := strings.TrimSpace(a.Token)
+ if token == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(token))
+ return "token:" + hex.EncodeToString(sum[:8])
+}
diff --git a/internal/config/codec.go b/internal/config/codec.go
new file mode 100644
index 0000000..2a23e20
--- /dev/null
+++ b/internal/config/codec.go
@@ -0,0 +1,241 @@
+package config
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "slices"
+ "strings"
+)
+
+func (c Config) MarshalJSON() ([]byte, error) {
+ m := map[string]any{}
+ for k, v := range c.AdditionalFields {
+ m[k] = v
+ }
+ if len(c.Keys) > 0 {
+ m["keys"] = c.Keys
+ }
+ if len(c.Accounts) > 0 {
+ m["accounts"] = c.Accounts
+ }
+ if len(c.ClaudeMapping) > 0 {
+ m["claude_mapping"] = c.ClaudeMapping
+ }
+ if len(c.ClaudeModelMap) > 0 {
+ m["claude_model_mapping"] = c.ClaudeModelMap
+ }
+ if len(c.ModelAliases) > 0 {
+ m["model_aliases"] = c.ModelAliases
+ }
+ if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
+ m["admin"] = c.Admin
+ }
+ if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
+ m["runtime"] = c.Runtime
+ }
+ if c.Compat.WideInputStrictOutput != nil {
+ m["compat"] = c.Compat
+ }
+ if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
+ m["toolcall"] = c.Toolcall
+ }
+ if c.Responses.StoreTTLSeconds > 0 {
+ m["responses"] = c.Responses
+ }
+ if strings.TrimSpace(c.Embeddings.Provider) != "" {
+ m["embeddings"] = c.Embeddings
+ }
+ if c.VercelSyncHash != "" {
+ m["_vercel_sync_hash"] = c.VercelSyncHash
+ }
+ if c.VercelSyncTime != 0 {
+ m["_vercel_sync_time"] = c.VercelSyncTime
+ }
+ return json.Marshal(m)
+}
+
+func (c *Config) UnmarshalJSON(b []byte) error {
+ raw := map[string]json.RawMessage{}
+ if err := json.Unmarshal(b, &raw); err != nil {
+ return err
+ }
+ c.AdditionalFields = map[string]any{}
+ for k, v := range raw {
+ switch k {
+ case "keys":
+ if err := json.Unmarshal(v, &c.Keys); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "accounts":
+ if err := json.Unmarshal(v, &c.Accounts); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "claude_mapping":
+ if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "claude_model_mapping":
+ if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "model_aliases":
+ if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "admin":
+ if err := json.Unmarshal(v, &c.Admin); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "runtime":
+ if err := json.Unmarshal(v, &c.Runtime); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "compat":
+ if err := json.Unmarshal(v, &c.Compat); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "toolcall":
+ if err := json.Unmarshal(v, &c.Toolcall); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "responses":
+ if err := json.Unmarshal(v, &c.Responses); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "embeddings":
+ if err := json.Unmarshal(v, &c.Embeddings); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "_vercel_sync_hash":
+ if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ case "_vercel_sync_time":
+ if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
+ return fmt.Errorf("invalid field %q: %w", k, err)
+ }
+ default:
+ var anyVal any
+ if err := json.Unmarshal(v, &anyVal); err == nil {
+ c.AdditionalFields[k] = anyVal
+ }
+ }
+ }
+ return nil
+}
+
+func (c Config) Clone() Config {
+ clone := Config{
+ Keys: slices.Clone(c.Keys),
+ Accounts: slices.Clone(c.Accounts),
+ ClaudeMapping: cloneStringMap(c.ClaudeMapping),
+ ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
+ ModelAliases: cloneStringMap(c.ModelAliases),
+ Admin: c.Admin,
+ Runtime: c.Runtime,
+ Compat: CompatConfig{
+ WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
+ },
+ Toolcall: c.Toolcall,
+ Responses: c.Responses,
+ Embeddings: c.Embeddings,
+ VercelSyncHash: c.VercelSyncHash,
+ VercelSyncTime: c.VercelSyncTime,
+ AdditionalFields: map[string]any{},
+ }
+ for k, v := range c.AdditionalFields {
+ clone.AdditionalFields[k] = v
+ }
+ return clone
+}
+
+func cloneStringMap(in map[string]string) map[string]string {
+ if len(in) == 0 {
+ return nil
+ }
+ out := make(map[string]string, len(in))
+ for k, v := range in {
+ out[k] = v
+ }
+ return out
+}
+
+func cloneBoolPtr(in *bool) *bool {
+ if in == nil {
+ return nil
+ }
+ v := *in
+ return &v
+}
+
+func parseConfigString(raw string) (Config, error) {
+ var cfg Config
+ candidates := []string{raw}
+ if normalized := normalizeConfigInput(raw); normalized != raw {
+ candidates = append(candidates, normalized)
+ }
+ for _, candidate := range candidates {
+ if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
+ return cfg, nil
+ }
+ }
+
+ base64Input := candidates[len(candidates)-1]
+ decoded, err := decodeConfigBase64(base64Input)
+ if err != nil {
+ return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
+ }
+ if err := json.Unmarshal(decoded, &cfg); err != nil {
+ return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
+ }
+ return cfg, nil
+}
+
+func normalizeConfigInput(raw string) string {
+ normalized := strings.TrimSpace(raw)
+ if normalized == "" {
+ return normalized
+ }
+ for {
+ changed := false
+ if len(normalized) >= 2 {
+ first := normalized[0]
+ last := normalized[len(normalized)-1]
+ if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
+ normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
+ changed = true
+ }
+ }
+ if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
+ normalized = strings.TrimSpace(normalized[len("base64:"):])
+ changed = true
+ }
+ if !changed {
+ break
+ }
+ }
+ return strings.TrimSpace(normalized)
+}
+
+func decodeConfigBase64(raw string) ([]byte, error) {
+ encodings := []*base64.Encoding{
+ base64.StdEncoding,
+ base64.RawStdEncoding,
+ base64.URLEncoding,
+ base64.RawURLEncoding,
+ }
+ var lastErr error
+ for _, enc := range encodings {
+ decoded, err := enc.DecodeString(raw)
+ if err == nil {
+ return decoded, nil
+ }
+ lastErr = err
+ }
+ if lastErr != nil {
+ return nil, lastErr
+ }
+ return nil, errors.New("base64 decode failed")
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 3bc0409..4b281a2 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,63 +1,5 @@
package config
-import (
- "crypto/sha256"
- "encoding/base64"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "log/slog"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
- "sync"
-)
-
-var Logger = newLogger()
-
-func newLogger() *slog.Logger {
- level := new(slog.LevelVar)
- switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
- case "DEBUG":
- level.Set(slog.LevelDebug)
- case "WARN":
- level.Set(slog.LevelWarn)
- case "ERROR":
- level.Set(slog.LevelError)
- default:
- level.Set(slog.LevelInfo)
- }
- h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
- return slog.New(h)
-}
-
-type Account struct {
- Email string `json:"email,omitempty"`
- Mobile string `json:"mobile,omitempty"`
- Password string `json:"password,omitempty"`
- Token string `json:"token,omitempty"`
-}
-
-func (a Account) Identifier() string {
- if strings.TrimSpace(a.Email) != "" {
- return strings.TrimSpace(a.Email)
- }
- if strings.TrimSpace(a.Mobile) != "" {
- return strings.TrimSpace(a.Mobile)
- }
- // Backward compatibility: old configs may contain token-only accounts.
- // Use a stable non-sensitive synthetic id so they can still join the pool.
- token := strings.TrimSpace(a.Token)
- if token == "" {
- return ""
- }
- sum := sha256.Sum256([]byte(token))
- return "token:" + hex.EncodeToString(sum[:8])
-}
-
type Config struct {
Keys []string `json:"keys,omitempty"`
Accounts []Account `json:"accounts,omitempty"`
@@ -75,6 +17,13 @@ type Config struct {
AdditionalFields map[string]any `json:"-"`
}
+type Account struct {
+ Email string `json:"email,omitempty"`
+ Mobile string `json:"mobile,omitempty"`
+ Password string `json:"password,omitempty"`
+ Token string `json:"token,omitempty"`
+}
+
type CompatConfig struct {
WideInputStrictOutput *bool `json:"wide_input_strict_output,omitempty"`
}
@@ -103,641 +52,3 @@ type ResponsesConfig struct {
type EmbeddingsConfig struct {
Provider string `json:"provider,omitempty"`
}
-
-func (c Config) MarshalJSON() ([]byte, error) {
- m := map[string]any{}
- for k, v := range c.AdditionalFields {
- m[k] = v
- }
- if len(c.Keys) > 0 {
- m["keys"] = c.Keys
- }
- if len(c.Accounts) > 0 {
- m["accounts"] = c.Accounts
- }
- if len(c.ClaudeMapping) > 0 {
- m["claude_mapping"] = c.ClaudeMapping
- }
- if len(c.ClaudeModelMap) > 0 {
- m["claude_model_mapping"] = c.ClaudeModelMap
- }
- if len(c.ModelAliases) > 0 {
- m["model_aliases"] = c.ModelAliases
- }
- if strings.TrimSpace(c.Admin.PasswordHash) != "" || c.Admin.JWTExpireHours > 0 || c.Admin.JWTValidAfterUnix > 0 {
- m["admin"] = c.Admin
- }
- if c.Runtime.AccountMaxInflight > 0 || c.Runtime.AccountMaxQueue > 0 || c.Runtime.GlobalMaxInflight > 0 {
- m["runtime"] = c.Runtime
- }
- if c.Compat.WideInputStrictOutput != nil {
- m["compat"] = c.Compat
- }
- if strings.TrimSpace(c.Toolcall.Mode) != "" || strings.TrimSpace(c.Toolcall.EarlyEmitConfidence) != "" {
- m["toolcall"] = c.Toolcall
- }
- if c.Responses.StoreTTLSeconds > 0 {
- m["responses"] = c.Responses
- }
- if strings.TrimSpace(c.Embeddings.Provider) != "" {
- m["embeddings"] = c.Embeddings
- }
- if c.VercelSyncHash != "" {
- m["_vercel_sync_hash"] = c.VercelSyncHash
- }
- if c.VercelSyncTime != 0 {
- m["_vercel_sync_time"] = c.VercelSyncTime
- }
- return json.Marshal(m)
-}
-
-func (c *Config) UnmarshalJSON(b []byte) error {
- raw := map[string]json.RawMessage{}
- if err := json.Unmarshal(b, &raw); err != nil {
- return err
- }
- c.AdditionalFields = map[string]any{}
- for k, v := range raw {
- switch k {
- case "keys":
- if err := json.Unmarshal(v, &c.Keys); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "accounts":
- if err := json.Unmarshal(v, &c.Accounts); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "claude_mapping":
- if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "claude_model_mapping":
- if err := json.Unmarshal(v, &c.ClaudeModelMap); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "model_aliases":
- if err := json.Unmarshal(v, &c.ModelAliases); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "admin":
- if err := json.Unmarshal(v, &c.Admin); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "runtime":
- if err := json.Unmarshal(v, &c.Runtime); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "compat":
- if err := json.Unmarshal(v, &c.Compat); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "toolcall":
- if err := json.Unmarshal(v, &c.Toolcall); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "responses":
- if err := json.Unmarshal(v, &c.Responses); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "embeddings":
- if err := json.Unmarshal(v, &c.Embeddings); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "_vercel_sync_hash":
- if err := json.Unmarshal(v, &c.VercelSyncHash); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- case "_vercel_sync_time":
- if err := json.Unmarshal(v, &c.VercelSyncTime); err != nil {
- return fmt.Errorf("invalid field %q: %w", k, err)
- }
- default:
- var anyVal any
- if err := json.Unmarshal(v, &anyVal); err == nil {
- c.AdditionalFields[k] = anyVal
- }
- }
- }
- return nil
-}
-
-func (c Config) Clone() Config {
- clone := Config{
- Keys: slices.Clone(c.Keys),
- Accounts: slices.Clone(c.Accounts),
- ClaudeMapping: cloneStringMap(c.ClaudeMapping),
- ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
- ModelAliases: cloneStringMap(c.ModelAliases),
- Admin: c.Admin,
- Runtime: c.Runtime,
- Compat: CompatConfig{
- WideInputStrictOutput: cloneBoolPtr(c.Compat.WideInputStrictOutput),
- },
- Toolcall: c.Toolcall,
- Responses: c.Responses,
- Embeddings: c.Embeddings,
- VercelSyncHash: c.VercelSyncHash,
- VercelSyncTime: c.VercelSyncTime,
- AdditionalFields: map[string]any{},
- }
- for k, v := range c.AdditionalFields {
- clone.AdditionalFields[k] = v
- }
- return clone
-}
-
-func cloneStringMap(in map[string]string) map[string]string {
- if len(in) == 0 {
- return nil
- }
- out := make(map[string]string, len(in))
- for k, v := range in {
- out[k] = v
- }
- return out
-}
-
-func cloneBoolPtr(in *bool) *bool {
- if in == nil {
- return nil
- }
- v := *in
- return &v
-}
-
-type Store struct {
- mu sync.RWMutex
- cfg Config
- path string
- fromEnv bool
- keyMap map[string]struct{} // O(1) API key lookup index
- accMap map[string]int // O(1) account lookup: identifier -> slice index
-}
-
-func BaseDir() string {
- cwd, err := os.Getwd()
- if err != nil {
- return "."
- }
- return cwd
-}
-
-func IsVercel() bool {
- return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
-}
-
-func ResolvePath(envKey, defaultRel string) string {
- raw := strings.TrimSpace(os.Getenv(envKey))
- if raw != "" {
- if filepath.IsAbs(raw) {
- return raw
- }
- return filepath.Join(BaseDir(), raw)
- }
- return filepath.Join(BaseDir(), defaultRel)
-}
-
-func ConfigPath() string {
- return ResolvePath("DS2API_CONFIG_PATH", "config.json")
-}
-
-func WASMPath() string {
- return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
-}
-
-func StaticAdminDir() string {
- return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
-}
-
-func LoadStore() *Store {
- cfg, fromEnv, err := loadConfig()
- if err != nil {
- Logger.Warn("[config] load failed", "error", err)
- }
- if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
- Logger.Warn("[config] empty config loaded")
- }
- s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
- s.rebuildIndexes()
- return s
-}
-
-// rebuildIndexes must be called with the lock already held (or during init).
-func (s *Store) rebuildIndexes() {
- s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
- for _, k := range s.cfg.Keys {
- s.keyMap[k] = struct{}{}
- }
- s.accMap = make(map[string]int, len(s.cfg.Accounts))
- for i, acc := range s.cfg.Accounts {
- id := acc.Identifier()
- if id != "" {
- s.accMap[id] = i
- }
- }
-}
-
-func loadConfig() (Config, bool, error) {
- rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
- if rawCfg == "" {
- rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
- }
- if rawCfg != "" {
- cfg, err := parseConfigString(rawCfg)
- return cfg, true, err
- }
-
- content, err := os.ReadFile(ConfigPath())
- if err != nil {
- if IsVercel() {
- // Vercel one-click deploy may start without a writable/present config file.
- // Keep an in-memory config so users can bootstrap via WebUI then sync env.
- return Config{}, true, nil
- }
- return Config{}, false, err
- }
- var cfg Config
- if err := json.Unmarshal(content, &cfg); err != nil {
- return Config{}, false, err
- }
- if IsVercel() {
- // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
- return cfg, true, nil
- }
- return cfg, false, nil
-}
-
-func parseConfigString(raw string) (Config, error) {
- var cfg Config
- candidates := []string{raw}
- if normalized := normalizeConfigInput(raw); normalized != raw {
- candidates = append(candidates, normalized)
- }
- for _, candidate := range candidates {
- if err := json.Unmarshal([]byte(candidate), &cfg); err == nil {
- return cfg, nil
- }
- }
-
- base64Input := candidates[len(candidates)-1]
- decoded, err := decodeConfigBase64(base64Input)
- if err != nil {
- return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON: %w", err)
- }
- if err := json.Unmarshal(decoded, &cfg); err != nil {
- return Config{}, fmt.Errorf("invalid DS2API_CONFIG_JSON decoded JSON: %w", err)
- }
- return cfg, nil
-}
-
-func normalizeConfigInput(raw string) string {
- normalized := strings.TrimSpace(raw)
- if normalized == "" {
- return normalized
- }
- for {
- changed := false
- if len(normalized) >= 2 {
- first := normalized[0]
- last := normalized[len(normalized)-1]
- if (first == '"' && last == '"') || (first == '\'' && last == '\'') {
- normalized = strings.TrimSpace(normalized[1 : len(normalized)-1])
- changed = true
- }
- }
- if strings.HasPrefix(strings.ToLower(normalized), "base64:") {
- normalized = strings.TrimSpace(normalized[len("base64:"):])
- changed = true
- }
- if !changed {
- break
- }
- }
- return strings.TrimSpace(normalized)
-}
-
-func decodeConfigBase64(raw string) ([]byte, error) {
- encodings := []*base64.Encoding{
- base64.StdEncoding,
- base64.RawStdEncoding,
- base64.URLEncoding,
- base64.RawURLEncoding,
- }
- var lastErr error
- for _, enc := range encodings {
- decoded, err := enc.DecodeString(raw)
- if err == nil {
- return decoded, nil
- }
- lastErr = err
- }
- if lastErr != nil {
- return nil, lastErr
- }
- return nil, errors.New("base64 decode failed")
-}
-
-func (s *Store) Snapshot() Config {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.cfg.Clone()
-}
-
-func (s *Store) HasAPIKey(k string) bool {
- s.mu.RLock()
- defer s.mu.RUnlock()
- _, ok := s.keyMap[k]
- return ok
-}
-
-func (s *Store) Keys() []string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return slices.Clone(s.cfg.Keys)
-}
-
-func (s *Store) Accounts() []Account {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return slices.Clone(s.cfg.Accounts)
-}
-
-func (s *Store) FindAccount(identifier string) (Account, bool) {
- identifier = strings.TrimSpace(identifier)
- s.mu.RLock()
- defer s.mu.RUnlock()
- if idx, ok := s.findAccountIndexLocked(identifier); ok {
- return s.cfg.Accounts[idx], true
- }
- return Account{}, false
-}
-
-func (s *Store) UpdateAccountToken(identifier, token string) error {
- identifier = strings.TrimSpace(identifier)
- s.mu.Lock()
- defer s.mu.Unlock()
- idx, ok := s.findAccountIndexLocked(identifier)
- if !ok {
- return errors.New("account not found")
- }
- oldID := s.cfg.Accounts[idx].Identifier()
- s.cfg.Accounts[idx].Token = token
- newID := s.cfg.Accounts[idx].Identifier()
- // Keep historical aliases usable for long-lived queues while also adding
- // the latest identifier after token refresh.
- if identifier != "" {
- s.accMap[identifier] = idx
- }
- if oldID != "" {
- s.accMap[oldID] = idx
- }
- if newID != "" {
- s.accMap[newID] = idx
- }
- return s.saveLocked()
-}
-
-func (s *Store) Replace(cfg Config) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.cfg = cfg.Clone()
- s.rebuildIndexes()
- return s.saveLocked()
-}
-
-func (s *Store) Update(mutator func(*Config) error) error {
- s.mu.Lock()
- defer s.mu.Unlock()
- cfg := s.cfg.Clone()
- if err := mutator(&cfg); err != nil {
- return err
- }
- s.cfg = cfg
- s.rebuildIndexes()
- return s.saveLocked()
-}
-
-func (s *Store) Save() error {
- s.mu.Lock()
- defer s.mu.Unlock()
- if s.fromEnv {
- Logger.Info("[save_config] source from env, skip write")
- return nil
- }
- b, err := json.MarshalIndent(s.cfg, "", " ")
- if err != nil {
- return err
- }
- return os.WriteFile(s.path, b, 0o644)
-}
-
-func (s *Store) saveLocked() error {
- if s.fromEnv {
- Logger.Info("[save_config] source from env, skip write")
- return nil
- }
- b, err := json.MarshalIndent(s.cfg, "", " ")
- if err != nil {
- return err
- }
- return os.WriteFile(s.path, b, 0o644)
-}
-
-// findAccountIndexLocked expects the store lock to already be held.
-func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
- if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
- return idx, true
- }
- // Fallback for token-only accounts whose derived identifier changed after
- // a token refresh; this preserves correctness on map misses.
- for i, acc := range s.cfg.Accounts {
- if acc.Identifier() == identifier {
- return i, true
- }
- }
- return -1, false
-}
-
-func (s *Store) IsEnvBacked() bool {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.fromEnv
-}
-
-func (s *Store) SetVercelSync(hash string, ts int64) error {
- return s.Update(func(c *Config) error {
- c.VercelSyncHash = hash
- c.VercelSyncTime = ts
- return nil
- })
-}
-
-func (s *Store) ExportJSONAndBase64() (string, string, error) {
- s.mu.RLock()
- defer s.mu.RUnlock()
- b, err := json.Marshal(s.cfg)
- if err != nil {
- return "", "", err
- }
- return string(b), base64.StdEncoding.EncodeToString(b), nil
-}
-
-func (s *Store) ClaudeMapping() map[string]string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if len(s.cfg.ClaudeModelMap) > 0 {
- return cloneStringMap(s.cfg.ClaudeModelMap)
- }
- if len(s.cfg.ClaudeMapping) > 0 {
- return cloneStringMap(s.cfg.ClaudeMapping)
- }
- return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
-}
-
-func (s *Store) ModelAliases() map[string]string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- out := DefaultModelAliases()
- for k, v := range s.cfg.ModelAliases {
- key := strings.TrimSpace(lower(k))
- val := strings.TrimSpace(lower(v))
- if key == "" || val == "" {
- continue
- }
- out[key] = val
- }
- return out
-}
-
-func (s *Store) CompatWideInputStrictOutput() bool {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Compat.WideInputStrictOutput == nil {
- return true
- }
- return *s.cfg.Compat.WideInputStrictOutput
-}
-
-func (s *Store) ToolcallMode() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
- if mode == "" {
- return "feature_match"
- }
- return mode
-}
-
-func (s *Store) ToolcallEarlyEmitConfidence() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
- if level == "" {
- return "high"
- }
- return level
-}
-
-func (s *Store) ResponsesStoreTTLSeconds() int {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Responses.StoreTTLSeconds > 0 {
- return s.cfg.Responses.StoreTTLSeconds
- }
- return 900
-}
-
-func (s *Store) EmbeddingsProvider() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return strings.TrimSpace(s.cfg.Embeddings.Provider)
-}
-
-func (s *Store) AdminPasswordHash() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return strings.TrimSpace(s.cfg.Admin.PasswordHash)
-}
-
-func (s *Store) AdminJWTExpireHours() int {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Admin.JWTExpireHours > 0 {
- return s.cfg.Admin.JWTExpireHours
- }
- if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
- if n, err := strconv.Atoi(raw); err == nil && n > 0 {
- return n
- }
- }
- return 24
-}
-
-func (s *Store) AdminJWTValidAfterUnix() int64 {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.cfg.Admin.JWTValidAfterUnix
-}
-
-func (s *Store) RuntimeAccountMaxInflight() int {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Runtime.AccountMaxInflight > 0 {
- return s.cfg.Runtime.AccountMaxInflight
- }
- for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
- raw := strings.TrimSpace(os.Getenv(key))
- if raw == "" {
- continue
- }
- n, err := strconv.Atoi(raw)
- if err == nil && n > 0 {
- return n
- }
- }
- return 2
-}
-
-func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Runtime.AccountMaxQueue > 0 {
- return s.cfg.Runtime.AccountMaxQueue
- }
- for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
- raw := strings.TrimSpace(os.Getenv(key))
- if raw == "" {
- continue
- }
- n, err := strconv.Atoi(raw)
- if err == nil && n >= 0 {
- return n
- }
- }
- if defaultSize < 0 {
- return 0
- }
- return defaultSize
-}
-
-func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
- s.mu.RLock()
- defer s.mu.RUnlock()
- if s.cfg.Runtime.GlobalMaxInflight > 0 {
- return s.cfg.Runtime.GlobalMaxInflight
- }
- for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
- raw := strings.TrimSpace(os.Getenv(key))
- if raw == "" {
- continue
- }
- n, err := strconv.Atoi(raw)
- if err == nil && n > 0 {
- return n
- }
- }
- if defaultSize < 0 {
- return 0
- }
- return defaultSize
-}
diff --git a/internal/config/logger.go b/internal/config/logger.go
new file mode 100644
index 0000000..8b2de91
--- /dev/null
+++ b/internal/config/logger.go
@@ -0,0 +1,25 @@
+package config
+
+import (
+ "log/slog"
+ "os"
+ "strings"
+)
+
+var Logger = newLogger()
+
+func newLogger() *slog.Logger {
+ level := new(slog.LevelVar)
+ switch strings.ToUpper(strings.TrimSpace(os.Getenv("LOG_LEVEL"))) {
+ case "DEBUG":
+ level.Set(slog.LevelDebug)
+ case "WARN":
+ level.Set(slog.LevelWarn)
+ case "ERROR":
+ level.Set(slog.LevelError)
+ default:
+ level.Set(slog.LevelInfo)
+ }
+ h := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level})
+ return slog.New(h)
+}
diff --git a/internal/config/paths.go b/internal/config/paths.go
new file mode 100644
index 0000000..23dfe54
--- /dev/null
+++ b/internal/config/paths.go
@@ -0,0 +1,42 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+)
+
+func BaseDir() string {
+ cwd, err := os.Getwd()
+ if err != nil {
+ return "."
+ }
+ return cwd
+}
+
+func IsVercel() bool {
+ return strings.TrimSpace(os.Getenv("VERCEL")) != "" || strings.TrimSpace(os.Getenv("NOW_REGION")) != ""
+}
+
+func ResolvePath(envKey, defaultRel string) string {
+ raw := strings.TrimSpace(os.Getenv(envKey))
+ if raw != "" {
+ if filepath.IsAbs(raw) {
+ return raw
+ }
+ return filepath.Join(BaseDir(), raw)
+ }
+ return filepath.Join(BaseDir(), defaultRel)
+}
+
+func ConfigPath() string {
+ return ResolvePath("DS2API_CONFIG_PATH", "config.json")
+}
+
+func WASMPath() string {
+ return ResolvePath("DS2API_WASM_PATH", "sha3_wasm_bg.7b9ca65ddd.wasm")
+}
+
+func StaticAdminDir() string {
+ return ResolvePath("DS2API_STATIC_ADMIN_DIR", "static/admin")
+}
diff --git a/internal/config/store.go b/internal/config/store.go
new file mode 100644
index 0000000..2e6fcaf
--- /dev/null
+++ b/internal/config/store.go
@@ -0,0 +1,193 @@
+package config
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "os"
+ "slices"
+ "strings"
+ "sync"
+)
+
+type Store struct {
+ mu sync.RWMutex
+ cfg Config
+ path string
+ fromEnv bool
+ keyMap map[string]struct{} // O(1) API key lookup index
+ accMap map[string]int // O(1) account lookup: identifier -> slice index
+}
+
+func LoadStore() *Store {
+ cfg, fromEnv, err := loadConfig()
+ if err != nil {
+ Logger.Warn("[config] load failed", "error", err)
+ }
+ if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
+ Logger.Warn("[config] empty config loaded")
+ }
+ s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
+ s.rebuildIndexes()
+ return s
+}
+
+func loadConfig() (Config, bool, error) {
+ rawCfg := strings.TrimSpace(os.Getenv("DS2API_CONFIG_JSON"))
+ if rawCfg == "" {
+ rawCfg = strings.TrimSpace(os.Getenv("CONFIG_JSON"))
+ }
+ if rawCfg != "" {
+ cfg, err := parseConfigString(rawCfg)
+ return cfg, true, err
+ }
+
+ content, err := os.ReadFile(ConfigPath())
+ if err != nil {
+ if IsVercel() {
+ // Vercel one-click deploy may start without a writable/present config file.
+ // Keep an in-memory config so users can bootstrap via WebUI then sync env.
+ return Config{}, true, nil
+ }
+ return Config{}, false, err
+ }
+ var cfg Config
+ if err := json.Unmarshal(content, &cfg); err != nil {
+ return Config{}, false, err
+ }
+ if IsVercel() {
+ // Vercel filesystem is ephemeral/read-only for runtime writes; avoid save errors.
+ return cfg, true, nil
+ }
+ return cfg, false, nil
+}
+
+func (s *Store) Snapshot() Config {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.cfg.Clone()
+}
+
+func (s *Store) HasAPIKey(k string) bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ _, ok := s.keyMap[k]
+ return ok
+}
+
+func (s *Store) Keys() []string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return slices.Clone(s.cfg.Keys)
+}
+
+func (s *Store) Accounts() []Account {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return slices.Clone(s.cfg.Accounts)
+}
+
+func (s *Store) FindAccount(identifier string) (Account, bool) {
+ identifier = strings.TrimSpace(identifier)
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if idx, ok := s.findAccountIndexLocked(identifier); ok {
+ return s.cfg.Accounts[idx], true
+ }
+ return Account{}, false
+}
+
+func (s *Store) UpdateAccountToken(identifier, token string) error {
+ identifier = strings.TrimSpace(identifier)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ idx, ok := s.findAccountIndexLocked(identifier)
+ if !ok {
+ return errors.New("account not found")
+ }
+ oldID := s.cfg.Accounts[idx].Identifier()
+ s.cfg.Accounts[idx].Token = token
+ newID := s.cfg.Accounts[idx].Identifier()
+ // Keep historical aliases usable for long-lived queues while also adding
+ // the latest identifier after token refresh.
+ if identifier != "" {
+ s.accMap[identifier] = idx
+ }
+ if oldID != "" {
+ s.accMap[oldID] = idx
+ }
+ if newID != "" {
+ s.accMap[newID] = idx
+ }
+ return s.saveLocked()
+}
+
+func (s *Store) Replace(cfg Config) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.cfg = cfg.Clone()
+ s.rebuildIndexes()
+ return s.saveLocked()
+}
+
+func (s *Store) Update(mutator func(*Config) error) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ cfg := s.cfg.Clone()
+ if err := mutator(&cfg); err != nil {
+ return err
+ }
+ s.cfg = cfg
+ s.rebuildIndexes()
+ return s.saveLocked()
+}
+
+func (s *Store) Save() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.fromEnv {
+ Logger.Info("[save_config] source from env, skip write")
+ return nil
+ }
+ b, err := json.MarshalIndent(s.cfg, "", " ")
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(s.path, b, 0o644)
+}
+
+func (s *Store) saveLocked() error {
+ if s.fromEnv {
+ Logger.Info("[save_config] source from env, skip write")
+ return nil
+ }
+ b, err := json.MarshalIndent(s.cfg, "", " ")
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(s.path, b, 0o644)
+}
+
+func (s *Store) IsEnvBacked() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.fromEnv
+}
+
+func (s *Store) SetVercelSync(hash string, ts int64) error {
+ return s.Update(func(c *Config) error {
+ c.VercelSyncHash = hash
+ c.VercelSyncTime = ts
+ return nil
+ })
+}
+
+func (s *Store) ExportJSONAndBase64() (string, string, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ b, err := json.Marshal(s.cfg)
+ if err != nil {
+ return "", "", err
+ }
+ return string(b), base64.StdEncoding.EncodeToString(b), nil
+}
diff --git a/internal/config/store_accessors.go b/internal/config/store_accessors.go
new file mode 100644
index 0000000..f0c5938
--- /dev/null
+++ b/internal/config/store_accessors.go
@@ -0,0 +1,167 @@
+package config
+
+import (
+ "os"
+ "strconv"
+ "strings"
+)
+
+func (s *Store) ClaudeMapping() map[string]string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if len(s.cfg.ClaudeModelMap) > 0 {
+ return cloneStringMap(s.cfg.ClaudeModelMap)
+ }
+ if len(s.cfg.ClaudeMapping) > 0 {
+ return cloneStringMap(s.cfg.ClaudeMapping)
+ }
+ return map[string]string{"fast": "deepseek-chat", "slow": "deepseek-reasoner"}
+}
+
+func (s *Store) ModelAliases() map[string]string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ out := DefaultModelAliases()
+ for k, v := range s.cfg.ModelAliases {
+ key := strings.TrimSpace(lower(k))
+ val := strings.TrimSpace(lower(v))
+ if key == "" || val == "" {
+ continue
+ }
+ out[key] = val
+ }
+ return out
+}
+
+func (s *Store) CompatWideInputStrictOutput() bool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Compat.WideInputStrictOutput == nil {
+ return true
+ }
+ return *s.cfg.Compat.WideInputStrictOutput
+}
+
+func (s *Store) ToolcallMode() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ mode := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.Mode))
+ if mode == "" {
+ return "feature_match"
+ }
+ return mode
+}
+
+func (s *Store) ToolcallEarlyEmitConfidence() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ level := strings.TrimSpace(strings.ToLower(s.cfg.Toolcall.EarlyEmitConfidence))
+ if level == "" {
+ return "high"
+ }
+ return level
+}
+
+func (s *Store) ResponsesStoreTTLSeconds() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Responses.StoreTTLSeconds > 0 {
+ return s.cfg.Responses.StoreTTLSeconds
+ }
+ return 900
+}
+
+func (s *Store) EmbeddingsProvider() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return strings.TrimSpace(s.cfg.Embeddings.Provider)
+}
+
+func (s *Store) AdminPasswordHash() string {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return strings.TrimSpace(s.cfg.Admin.PasswordHash)
+}
+
+func (s *Store) AdminJWTExpireHours() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Admin.JWTExpireHours > 0 {
+ return s.cfg.Admin.JWTExpireHours
+ }
+ if raw := strings.TrimSpace(os.Getenv("DS2API_JWT_EXPIRE_HOURS")); raw != "" {
+ if n, err := strconv.Atoi(raw); err == nil && n > 0 {
+ return n
+ }
+ }
+ return 24
+}
+
+func (s *Store) AdminJWTValidAfterUnix() int64 {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.cfg.Admin.JWTValidAfterUnix
+}
+
+func (s *Store) RuntimeAccountMaxInflight() int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Runtime.AccountMaxInflight > 0 {
+ return s.cfg.Runtime.AccountMaxInflight
+ }
+ for _, key := range []string{"DS2API_ACCOUNT_MAX_INFLIGHT", "DS2API_ACCOUNT_CONCURRENCY"} {
+ raw := strings.TrimSpace(os.Getenv(key))
+ if raw == "" {
+ continue
+ }
+ n, err := strconv.Atoi(raw)
+ if err == nil && n > 0 {
+ return n
+ }
+ }
+ return 2
+}
+
+func (s *Store) RuntimeAccountMaxQueue(defaultSize int) int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Runtime.AccountMaxQueue > 0 {
+ return s.cfg.Runtime.AccountMaxQueue
+ }
+ for _, key := range []string{"DS2API_ACCOUNT_MAX_QUEUE", "DS2API_ACCOUNT_QUEUE_SIZE"} {
+ raw := strings.TrimSpace(os.Getenv(key))
+ if raw == "" {
+ continue
+ }
+ n, err := strconv.Atoi(raw)
+ if err == nil && n >= 0 {
+ return n
+ }
+ }
+ if defaultSize < 0 {
+ return 0
+ }
+ return defaultSize
+}
+
+func (s *Store) RuntimeGlobalMaxInflight(defaultSize int) int {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.cfg.Runtime.GlobalMaxInflight > 0 {
+ return s.cfg.Runtime.GlobalMaxInflight
+ }
+ for _, key := range []string{"DS2API_GLOBAL_MAX_INFLIGHT", "DS2API_MAX_INFLIGHT"} {
+ raw := strings.TrimSpace(os.Getenv(key))
+ if raw == "" {
+ continue
+ }
+ n, err := strconv.Atoi(raw)
+ if err == nil && n > 0 {
+ return n
+ }
+ }
+ if defaultSize < 0 {
+ return 0
+ }
+ return defaultSize
+}
diff --git a/internal/config/store_index.go b/internal/config/store_index.go
new file mode 100644
index 0000000..7d0f62a
--- /dev/null
+++ b/internal/config/store_index.go
@@ -0,0 +1,31 @@
+package config
+
+// rebuildIndexes must be called with the lock already held (or during init).
+func (s *Store) rebuildIndexes() {
+ s.keyMap = make(map[string]struct{}, len(s.cfg.Keys))
+ for _, k := range s.cfg.Keys {
+ s.keyMap[k] = struct{}{}
+ }
+ s.accMap = make(map[string]int, len(s.cfg.Accounts))
+ for i, acc := range s.cfg.Accounts {
+ id := acc.Identifier()
+ if id != "" {
+ s.accMap[id] = i
+ }
+ }
+}
+
+// findAccountIndexLocked expects the store lock to already be held.
+func (s *Store) findAccountIndexLocked(identifier string) (int, bool) {
+ if idx, ok := s.accMap[identifier]; ok && idx >= 0 && idx < len(s.cfg.Accounts) {
+ return idx, true
+ }
+ // Fallback for token-only accounts whose derived identifier changed after
+ // a token refresh; this preserves correctness on map misses.
+ for i, acc := range s.cfg.Accounts {
+ if acc.Identifier() == identifier {
+ return i, true
+ }
+ }
+ return -1, false
+}
diff --git a/internal/deepseek/client.go b/internal/deepseek/client.go
deleted file mode 100644
index 2ffe05d..0000000
--- a/internal/deepseek/client.go
+++ /dev/null
@@ -1,347 +0,0 @@
-package deepseek
-
-import (
- "bufio"
- "bytes"
- "compress/gzip"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
-
- "ds2api/internal/auth"
- "ds2api/internal/config"
- trans "ds2api/internal/deepseek/transport"
- "ds2api/internal/devcapture"
- "ds2api/internal/util"
-
- "github.com/andybalholm/brotli"
-)
-
-// intFrom is a package-internal alias for the shared util version.
-var intFrom = util.IntFrom
-
-type Client struct {
- Store *config.Store
- Auth *auth.Resolver
- capture *devcapture.Store
- regular trans.Doer
- stream trans.Doer
- fallback *http.Client
- fallbackS *http.Client
- powSolver *PowSolver
- maxRetries int
-}
-
-func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
- return &Client{
- Store: store,
- Auth: resolver,
- capture: devcapture.Global(),
- regular: trans.New(60 * time.Second),
- stream: trans.New(0),
- fallback: &http.Client{Timeout: 60 * time.Second},
- fallbackS: &http.Client{Timeout: 0},
- powSolver: NewPowSolver(config.WASMPath()),
- maxRetries: 3,
- }
-}
-
-func (c *Client) PreloadPow(ctx context.Context) error {
- return c.powSolver.init(ctx)
-}
-
-func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
- payload := map[string]any{
- "password": strings.TrimSpace(acc.Password),
- "device_id": "deepseek_to_api",
- "os": "android",
- }
- if email := strings.TrimSpace(acc.Email); email != "" {
- payload["email"] = email
- } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
- payload["mobile"] = mobile
- payload["area_code"] = nil
- } else {
- return "", errors.New("missing email/mobile")
- }
- resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
- if err != nil {
- return "", err
- }
- code := intFrom(resp["code"])
- if code != 0 {
- return "", fmt.Errorf("login failed: %v", resp["msg"])
- }
- data, _ := resp["data"].(map[string]any)
- if intFrom(data["biz_code"]) != 0 {
- return "", fmt.Errorf("login failed: %v", data["biz_msg"])
- }
- bizData, _ := data["biz_data"].(map[string]any)
- user, _ := bizData["user"].(map[string]any)
- token, _ := user["token"].(string)
- if strings.TrimSpace(token) == "" {
- return "", errors.New("missing login token")
- }
- return token, nil
-}
-
-func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
- if maxAttempts <= 0 {
- maxAttempts = c.maxRetries
- }
- attempts := 0
- refreshed := false
- for attempts < maxAttempts {
- headers := c.authHeaders(a.DeepSeekToken)
- resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
- if err != nil {
- config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
- attempts++
- continue
- }
- code := intFrom(resp["code"])
- if status == http.StatusOK && code == 0 {
- data, _ := resp["data"].(map[string]any)
- bizData, _ := data["biz_data"].(map[string]any)
- sessionID, _ := bizData["id"].(string)
- if sessionID != "" {
- return sessionID, nil
- }
- }
- msg, _ := resp["msg"].(string)
- config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
- if a.UseConfigToken {
- if isTokenInvalid(status, code, msg) && !refreshed {
- if c.Auth.RefreshToken(ctx, a) {
- refreshed = true
- continue
- }
- }
- if c.Auth.SwitchAccount(ctx, a) {
- refreshed = false
- attempts++
- continue
- }
- }
- attempts++
- }
- return "", errors.New("create session failed")
-}
-
-func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
- if maxAttempts <= 0 {
- maxAttempts = c.maxRetries
- }
- attempts := 0
- for attempts < maxAttempts {
- headers := c.authHeaders(a.DeepSeekToken)
- resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
- if err != nil {
- config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
- attempts++
- continue
- }
- code := intFrom(resp["code"])
- if status == http.StatusOK && code == 0 {
- data, _ := resp["data"].(map[string]any)
- bizData, _ := data["biz_data"].(map[string]any)
- challenge, _ := bizData["challenge"].(map[string]any)
- answer, err := c.powSolver.Compute(ctx, challenge)
- if err != nil {
- attempts++
- continue
- }
- return BuildPowHeader(challenge, answer)
- }
- msg, _ := resp["msg"].(string)
- config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
- if a.UseConfigToken {
- if isTokenInvalid(status, code, msg) {
- if c.Auth.RefreshToken(ctx, a) {
- continue
- }
- }
- if c.Auth.SwitchAccount(ctx, a) {
- attempts++
- continue
- }
- }
- attempts++
- }
- return "", errors.New("get pow failed")
-}
-
-func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
- if maxAttempts <= 0 {
- maxAttempts = c.maxRetries
- }
- headers := c.authHeaders(a.DeepSeekToken)
- headers["x-ds-pow-response"] = powResp
- captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
- attempts := 0
- for attempts < maxAttempts {
- resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
- if err != nil {
- attempts++
- time.Sleep(time.Second)
- continue
- }
- if resp.StatusCode == http.StatusOK {
- if captureSession != nil {
- resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
- }
- return resp, nil
- }
- if captureSession != nil {
- resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
- }
- _ = resp.Body.Close()
- attempts++
- time.Sleep(time.Second)
- }
- return nil, errors.New("completion failed")
-}
-
-func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
- body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
- if err != nil {
- return nil, err
- }
- if status == 0 {
- return nil, errors.New("request failed")
- }
- return body, nil
-}
-
-func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
- b, err := json.Marshal(payload)
- if err != nil {
- return nil, 0, err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
- if err != nil {
- return nil, 0, err
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- resp, err := doer.Do(req)
- if err != nil {
- config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
- req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
- if reqErr != nil {
- return nil, 0, err
- }
- for k, v := range headers {
- req2.Header.Set(k, v)
- }
- resp, err = c.fallback.Do(req2)
- if err != nil {
- return nil, 0, err
- }
- }
- defer resp.Body.Close()
- payloadBytes, err := readResponseBody(resp)
- if err != nil {
- return nil, resp.StatusCode, err
- }
- out := map[string]any{}
- if len(payloadBytes) > 0 {
- if err := json.Unmarshal(payloadBytes, &out); err != nil {
- config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
- }
- }
- return out, resp.StatusCode, nil
-}
-
-func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
- b, err := json.Marshal(payload)
- if err != nil {
- return nil, err
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
- if err != nil {
- return nil, err
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- resp, err := c.stream.Do(req)
- if err != nil {
- config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
- req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
- if reqErr != nil {
- return nil, err
- }
- for k, v := range headers {
- req2.Header.Set(k, v)
- }
- return c.fallbackS.Do(req2)
- }
- return resp, nil
-}
-
-func (c *Client) authHeaders(token string) map[string]string {
- headers := make(map[string]string, len(BaseHeaders)+1)
- for k, v := range BaseHeaders {
- headers[k] = v
- }
- headers["authorization"] = "Bearer " + token
- return headers
-}
-
-func isTokenInvalid(status int, code int, msg string) bool {
- msg = strings.ToLower(msg)
- if status == http.StatusUnauthorized || status == http.StatusForbidden {
- return true
- }
- if code == 40001 || code == 40002 || code == 40003 {
- return true
- }
- return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
-}
-
-func readResponseBody(resp *http.Response) ([]byte, error) {
- encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
- var reader io.Reader = resp.Body
- switch encoding {
- case "gzip":
- gz, err := gzip.NewReader(resp.Body)
- if err != nil {
- return nil, err
- }
- defer gz.Close()
- reader = gz
- case "br":
- reader = brotli.NewReader(resp.Body)
- }
- return io.ReadAll(reader)
-}
-
-func preview(b []byte) string {
- s := strings.TrimSpace(string(b))
- if len(s) > 160 {
- return s[:160]
- }
- return s
-}
-
-func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
- scanner := bufio.NewScanner(resp.Body)
- buf := make([]byte, 0, 64*1024)
- scanner.Buffer(buf, 2*1024*1024)
- for scanner.Scan() {
- if !onLine(scanner.Bytes()) {
- break
- }
- }
- if err := scanner.Err(); err != nil {
- return err
- }
- return nil
-}
diff --git a/internal/deepseek/client_auth.go b/internal/deepseek/client_auth.go
new file mode 100644
index 0000000..820acaf
--- /dev/null
+++ b/internal/deepseek/client_auth.go
@@ -0,0 +1,153 @@
+package deepseek
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "strings"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+)
+
+func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
+ payload := map[string]any{
+ "password": strings.TrimSpace(acc.Password),
+ "device_id": "deepseek_to_api",
+ "os": "android",
+ }
+ if email := strings.TrimSpace(acc.Email); email != "" {
+ payload["email"] = email
+ } else if mobile := strings.TrimSpace(acc.Mobile); mobile != "" {
+ payload["mobile"] = mobile
+ payload["area_code"] = nil
+ } else {
+ return "", errors.New("missing email/mobile")
+ }
+ resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
+ if err != nil {
+ return "", err
+ }
+ code := intFrom(resp["code"])
+ if code != 0 {
+ return "", fmt.Errorf("login failed: %v", resp["msg"])
+ }
+ data, _ := resp["data"].(map[string]any)
+ if intFrom(data["biz_code"]) != 0 {
+ return "", fmt.Errorf("login failed: %v", data["biz_msg"])
+ }
+ bizData, _ := data["biz_data"].(map[string]any)
+ user, _ := bizData["user"].(map[string]any)
+ token, _ := user["token"].(string)
+ if strings.TrimSpace(token) == "" {
+ return "", errors.New("missing login token")
+ }
+ return token, nil
+}
+
+func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
+ if maxAttempts <= 0 {
+ maxAttempts = c.maxRetries
+ }
+ attempts := 0
+ refreshed := false
+ for attempts < maxAttempts {
+ headers := c.authHeaders(a.DeepSeekToken)
+ resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
+ if err != nil {
+ config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
+ attempts++
+ continue
+ }
+ code := intFrom(resp["code"])
+ if status == http.StatusOK && code == 0 {
+ data, _ := resp["data"].(map[string]any)
+ bizData, _ := data["biz_data"].(map[string]any)
+ sessionID, _ := bizData["id"].(string)
+ if sessionID != "" {
+ return sessionID, nil
+ }
+ }
+ msg, _ := resp["msg"].(string)
+ config.Logger.Warn("[create_session] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
+ if a.UseConfigToken {
+ if isTokenInvalid(status, code, msg) && !refreshed {
+ if c.Auth.RefreshToken(ctx, a) {
+ refreshed = true
+ continue
+ }
+ }
+ if c.Auth.SwitchAccount(ctx, a) {
+ refreshed = false
+ attempts++
+ continue
+ }
+ }
+ attempts++
+ }
+ return "", errors.New("create session failed")
+}
+
+func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts int) (string, error) {
+ if maxAttempts <= 0 {
+ maxAttempts = c.maxRetries
+ }
+ attempts := 0
+ for attempts < maxAttempts {
+ headers := c.authHeaders(a.DeepSeekToken)
+ resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
+ if err != nil {
+ config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
+ attempts++
+ continue
+ }
+ code := intFrom(resp["code"])
+ if status == http.StatusOK && code == 0 {
+ data, _ := resp["data"].(map[string]any)
+ bizData, _ := data["biz_data"].(map[string]any)
+ challenge, _ := bizData["challenge"].(map[string]any)
+ answer, err := c.powSolver.Compute(ctx, challenge)
+ if err != nil {
+ attempts++
+ continue
+ }
+ return BuildPowHeader(challenge, answer)
+ }
+ msg, _ := resp["msg"].(string)
+ config.Logger.Warn("[get_pow] failed", "status", status, "code", code, "msg", msg, "use_config_token", a.UseConfigToken, "account", a.AccountID)
+ if a.UseConfigToken {
+ if isTokenInvalid(status, code, msg) {
+ if c.Auth.RefreshToken(ctx, a) {
+ continue
+ }
+ }
+ if c.Auth.SwitchAccount(ctx, a) {
+ attempts++
+ continue
+ }
+ }
+ attempts++
+ }
+ return "", errors.New("get pow failed")
+}
+
+func (c *Client) authHeaders(token string) map[string]string {
+ headers := make(map[string]string, len(BaseHeaders)+1)
+ for k, v := range BaseHeaders {
+ headers[k] = v
+ }
+ headers["authorization"] = "Bearer " + token
+ return headers
+}
+
+func isTokenInvalid(status int, code int, msg string) bool {
+ msg = strings.ToLower(msg)
+ if status == http.StatusUnauthorized || status == http.StatusForbidden {
+ return true
+ }
+ if code == 40001 || code == 40002 || code == 40003 {
+ return true
+ }
+ return strings.Contains(msg, "token") || strings.Contains(msg, "unauthorized")
+}
diff --git a/internal/deepseek/client_completion.go b/internal/deepseek/client_completion.go
new file mode 100644
index 0000000..051bffe
--- /dev/null
+++ b/internal/deepseek/client_completion.go
@@ -0,0 +1,71 @@
+package deepseek
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "time"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+)
+
+func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
+ if maxAttempts <= 0 {
+ maxAttempts = c.maxRetries
+ }
+ headers := c.authHeaders(a.DeepSeekToken)
+ headers["x-ds-pow-response"] = powResp
+ captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
+ attempts := 0
+ for attempts < maxAttempts {
+ resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
+ if err != nil {
+ attempts++
+ time.Sleep(time.Second)
+ continue
+ }
+ if resp.StatusCode == http.StatusOK {
+ if captureSession != nil {
+ resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
+ }
+ return resp, nil
+ }
+ if captureSession != nil {
+ resp.Body = captureSession.WrapBody(resp.Body, resp.StatusCode)
+ }
+ _ = resp.Body.Close()
+ attempts++
+ time.Sleep(time.Second)
+ }
+ return nil, errors.New("completion failed")
+}
+
+func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
+ b, err := json.Marshal(payload)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
+ if err != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ resp, err := c.stream.Do(req)
+ if err != nil {
+ config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
+ req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
+ if reqErr != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ req2.Header.Set(k, v)
+ }
+ return c.fallbackS.Do(req2)
+ }
+ return resp, nil
+}
diff --git a/internal/deepseek/client_core.go b/internal/deepseek/client_core.go
new file mode 100644
index 0000000..cda6edc
--- /dev/null
+++ b/internal/deepseek/client_core.go
@@ -0,0 +1,46 @@
+package deepseek
+
+import (
+ "context"
+ "net/http"
+ "time"
+
+ "ds2api/internal/auth"
+ "ds2api/internal/config"
+ trans "ds2api/internal/deepseek/transport"
+ "ds2api/internal/devcapture"
+ "ds2api/internal/util"
+)
+
+// intFrom is a package-internal alias for the shared util version.
+var intFrom = util.IntFrom
+
+type Client struct {
+ Store *config.Store
+ Auth *auth.Resolver
+ capture *devcapture.Store
+ regular trans.Doer
+ stream trans.Doer
+ fallback *http.Client
+ fallbackS *http.Client
+ powSolver *PowSolver
+ maxRetries int
+}
+
+func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
+ return &Client{
+ Store: store,
+ Auth: resolver,
+ capture: devcapture.Global(),
+ regular: trans.New(60 * time.Second),
+ stream: trans.New(0),
+ fallback: &http.Client{Timeout: 60 * time.Second},
+ fallbackS: &http.Client{Timeout: 0},
+ powSolver: NewPowSolver(config.WASMPath()),
+ maxRetries: 3,
+ }
+}
+
+func (c *Client) PreloadPow(ctx context.Context) error {
+ return c.powSolver.init(ctx)
+}
diff --git a/internal/deepseek/client_http_helpers.go b/internal/deepseek/client_http_helpers.go
new file mode 100644
index 0000000..05de224
--- /dev/null
+++ b/internal/deepseek/client_http_helpers.go
@@ -0,0 +1,51 @@
+package deepseek
+
+import (
+ "bufio"
+ "compress/gzip"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/andybalholm/brotli"
+)
+
+func readResponseBody(resp *http.Response) ([]byte, error) {
+ encoding := strings.ToLower(strings.TrimSpace(resp.Header.Get("Content-Encoding")))
+ var reader io.Reader = resp.Body
+ switch encoding {
+ case "gzip":
+ gz, err := gzip.NewReader(resp.Body)
+ if err != nil {
+ return nil, err
+ }
+ defer gz.Close()
+ reader = gz
+ case "br":
+ reader = brotli.NewReader(resp.Body)
+ }
+ return io.ReadAll(reader)
+}
+
+func preview(b []byte) string {
+ s := strings.TrimSpace(string(b))
+ if len(s) > 160 {
+ return s[:160]
+ }
+ return s
+}
+
+func ScanSSELines(resp *http.Response, onLine func([]byte) bool) error {
+ scanner := bufio.NewScanner(resp.Body)
+ buf := make([]byte, 0, 64*1024)
+ scanner.Buffer(buf, 2*1024*1024)
+ for scanner.Scan() {
+ if !onLine(scanner.Bytes()) {
+ break
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/internal/deepseek/client_http_json.go b/internal/deepseek/client_http_json.go
new file mode 100644
index 0000000..6d3599d
--- /dev/null
+++ b/internal/deepseek/client_http_json.go
@@ -0,0 +1,64 @@
+package deepseek
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+
+ "ds2api/internal/config"
+ trans "ds2api/internal/deepseek/transport"
+)
+
+func (c *Client) postJSON(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, error) {
+ body, status, err := c.postJSONWithStatus(ctx, doer, url, headers, payload)
+ if err != nil {
+ return nil, err
+ }
+ if status == 0 {
+ return nil, errors.New("request failed")
+ }
+ return body, nil
+}
+
+func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (map[string]any, int, error) {
+ b, err := json.Marshal(payload)
+ if err != nil {
+ return nil, 0, err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
+ if err != nil {
+ return nil, 0, err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ resp, err := doer.Do(req)
+ if err != nil {
+ config.Logger.Warn("[deepseek] fingerprint request failed, fallback to std transport", "url", url, "error", err)
+ req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
+ if reqErr != nil {
+ return nil, 0, err
+ }
+ for k, v := range headers {
+ req2.Header.Set(k, v)
+ }
+ resp, err = c.fallback.Do(req2)
+ if err != nil {
+ return nil, 0, err
+ }
+ }
+ defer resp.Body.Close()
+ payloadBytes, err := readResponseBody(resp)
+ if err != nil {
+ return nil, resp.StatusCode, err
+ }
+ out := map[string]any{}
+ if len(payloadBytes) > 0 {
+ if err := json.Unmarshal(payloadBytes, &out); err != nil {
+ config.Logger.Warn("[deepseek] json parse failed", "url", url, "status", resp.StatusCode, "content_encoding", resp.Header.Get("Content-Encoding"), "preview", preview(payloadBytes))
+ }
+ }
+ return out, resp.StatusCode, nil
+}
diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go
deleted file mode 100644
index 2107d4e..0000000
--- a/internal/format/openai/render.go
+++ /dev/null
@@ -1,307 +0,0 @@
-package openai
-
-import (
- "encoding/json"
- "strings"
- "time"
-
- "github.com/google/uuid"
-
- "ds2api/internal/util"
-)
-
-func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
- detected := util.ParseToolCalls(finalText, toolNames)
- finishReason := "stop"
- messageObj := map[string]any{"role": "assistant", "content": finalText}
- if strings.TrimSpace(finalThinking) != "" {
- messageObj["reasoning_content"] = finalThinking
- }
- if len(detected) > 0 {
- finishReason = "tool_calls"
- messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
- messageObj["content"] = nil
- }
- promptTokens := util.EstimateTokens(finalPrompt)
- reasoningTokens := util.EstimateTokens(finalThinking)
- completionTokens := util.EstimateTokens(finalText)
-
- return map[string]any{
- "id": completionID,
- "object": "chat.completion",
- "created": time.Now().Unix(),
- "model": model,
- "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
- "usage": map[string]any{
- "prompt_tokens": promptTokens,
- "completion_tokens": reasoningTokens + completionTokens,
- "total_tokens": promptTokens + reasoningTokens + completionTokens,
- "completion_tokens_details": map[string]any{
- "reasoning_tokens": reasoningTokens,
- },
- },
- }
-}
-
-func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
- // Align responses tool-call semantics with chat/completions:
- // mixed prose + tool_call payloads should still be interpreted as tool calls.
- detected := util.ParseToolCalls(finalText, toolNames)
- if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
- detected = util.ParseToolCalls(finalThinking, toolNames)
- }
- exposedOutputText := finalText
- output := make([]any, 0, 2)
- if len(detected) > 0 {
- exposedOutputText = ""
- if strings.TrimSpace(finalThinking) != "" {
- output = append(output, map[string]any{
- "type": "reasoning",
- "text": finalThinking,
- })
- }
- formatted := util.FormatOpenAIToolCalls(detected)
- output = append(output, toResponsesFunctionCallItems(formatted)...)
- output = append(output, map[string]any{
- "type": "tool_calls",
- "tool_calls": formatted,
- })
- } else {
- content := make([]any, 0, 2)
- if finalThinking != "" {
- content = append([]any{map[string]any{
- "type": "reasoning",
- "text": finalThinking,
- }}, content...)
- }
- if strings.TrimSpace(finalText) != "" {
- content = append(content, map[string]any{
- "type": "output_text",
- "text": finalText,
- })
- }
- if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
- exposedOutputText = finalThinking
- }
- output = append(output, map[string]any{
- "type": "message",
- "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
- "role": "assistant",
- "content": content,
- })
- }
- promptTokens := util.EstimateTokens(finalPrompt)
- reasoningTokens := util.EstimateTokens(finalThinking)
- completionTokens := util.EstimateTokens(finalText)
- return map[string]any{
- "id": responseID,
- "type": "response",
- "object": "response",
- "created_at": time.Now().Unix(),
- "status": "completed",
- "model": model,
- "output": output,
- "output_text": exposedOutputText,
- "usage": map[string]any{
- "input_tokens": promptTokens,
- "output_tokens": reasoningTokens + completionTokens,
- "total_tokens": promptTokens + reasoningTokens + completionTokens,
- },
- }
-}
-
-func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
- if len(toolCalls) == 0 {
- return nil
- }
- out := make([]any, 0, len(toolCalls))
- for _, tc := range toolCalls {
- callID, _ := tc["id"].(string)
- if strings.TrimSpace(callID) == "" {
- callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
- }
- name := ""
- args := "{}"
- if fn, ok := tc["function"].(map[string]any); ok {
- if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
- name = n
- }
- if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
- args = a
- }
- }
- out = append(out, map[string]any{
- "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
- "type": "function_call",
- "call_id": callID,
- "name": name,
- "arguments": normalizeJSONString(args),
- "status": "completed",
- })
- }
- return out
-}
-
-func normalizeJSONString(raw string) string {
- s := strings.TrimSpace(raw)
- if s == "" {
- return "{}"
- }
- var v any
- if err := json.Unmarshal([]byte(s), &v); err != nil {
- return raw
- }
- b, err := json.Marshal(v)
- if err != nil {
- return raw
- }
- return string(b)
-}
-
-func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
- return map[string]any{
- "delta": delta,
- "index": index,
- }
-}
-
-func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
- return map[string]any{
- "delta": map[string]any{},
- "index": index,
- "finish_reason": finishReason,
- }
-}
-
-func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
- out := map[string]any{
- "id": completionID,
- "object": "chat.completion.chunk",
- "created": created,
- "model": model,
- "choices": choices,
- }
- if len(usage) > 0 {
- out["usage"] = usage
- }
- return out
-}
-
-func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
- promptTokens := util.EstimateTokens(finalPrompt)
- reasoningTokens := util.EstimateTokens(finalThinking)
- completionTokens := util.EstimateTokens(finalText)
- return map[string]any{
- "prompt_tokens": promptTokens,
- "completion_tokens": reasoningTokens + completionTokens,
- "total_tokens": promptTokens + reasoningTokens + completionTokens,
- "completion_tokens_details": map[string]any{
- "reasoning_tokens": reasoningTokens,
- },
- }
-}
-
-func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
- return map[string]any{
- "type": "response.created",
- "id": responseID,
- "response_id": responseID,
- "object": "response",
- "model": model,
- "status": "in_progress",
- }
-}
-
-func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
- return map[string]any{
- "type": "response.output_text.delta",
- "id": responseID,
- "response_id": responseID,
- "delta": delta,
- }
-}
-
-func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
- return map[string]any{
- "type": "response.reasoning.delta",
- "id": responseID,
- "response_id": responseID,
- "delta": delta,
- }
-}
-
-func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
- return map[string]any{
- "type": "response.reasoning_text.delta",
- "id": responseID,
- "response_id": responseID,
- "item_id": itemID,
- "output_index": outputIndex,
- "content_index": contentIndex,
- "delta": delta,
- }
-}
-
-func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
- return map[string]any{
- "type": "response.reasoning_text.done",
- "id": responseID,
- "response_id": responseID,
- "item_id": itemID,
- "output_index": outputIndex,
- "content_index": contentIndex,
- "text": text,
- }
-}
-
-func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
- return map[string]any{
- "type": "response.output_tool_call.delta",
- "id": responseID,
- "response_id": responseID,
- "tool_calls": toolCalls,
- }
-}
-
-func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
- return map[string]any{
- "type": "response.output_tool_call.done",
- "id": responseID,
- "response_id": responseID,
- "tool_calls": toolCalls,
- }
-}
-
-func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
- return map[string]any{
- "type": "response.function_call_arguments.delta",
- "id": responseID,
- "response_id": responseID,
- "item_id": itemID,
- "output_index": outputIndex,
- "call_id": callID,
- "delta": delta,
- }
-}
-
-func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
- return map[string]any{
- "type": "response.function_call_arguments.done",
- "id": responseID,
- "response_id": responseID,
- "item_id": itemID,
- "output_index": outputIndex,
- "call_id": callID,
- "name": name,
- "arguments": normalizeJSONString(arguments),
- }
-}
-
-func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
- responseID, _ := response["id"].(string)
- return map[string]any{
- "type": "response.completed",
- "response_id": responseID,
- "response": response,
- }
-}
diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go
new file mode 100644
index 0000000..1e58fbd
--- /dev/null
+++ b/internal/format/openai/render_chat.go
@@ -0,0 +1,60 @@
+package openai
+
+import (
+ "strings"
+ "time"
+
+ "ds2api/internal/util"
+)
+
+func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
+ detected := util.ParseToolCalls(finalText, toolNames)
+ finishReason := "stop"
+ messageObj := map[string]any{"role": "assistant", "content": finalText}
+ if strings.TrimSpace(finalThinking) != "" {
+ messageObj["reasoning_content"] = finalThinking
+ }
+ if len(detected) > 0 {
+ finishReason = "tool_calls"
+ messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected)
+ messageObj["content"] = nil
+ }
+
+ return map[string]any{
+ "id": completionID,
+ "object": "chat.completion",
+ "created": time.Now().Unix(),
+ "model": model,
+ "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
+ "usage": BuildChatUsage(finalPrompt, finalThinking, finalText),
+ }
+}
+
+func BuildChatStreamDeltaChoice(index int, delta map[string]any) map[string]any {
+ return map[string]any{
+ "delta": delta,
+ "index": index,
+ }
+}
+
+func BuildChatStreamFinishChoice(index int, finishReason string) map[string]any {
+ return map[string]any{
+ "delta": map[string]any{},
+ "index": index,
+ "finish_reason": finishReason,
+ }
+}
+
+func BuildChatStreamChunk(completionID string, created int64, model string, choices []map[string]any, usage map[string]any) map[string]any {
+ out := map[string]any{
+ "id": completionID,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model,
+ "choices": choices,
+ }
+ if len(usage) > 0 {
+ out["usage"] = usage
+ }
+ return out
+}
diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go
new file mode 100644
index 0000000..4fd17c3
--- /dev/null
+++ b/internal/format/openai/render_responses.go
@@ -0,0 +1,119 @@
+package openai
+
+import (
+ "encoding/json"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+
+ "ds2api/internal/util"
+)
+
+func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
+ // Align responses tool-call semantics with chat/completions:
+ // mixed prose + tool_call payloads should still be interpreted as tool calls.
+ detected := util.ParseToolCalls(finalText, toolNames)
+ if len(detected) == 0 && strings.TrimSpace(finalThinking) != "" {
+ detected = util.ParseToolCalls(finalThinking, toolNames)
+ }
+ exposedOutputText := finalText
+ output := make([]any, 0, 2)
+ if len(detected) > 0 {
+ exposedOutputText = ""
+ if strings.TrimSpace(finalThinking) != "" {
+ output = append(output, map[string]any{
+ "type": "reasoning",
+ "text": finalThinking,
+ })
+ }
+ formatted := util.FormatOpenAIToolCalls(detected)
+ output = append(output, toResponsesFunctionCallItems(formatted)...)
+ output = append(output, map[string]any{
+ "type": "tool_calls",
+ "tool_calls": formatted,
+ })
+ } else {
+ content := make([]any, 0, 2)
+ if finalThinking != "" {
+ content = append([]any{map[string]any{
+ "type": "reasoning",
+ "text": finalThinking,
+ }}, content...)
+ }
+ if strings.TrimSpace(finalText) != "" {
+ content = append(content, map[string]any{
+ "type": "output_text",
+ "text": finalText,
+ })
+ }
+ if strings.TrimSpace(finalText) == "" && strings.TrimSpace(finalThinking) != "" {
+ exposedOutputText = finalThinking
+ }
+ output = append(output, map[string]any{
+ "type": "message",
+ "id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ "role": "assistant",
+ "content": content,
+ })
+ }
+ return map[string]any{
+ "id": responseID,
+ "type": "response",
+ "object": "response",
+ "created_at": time.Now().Unix(),
+ "status": "completed",
+ "model": model,
+ "output": output,
+ "output_text": exposedOutputText,
+ "usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText),
+ }
+}
+
+func toResponsesFunctionCallItems(toolCalls []map[string]any) []any {
+ if len(toolCalls) == 0 {
+ return nil
+ }
+ out := make([]any, 0, len(toolCalls))
+ for _, tc := range toolCalls {
+ callID, _ := tc["id"].(string)
+ if strings.TrimSpace(callID) == "" {
+ callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
+ }
+ name := ""
+ args := "{}"
+ if fn, ok := tc["function"].(map[string]any); ok {
+ if n, _ := fn["name"].(string); strings.TrimSpace(n) != "" {
+ name = n
+ }
+ if a, _ := fn["arguments"].(string); strings.TrimSpace(a) != "" {
+ args = a
+ }
+ }
+ out = append(out, map[string]any{
+ "id": "fc_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ "type": "function_call",
+ "call_id": callID,
+ "name": name,
+ "arguments": normalizeJSONString(args),
+ "status": "completed",
+ })
+ }
+ return out
+}
+
+func normalizeJSONString(raw string) string {
+ s := strings.TrimSpace(raw)
+ if s == "" {
+ return "{}"
+ }
+ var v any
+ if err := json.Unmarshal([]byte(s), &v); err != nil {
+ return raw
+ }
+ b, err := json.Marshal(v)
+ if err != nil {
+ return raw
+ }
+ return string(b)
+}
diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go
new file mode 100644
index 0000000..cc62604
--- /dev/null
+++ b/internal/format/openai/render_stream_events.go
@@ -0,0 +1,106 @@
+package openai
+
+func BuildResponsesCreatedPayload(responseID, model string) map[string]any {
+ return map[string]any{
+ "type": "response.created",
+ "id": responseID,
+ "response_id": responseID,
+ "object": "response",
+ "model": model,
+ "status": "in_progress",
+ }
+}
+
+func BuildResponsesTextDeltaPayload(responseID, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.output_text.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.reasoning.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesReasoningTextDeltaPayload(responseID, itemID string, outputIndex, contentIndex int, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.reasoning_text.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "content_index": contentIndex,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesReasoningTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any {
+ return map[string]any{
+ "type": "response.reasoning_text.done",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "content_index": contentIndex,
+ "text": text,
+ }
+}
+
+func BuildResponsesToolCallDeltaPayload(responseID string, toolCalls []map[string]any) map[string]any {
+ return map[string]any{
+ "type": "response.output_tool_call.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "tool_calls": toolCalls,
+ }
+}
+
+func BuildResponsesToolCallDonePayload(responseID string, toolCalls []map[string]any) map[string]any {
+ return map[string]any{
+ "type": "response.output_tool_call.done",
+ "id": responseID,
+ "response_id": responseID,
+ "tool_calls": toolCalls,
+ }
+}
+
+func BuildResponsesFunctionCallArgumentsDeltaPayload(responseID, itemID string, outputIndex int, callID, delta string) map[string]any {
+ return map[string]any{
+ "type": "response.function_call_arguments.delta",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "call_id": callID,
+ "delta": delta,
+ }
+}
+
+func BuildResponsesFunctionCallArgumentsDonePayload(responseID, itemID string, outputIndex int, callID, name, arguments string) map[string]any {
+ return map[string]any{
+ "type": "response.function_call_arguments.done",
+ "id": responseID,
+ "response_id": responseID,
+ "item_id": itemID,
+ "output_index": outputIndex,
+ "call_id": callID,
+ "name": name,
+ "arguments": normalizeJSONString(arguments),
+ }
+}
+
+func BuildResponsesCompletedPayload(response map[string]any) map[string]any {
+ responseID, _ := response["id"].(string)
+ return map[string]any{
+ "type": "response.completed",
+ "response_id": responseID,
+ "response": response,
+ }
+}
diff --git a/internal/format/openai/render_usage.go b/internal/format/openai/render_usage.go
new file mode 100644
index 0000000..b328d20
--- /dev/null
+++ b/internal/format/openai/render_usage.go
@@ -0,0 +1,28 @@
+package openai
+
+import "ds2api/internal/util"
+
+func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
+ promptTokens := util.EstimateTokens(finalPrompt)
+ reasoningTokens := util.EstimateTokens(finalThinking)
+ completionTokens := util.EstimateTokens(finalText)
+ return map[string]any{
+ "prompt_tokens": promptTokens,
+ "completion_tokens": reasoningTokens + completionTokens,
+ "total_tokens": promptTokens + reasoningTokens + completionTokens,
+ "completion_tokens_details": map[string]any{
+ "reasoning_tokens": reasoningTokens,
+ },
+ }
+}
+
+func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any {
+ promptTokens := util.EstimateTokens(finalPrompt)
+ reasoningTokens := util.EstimateTokens(finalThinking)
+ completionTokens := util.EstimateTokens(finalText)
+ return map[string]any{
+ "input_tokens": promptTokens,
+ "output_tokens": reasoningTokens + completionTokens,
+ "total_tokens": promptTokens + reasoningTokens + completionTokens,
+ }
+}
diff --git a/internal/testsuite/edge_cases.go b/internal/testsuite/edge_cases.go
index cba0b5a..50bc8ac 100644
--- a/internal/testsuite/edge_cases.go
+++ b/internal/testsuite/edge_cases.go
@@ -1,7 +1,6 @@
package testsuite
import (
- "bytes"
"context"
"encoding/json"
"fmt"
@@ -125,72 +124,6 @@ func (r *Runner) caseStreamAbortRelease(ctx context.Context, cc *caseContext) er
return nil
}
-func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
- cc.seq++
- traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
- cc.traceIDsSet[traceID] = struct{}{}
- fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
- if err != nil {
- return err
- }
- headers := map[string]string{}
- for k, v := range spec.Headers {
- headers[k] = v
- }
- headers["X-Ds2-Test-Trace"] = traceID
- bodyBytes, _ := json.Marshal(spec.Body)
- headers["Content-Type"] = "application/json"
- cc.requests = append(cc.requests, requestLog{
- Seq: cc.seq,
- Attempt: 1,
- TraceID: traceID,
- Method: spec.Method,
- URL: fullURL,
- Headers: headers,
- Body: spec.Body,
- Timestamp: time.Now().Format(time.RFC3339Nano),
- })
-
- reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
- defer cancel()
- req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
- if err != nil {
- return err
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- start := time.Now()
- resp, err := cc.runner.httpClient.Do(req)
- if err != nil {
- cc.responses = append(cc.responses, responseLog{
- Seq: cc.seq,
- Attempt: 1,
- TraceID: traceID,
- StatusCode: 0,
- DurationMS: time.Since(start).Milliseconds(),
- NetworkErr: err.Error(),
- ReceivedAt: time.Now().Format(time.RFC3339Nano),
- })
- return err
- }
- defer resp.Body.Close()
- buf := make([]byte, 512)
- _, _ = resp.Body.Read(buf)
- _ = resp.Body.Close()
- cc.responses = append(cc.responses, responseLog{
- Seq: cc.seq,
- Attempt: 1,
- TraceID: traceID,
- StatusCode: resp.StatusCode,
- Headers: resp.Header,
- BodyText: "aborted_after_first_chunk",
- DurationMS: time.Since(start).Milliseconds(),
- ReceivedAt: time.Now().Format(time.RFC3339Nano),
- })
- return nil
-}
-
func (r *Runner) caseToolcallStreamMixed(ctx context.Context, cc *caseContext) error {
payload := toolcallPayload(true)
payload["messages"] = []map[string]any{
@@ -293,167 +226,6 @@ func (r *Runner) caseSSEJSONIntegrity(ctx context.Context, cc *caseContext) erro
return nil
}
-func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
- resp, err := cc.requestOnce(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-not-exists",
- "messages": []map[string]any{
- {"role": "user", "content": "hi"},
- },
- "stream": false,
- },
- Retryable: false,
- }, 1)
- if err != nil {
- return err
- }
- cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- e, _ := m["error"].(map[string]any)
- cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "stream": false,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- e, _ := m["error"].(map[string]any)
- cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/config",
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
- return nil
-}
-
-func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
- if len(r.configRaw.Accounts) == 0 {
- cc.assert("account_present", false, "no account in config")
- return nil
- }
- acc := r.configRaw.Accounts[0]
- id := strings.TrimSpace(acc.Email)
- if id == "" {
- id = strings.TrimSpace(acc.Mobile)
- }
- if id == "" {
- cc.assert("account_identifier", false, "first account has no identifier")
- return nil
- }
- if strings.TrimSpace(acc.Password) == "" {
- r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
- cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
- return nil
- }
- invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
- update := map[string]any{
- "keys": r.configRaw.Keys,
- "accounts": []map[string]any{
- {
- "email": acc.Email,
- "mobile": acc.Mobile,
- "password": acc.Password,
- "token": invalidToken,
- },
- },
- }
- updResp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/admin/config",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Body: update,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
-
- chatResp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- "X-Ds2-Target-Account": id,
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {"role": "user", "content": "token refresh test"},
- },
- "stream": false,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
-
- cfgResp, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/config",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- var cfg map[string]any
- _ = json.Unmarshal(cfgResp.Body, &cfg)
- accounts, _ := cfg["accounts"].([]any)
- preview := ""
- hasToken := false
- for _, item := range accounts {
- m, _ := item.(map[string]any)
- e := asString(m["email"])
- mo := asString(m["mobile"])
- if e == acc.Email && mo == acc.Mobile {
- preview = asString(m["token_preview"])
- hasToken, _ = m["has_token"].(bool)
- break
- }
- }
- cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
- cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
- return nil
-}
-
func (r *Runner) fetchQueueStatus(ctx context.Context, cc *caseContext) (map[string]any, error) {
resp, err := cc.request(ctx, requestSpec{
Method: http.MethodGet,
diff --git a/internal/testsuite/edge_cases_abort.go b/internal/testsuite/edge_cases_abort.go
new file mode 100644
index 0000000..2cc1fc1
--- /dev/null
+++ b/internal/testsuite/edge_cases_abort.go
@@ -0,0 +1,76 @@
+package testsuite
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+func (cc *caseContext) abortStreamRequest(ctx context.Context, spec requestSpec) error {
+ cc.seq++
+ traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), cc.seq)
+ cc.traceIDsSet[traceID] = struct{}{}
+ fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
+ if err != nil {
+ return err
+ }
+ headers := map[string]string{}
+ for k, v := range spec.Headers {
+ headers[k] = v
+ }
+ headers["X-Ds2-Test-Trace"] = traceID
+ bodyBytes, _ := json.Marshal(spec.Body)
+ headers["Content-Type"] = "application/json"
+ cc.requests = append(cc.requests, requestLog{
+ Seq: cc.seq,
+ Attempt: 1,
+ TraceID: traceID,
+ Method: spec.Method,
+ URL: fullURL,
+ Headers: headers,
+ Body: spec.Body,
+ Timestamp: time.Now().Format(time.RFC3339Nano),
+ })
+
+ reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
+ defer cancel()
+ req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ start := time.Now()
+ resp, err := cc.runner.httpClient.Do(req)
+ if err != nil {
+ cc.responses = append(cc.responses, responseLog{
+ Seq: cc.seq,
+ Attempt: 1,
+ TraceID: traceID,
+ StatusCode: 0,
+ DurationMS: time.Since(start).Milliseconds(),
+ NetworkErr: err.Error(),
+ ReceivedAt: time.Now().Format(time.RFC3339Nano),
+ })
+ return err
+ }
+ defer resp.Body.Close()
+ buf := make([]byte, 512)
+ _, _ = resp.Body.Read(buf)
+ _ = resp.Body.Close()
+ cc.responses = append(cc.responses, responseLog{
+ Seq: cc.seq,
+ Attempt: 1,
+ TraceID: traceID,
+ StatusCode: resp.StatusCode,
+ Headers: resp.Header,
+ BodyText: "aborted_after_first_chunk",
+ DurationMS: time.Since(start).Milliseconds(),
+ ReceivedAt: time.Now().Format(time.RFC3339Nano),
+ })
+ return nil
+}
diff --git a/internal/testsuite/edge_cases_error_contract.go b/internal/testsuite/edge_cases_error_contract.go
new file mode 100644
index 0000000..d65ce6d
--- /dev/null
+++ b/internal/testsuite/edge_cases_error_contract.go
@@ -0,0 +1,170 @@
+package testsuite
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+func (r *Runner) caseInvalidModel(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.requestOnce(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-not-exists",
+ "messages": []map[string]any{
+ {"role": "user", "content": "hi"},
+ },
+ "stream": false,
+ },
+ Retryable: false,
+ }, 1)
+ if err != nil {
+ return err
+ }
+ cc.assert("status_503", resp.StatusCode == http.StatusServiceUnavailable, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ e, _ := m["error"].(map[string]any)
+ cc.assert("error_type_service_unavailable", asString(e["type"]) == "service_unavailable_error", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseMissingMessages(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "stream": false,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_400", resp.StatusCode == http.StatusBadRequest, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ e, _ := m["error"].(map[string]any)
+ cc.assert("error_type_invalid_request", asString(e["type"]) == "invalid_request_error", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseAdminUnauthorized(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/config",
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
+ return nil
+}
+
+func (r *Runner) caseTokenRefreshManagedAccount(ctx context.Context, cc *caseContext) error {
+ if len(r.configRaw.Accounts) == 0 {
+ cc.assert("account_present", false, "no account in config")
+ return nil
+ }
+ acc := r.configRaw.Accounts[0]
+ id := strings.TrimSpace(acc.Email)
+ if id == "" {
+ id = strings.TrimSpace(acc.Mobile)
+ }
+ if id == "" {
+ cc.assert("account_identifier", false, "first account has no identifier")
+ return nil
+ }
+ if strings.TrimSpace(acc.Password) == "" {
+ r.warnings = append(r.warnings, "token refresh edge case skipped strict check: first account password empty")
+ cc.assert("account_password_present", true, "skipped strict refresh check due empty password")
+ return nil
+ }
+ invalidToken := "invalid-testsuite-refresh-token-" + sanitizeID(r.runID)
+ update := map[string]any{
+ "keys": r.configRaw.Keys,
+ "accounts": []map[string]any{
+ {
+ "email": acc.Email,
+ "mobile": acc.Mobile,
+ "password": acc.Password,
+ "token": invalidToken,
+ },
+ },
+ }
+ updResp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/admin/config",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Body: update,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("update_config_status_200", updResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", updResp.StatusCode))
+
+ chatResp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ "X-Ds2-Target-Account": id,
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {"role": "user", "content": "token refresh test"},
+ },
+ "stream": false,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("chat_status_200", chatResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d body=%s", chatResp.StatusCode, string(chatResp.Body)))
+
+ cfgResp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/config",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ var cfg map[string]any
+ _ = json.Unmarshal(cfgResp.Body, &cfg)
+ accounts, _ := cfg["accounts"].([]any)
+ preview := ""
+ hasToken := false
+ for _, item := range accounts {
+ m, _ := item.(map[string]any)
+ e := asString(m["email"])
+ mo := asString(m["mobile"])
+ if e == acc.Email && mo == acc.Mobile {
+ preview = asString(m["token_preview"])
+ hasToken, _ = m["has_token"].(bool)
+ break
+ }
+ }
+ cc.assert("has_token_after_refresh", hasToken, fmt.Sprintf("config=%s", string(cfgResp.Body)))
+ cc.assert("token_preview_changed_from_invalid", !strings.HasPrefix(preview, invalidToken[:20]), fmt.Sprintf("preview=%s invalid_prefix=%s", preview, invalidToken[:20]))
+ return nil
+}
diff --git a/internal/testsuite/runner.go b/internal/testsuite/runner.go
deleted file mode 100644
index 33e7580..0000000
--- a/internal/testsuite/runner.go
+++ /dev/null
@@ -1,1766 +0,0 @@
-package testsuite
-
-import (
- "bytes"
- "context"
- "crypto/sha256"
- "encoding/hex"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net"
- "net/http"
- "net/url"
- "os"
- "os/exec"
- "path/filepath"
- "runtime"
- "sort"
- "strconv"
- "strings"
- "sync"
- "time"
-)
-
-type Options struct {
- ConfigPath string
- AdminKey string
- OutputDir string
- Port int
- Timeout time.Duration
- Retries int
- NoPreflight bool
- MaxKeepRuns int
-}
-
-type runSummary struct {
- RunID string `json:"run_id"`
- StartedAt string `json:"started_at"`
- EndedAt string `json:"ended_at"`
- DurationMS int64 `json:"duration_ms"`
- Stats map[string]any `json:"stats"`
- Environment map[string]any `json:"environment"`
- Cases []caseResult `json:"cases"`
- Warnings []string `json:"warnings,omitempty"`
-}
-
-type caseResult struct {
- CaseID string `json:"case_id"`
- Passed bool `json:"passed"`
- DurationMS int64 `json:"duration_ms"`
- TraceIDs []string `json:"trace_ids"`
- StatusCodes []int `json:"status_codes"`
- Error string `json:"error,omitempty"`
- ArtifactPath string `json:"artifact_path"`
- Assertions []assertionResult `json:"assertions"`
-}
-
-type assertionResult struct {
- Name string `json:"name"`
- Passed bool `json:"passed"`
- Detail string `json:"detail,omitempty"`
-}
-
-type requestLog struct {
- Seq int `json:"seq"`
- Attempt int `json:"attempt"`
- TraceID string `json:"trace_id"`
- Method string `json:"method"`
- URL string `json:"url"`
- Headers map[string]string `json:"headers"`
- Body any `json:"body,omitempty"`
- Timestamp string `json:"timestamp"`
-}
-
-type responseLog struct {
- Seq int `json:"seq"`
- Attempt int `json:"attempt"`
- TraceID string `json:"trace_id"`
- StatusCode int `json:"status_code"`
- Headers map[string][]string `json:"headers"`
- BodyText string `json:"body_text"`
- DurationMS int64 `json:"duration_ms"`
- NetworkErr string `json:"network_error,omitempty"`
- ReceivedAt string `json:"received_at"`
-}
-
-type caseContext struct {
- runner *Runner
- id string
- dir string
- startedAt time.Time
- mu sync.Mutex
- seq int
- assertions []assertionResult
- requests []requestLog
- responses []responseLog
- streamRaw strings.Builder
- traceIDsSet map[string]struct{}
-}
-
-type requestSpec struct {
- Method string
- Path string
- Headers map[string]string
- Body any
- Stream bool
- Retryable bool
-}
-
-type responseResult struct {
- StatusCode int
- Headers http.Header
- Body []byte
- TraceID string
- URL string
-}
-
-type Runner struct {
- opts Options
-
- runID string
- runDir string
- serverLog string
- preflightLog string
-
- baseURL string
- httpClient *http.Client
- serverCmd *exec.Cmd
- serverLogFd *os.File
-
- configCopyPath string
- originalConfigPath string
- originalConfigHash string
-
- configRaw runConfig
- apiKey string
- adminKey string
- adminJWT string
- accountID string
-
- warnings []string
- results []caseResult
-}
-
-type runConfig struct {
- Keys []string `json:"keys"`
- Accounts []struct {
- Email string `json:"email,omitempty"`
- Mobile string `json:"mobile,omitempty"`
- Password string `json:"password,omitempty"`
- Token string `json:"token,omitempty"`
- } `json:"accounts"`
-}
-
-func DefaultOptions() Options {
- return Options{
- ConfigPath: "config.json",
- AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")),
- OutputDir: "artifacts/testsuite",
- Port: 0,
- Timeout: 120 * time.Second,
- Retries: 2,
- NoPreflight: false,
- MaxKeepRuns: 5,
- }
-}
-
-func Run(ctx context.Context, opts Options) error {
- r, err := newRunner(opts)
- if err != nil {
- return err
- }
- start := time.Now()
- defer func() {
- _ = r.stopServer()
- }()
-
- if err := r.prepareRunDir(); err != nil {
- return err
- }
-
- if !r.opts.NoPreflight {
- if err := r.runPreflight(ctx); err != nil {
- _ = r.writeSummary(start, time.Now())
- return err
- }
- }
-
- if err := r.prepareConfigIsolation(); err != nil {
- _ = r.writeSummary(start, time.Now())
- return err
- }
-
- if err := r.startServer(ctx); err != nil {
- _ = r.writeSummary(start, time.Now())
- return err
- }
-
- if err := r.prepareAuth(ctx); err != nil {
- r.warnings = append(r.warnings, "auth prepare failed: "+err.Error())
- }
-
- for _, c := range r.cases() {
- r.runCase(ctx, c)
- }
-
- if err := r.ensureOriginalConfigUntouched(); err != nil {
- r.warnings = append(r.warnings, err.Error())
- }
-
- end := time.Now()
- if err := r.writeSummary(start, end); err != nil {
- return err
- }
-
- // Prune old test runs, keeping only the most recent N.
- if err := r.pruneOldRuns(); err != nil {
- r.warnings = append(r.warnings, "prune old runs: "+err.Error())
- }
-
- failed := 0
- for _, cs := range r.results {
- if !cs.Passed {
- failed++
- }
- }
- if failed > 0 {
- return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md"))
- }
- return nil
-}
-
-func newRunner(opts Options) (*Runner, error) {
- if strings.TrimSpace(opts.ConfigPath) == "" {
- opts.ConfigPath = "config.json"
- }
- if strings.TrimSpace(opts.OutputDir) == "" {
- opts.OutputDir = "artifacts/testsuite"
- }
- if opts.Timeout <= 0 {
- opts.Timeout = 120 * time.Second
- }
- if opts.Retries < 0 {
- opts.Retries = 0
- }
- adminKey := strings.TrimSpace(opts.AdminKey)
- if adminKey == "" {
- adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY"))
- }
- if adminKey == "" {
- adminKey = "admin"
- }
- opts.AdminKey = adminKey
-
- return &Runner{
- opts: opts,
- httpClient: &http.Client{
- Timeout: 0,
- },
- runID: time.Now().UTC().Format("20060102T150405Z"),
- adminKey: adminKey,
- }, nil
-}
-
-func (r *Runner) prepareRunDir() error {
- r.runDir = filepath.Join(r.opts.OutputDir, r.runID)
- if err := os.MkdirAll(r.runDir, 0o755); err != nil {
- return err
- }
- if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil {
- return err
- }
- r.serverLog = filepath.Join(r.runDir, "server.log")
- r.preflightLog = filepath.Join(r.runDir, "preflight.log")
- return nil
-}
-
-// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns.
-// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order.
-func (r *Runner) pruneOldRuns() error {
- keep := r.opts.MaxKeepRuns
- if keep <= 0 {
- return nil // 0 or negative means no pruning
- }
-
- entries, err := os.ReadDir(r.opts.OutputDir)
- if err != nil {
- return err
- }
-
- // Collect only directories (each run is a directory).
- var runDirs []string
- for _, e := range entries {
- if !e.IsDir() {
- continue
- }
- runDirs = append(runDirs, e.Name())
- }
-
- sort.Strings(runDirs)
-
- if len(runDirs) <= keep {
- return nil
- }
-
- // Remove oldest runs (those at the beginning of the sorted list).
- toRemove := runDirs[:len(runDirs)-keep]
- var errs []string
- for _, name := range toRemove {
- dirPath := filepath.Join(r.opts.OutputDir, name)
- if err := os.RemoveAll(dirPath); err != nil {
- errs = append(errs, fmt.Sprintf("remove %s: %v", name, err))
- } else {
- fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name)
- }
- }
-
- if len(errs) > 0 {
- return errors.New(strings.Join(errs, "; "))
- }
- return nil
-}
-
-func (r *Runner) runPreflight(ctx context.Context) error {
- steps := [][]string{
- {"go", "test", "./...", "-count=1"},
- {"node", "--check", "api/chat-stream.js"},
- {"node", "--check", "api/helpers/stream-tool-sieve.js"},
- {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"},
- {"npm", "run", "build", "--prefix", "webui"},
- }
- f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
- if err != nil {
- return err
- }
- defer f.Close()
- for _, step := range steps {
- if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil {
- return err
- }
- cmd := exec.CommandContext(ctx, step[0], step[1:]...)
- cmd.Stdout = f
- cmd.Stderr = f
- if err := cmd.Run(); err != nil {
- return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err)
- }
- }
- return nil
-}
-
-func (r *Runner) prepareConfigIsolation() error {
- abs, err := filepath.Abs(r.opts.ConfigPath)
- if err != nil {
- return err
- }
- r.originalConfigPath = abs
- raw, err := os.ReadFile(abs)
- if err != nil {
- return err
- }
- sum := sha256.Sum256(raw)
- r.originalConfigHash = hex.EncodeToString(sum[:])
-
- tmpDir := filepath.Join(r.runDir, "tmp")
- if err := os.MkdirAll(tmpDir, 0o755); err != nil {
- return err
- }
- r.configCopyPath = filepath.Join(tmpDir, "config.json")
- if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil {
- return err
- }
- var cfg runConfig
- if err := json.Unmarshal(raw, &cfg); err != nil {
- return fmt.Errorf("parse config failed: %w", err)
- }
- r.configRaw = cfg
- if len(cfg.Keys) > 0 {
- r.apiKey = strings.TrimSpace(cfg.Keys[0])
- }
- for _, acc := range cfg.Accounts {
- id := strings.TrimSpace(acc.Email)
- if id == "" {
- id = strings.TrimSpace(acc.Mobile)
- }
- if id != "" {
- r.accountID = id
- break
- }
- }
- return nil
-}
-
-func (r *Runner) startServer(ctx context.Context) error {
- port := r.opts.Port
- if port <= 0 {
- p, err := findFreePort()
- if err != nil {
- return err
- }
- port = p
- }
- r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port)
-
- logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
- if err != nil {
- return err
- }
- r.serverLogFd = logFd
- cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api")
- cmd.Stdout = logFd
- cmd.Stderr = logFd
- cmd.Env = prepareServerEnv(os.Environ(), map[string]string{
- "PORT": strconv.Itoa(port),
- "DS2API_CONFIG_PATH": r.configCopyPath,
- "DS2API_AUTO_BUILD_WEBUI": "false",
- "DS2API_CONFIG_JSON": "",
- "CONFIG_JSON": "",
- })
- if err := cmd.Start(); err != nil {
- _ = logFd.Close()
- return err
- }
- r.serverCmd = cmd
-
- deadline := time.Now().Add(90 * time.Second)
- for time.Now().Before(deadline) {
- if r.ping("/healthz") == nil && r.ping("/readyz") == nil {
- return nil
- }
- time.Sleep(500 * time.Millisecond)
- }
- return errors.New("server readiness timeout")
-}
-
-func (r *Runner) stopServer() error {
- var errs []string
- if r.serverCmd != nil && r.serverCmd.Process != nil {
- _ = r.serverCmd.Process.Signal(os.Interrupt)
- done := make(chan error, 1)
- go func() { done <- r.serverCmd.Wait() }()
- select {
- case <-time.After(5 * time.Second):
- _ = r.serverCmd.Process.Kill()
- <-done
- case <-done:
- }
- }
- if r.serverLogFd != nil {
- if err := r.serverLogFd.Close(); err != nil {
- errs = append(errs, err.Error())
- }
- }
- if len(errs) > 0 {
- return errors.New(strings.Join(errs, "; "))
- }
- return nil
-}
-
-func (r *Runner) ping(path string) error {
- resp, err := r.httpClient.Get(r.baseURL + path)
- if err != nil {
- return err
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("status=%d", resp.StatusCode)
- }
- return nil
-}
-
-func (r *Runner) prepareAuth(ctx context.Context) error {
- reqBody := map[string]any{
- "admin_key": r.adminKey,
- "expire_hours": 24,
- }
- resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody)
- if err != nil {
- return err
- }
- if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body))
- }
- var m map[string]any
- if err := json.Unmarshal(resp.Body, &m); err != nil {
- return err
- }
- token, _ := m["token"].(string)
- if strings.TrimSpace(token) == "" {
- return errors.New("empty admin jwt token")
- }
- r.adminJWT = token
- return nil
-}
-
-func (r *Runner) ensureOriginalConfigUntouched() error {
- raw, err := os.ReadFile(r.originalConfigPath)
- if err != nil {
- return err
- }
- sum := sha256.Sum256(raw)
- current := hex.EncodeToString(sum[:])
- if current != r.originalConfigHash {
- return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath)
- }
- return nil
-}
-
-func (r *Runner) runCase(ctx context.Context, c caseDef) {
- caseDir := filepath.Join(r.runDir, "cases", c.ID)
- _ = os.MkdirAll(caseDir, 0o755)
- cc := &caseContext{
- runner: r,
- id: c.ID,
- dir: caseDir,
- startedAt: time.Now(),
- traceIDsSet: map[string]struct{}{},
- }
- err := c.Run(ctx, cc)
- duration := time.Since(cc.startedAt).Milliseconds()
-
- if err != nil {
- cc.assertions = append(cc.assertions, assertionResult{
- Name: "case_error",
- Passed: false,
- Detail: err.Error(),
- })
- }
- passed := err == nil
- for _, a := range cc.assertions {
- if !a.Passed {
- passed = false
- break
- }
- }
-
- traceIDs := make([]string, 0, len(cc.traceIDsSet))
- for t := range cc.traceIDsSet {
- traceIDs = append(traceIDs, t)
- }
- sort.Strings(traceIDs)
- statuses := uniqueStatusCodes(cc.responses)
- cs := caseResult{
- CaseID: c.ID,
- Passed: passed,
- DurationMS: duration,
- TraceIDs: traceIDs,
- StatusCodes: statuses,
- ArtifactPath: caseDir,
- Assertions: cc.assertions,
- }
- if err != nil {
- cs.Error = err.Error()
- }
- _ = cc.flushArtifacts(cs)
- r.results = append(r.results, cs)
-}
-
-func (cc *caseContext) assert(name string, ok bool, detail string) {
- cc.mu.Lock()
- defer cc.mu.Unlock()
- cc.assertions = append(cc.assertions, assertionResult{
- Name: name,
- Passed: ok,
- Detail: detail,
- })
-}
-
-func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) {
- retries := cc.runner.opts.Retries
- if !spec.Retryable {
- retries = 0
- }
- var lastErr error
- for attempt := 1; attempt <= retries+1; attempt++ {
- resp, err := cc.requestOnce(ctx, spec, attempt)
- if err == nil && resp.StatusCode < 500 {
- return resp, nil
- }
- if err != nil {
- lastErr = err
- } else if resp.StatusCode >= 500 {
- lastErr = fmt.Errorf("status=%d", resp.StatusCode)
- }
- if attempt <= retries {
- sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond
- time.Sleep(sleep)
- }
- }
- return nil, lastErr
-}
-
-func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) {
- cc.mu.Lock()
- cc.seq++
- seq := cc.seq
- traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq)
- cc.traceIDsSet[traceID] = struct{}{}
- cc.mu.Unlock()
-
- fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
- if err != nil {
- return nil, err
- }
-
- headers := map[string]string{}
- for k, v := range spec.Headers {
- headers[k] = v
- }
- headers["X-Ds2-Test-Trace"] = traceID
-
- var bodyBytes []byte
- var bodyAny any
- if spec.Body != nil {
- b, err := json.Marshal(spec.Body)
- if err != nil {
- return nil, err
- }
- bodyBytes = b
- bodyAny = spec.Body
- headers["Content-Type"] = "application/json"
- }
- cc.mu.Lock()
- cc.requests = append(cc.requests, requestLog{
- Seq: seq,
- Attempt: attempt,
- TraceID: traceID,
- Method: spec.Method,
- URL: fullURL,
- Headers: headers,
- Body: bodyAny,
- Timestamp: time.Now().Format(time.RFC3339Nano),
- })
- cc.mu.Unlock()
-
- reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
- defer cancel()
- req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
- if err != nil {
- return nil, err
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- start := time.Now()
- resp, err := cc.runner.httpClient.Do(req)
- if err != nil {
- cc.mu.Lock()
- cc.responses = append(cc.responses, responseLog{
- Seq: seq,
- Attempt: attempt,
- TraceID: traceID,
- StatusCode: 0,
- DurationMS: time.Since(start).Milliseconds(),
- NetworkErr: err.Error(),
- ReceivedAt: time.Now().Format(time.RFC3339Nano),
- })
- cc.mu.Unlock()
- return nil, err
- }
- defer resp.Body.Close()
- body, _ := io.ReadAll(resp.Body)
-
- cc.mu.Lock()
- cc.responses = append(cc.responses, responseLog{
- Seq: seq,
- Attempt: attempt,
- TraceID: traceID,
- StatusCode: resp.StatusCode,
- Headers: resp.Header,
- BodyText: string(body),
- DurationMS: time.Since(start).Milliseconds(),
- ReceivedAt: time.Now().Format(time.RFC3339Nano),
- })
-
- if spec.Stream {
- cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL))
- cc.streamRaw.Write(body)
- cc.streamRaw.WriteString("\n\n")
- }
- cc.mu.Unlock()
-
- return &responseResult{
- StatusCode: resp.StatusCode,
- Headers: resp.Header,
- Body: body,
- TraceID: traceID,
- URL: fullURL,
- }, nil
-}
-
-func (cc *caseContext) flushArtifacts(cs caseResult) error {
- requestPath := filepath.Join(cc.dir, "request.json")
- headersPath := filepath.Join(cc.dir, "response.headers")
- bodyPath := filepath.Join(cc.dir, "response.body")
- streamPath := filepath.Join(cc.dir, "stream.raw")
- assertPath := filepath.Join(cc.dir, "assertions.json")
- metaPath := filepath.Join(cc.dir, "meta.json")
-
- if err := writeJSONFile(requestPath, cc.requests); err != nil {
- return err
- }
- respHeaders := make([]map[string]any, 0, len(cc.responses))
- respBodies := make([]map[string]any, 0, len(cc.responses))
- for _, r := range cc.responses {
- respHeaders = append(respHeaders, map[string]any{
- "seq": r.Seq,
- "attempt": r.Attempt,
- "trace_id": r.TraceID,
- "status_code": r.StatusCode,
- "headers": r.Headers,
- })
- respBodies = append(respBodies, map[string]any{
- "seq": r.Seq,
- "attempt": r.Attempt,
- "trace_id": r.TraceID,
- "status_code": r.StatusCode,
- "body_text": r.BodyText,
- "network_error": r.NetworkErr,
- "duration_ms": r.DurationMS,
- })
- }
- if err := writeJSONFile(headersPath, respHeaders); err != nil {
- return err
- }
- if err := writeJSONFile(bodyPath, respBodies); err != nil {
- return err
- }
- if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil {
- return err
- }
- if err := writeJSONFile(assertPath, cc.assertions); err != nil {
- return err
- }
- meta := map[string]any{
- "case_id": cs.CaseID,
- "trace_id": strings.Join(cs.TraceIDs, ","),
- "attempt": len(cc.responses),
- "duration_ms": cs.DurationMS,
- "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed],
- "status_codes": cs.StatusCodes,
- "assertions": cs.Assertions,
- "artifact_path": cs.ArtifactPath,
- }
- return writeJSONFile(metaPath, meta)
-}
-
-type caseDef struct {
- ID string
- Run func(context.Context, *caseContext) error
-}
-
-func (r *Runner) cases() []caseDef {
- return []caseDef{
- {ID: "healthz_ok", Run: r.caseHealthz},
- {ID: "readyz_ok", Run: r.caseReadyz},
- {ID: "models_openai", Run: r.caseModelsOpenAI},
- {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
- {ID: "models_claude", Run: r.caseModelsClaude},
- {ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
- {ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
- {ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
- {ID: "chat_stream_basic", Run: r.caseChatStream},
- {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
- {ID: "responses_stream_basic", Run: r.caseResponsesStream},
- {ID: "embeddings_contract", Run: r.caseEmbeddings},
- {ID: "reasoner_stream", Run: r.caseReasonerStream},
- {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
- {ID: "toolcall_stream", Run: r.caseToolcallStream},
- {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream},
- {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream},
- {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens},
- {ID: "admin_account_test_single", Run: r.caseAdminAccountTest},
- {ID: "concurrency_burst", Run: r.caseConcurrencyBurst},
- {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit},
- {ID: "stream_abort_release", Run: r.caseStreamAbortRelease},
- {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed},
- {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity},
- {ID: "error_contract_invalid_model", Run: r.caseInvalidModel},
- {ID: "error_contract_missing_messages", Run: r.caseMissingMessages},
- {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized},
- {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated},
- {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount},
- {ID: "error_contract_invalid_key", Run: r.caseInvalidKey},
- }
-}
-
-func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true})
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true})
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true})
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- ids := extractModelIDs(resp.Body)
- cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ","))
- cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ","))
- return nil
-}
-
-func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- ids := extractModelIDs(resp.Body)
- cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids))
- return nil
-}
-
-func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error {
- loginResp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/admin/login",
- Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24},
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode))
- var payload map[string]any
- _ = json.Unmarshal(loginResp.Body, &payload)
- token := asString(payload["token"])
- cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body)))
- if token == "" {
- return nil
- }
- verifyResp, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/verify",
- Headers: map[string]string{
- "Authorization": "Bearer " + token,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode))
- var v map[string]any
- _ = json.Unmarshal(verifyResp.Body, &v)
- valid, _ := v["valid"].(bool)
- cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body)))
- return nil
-}
-
-func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/queue/status",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- _, hasRec := m["recommended_concurrency"]
- _, hasQueue := m["max_queue_size"]
- cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {"role": "user", "content": "请简单回复一句话"},
- },
- "stream": false,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body)))
- choices, _ := m["choices"].([]any)
- cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {"role": "user", "content": "请流式回复一句话"},
- },
- "stream": true,
- },
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- frames, done := parseSSEFrames(resp.Body)
- cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
- cc.assert("done_terminated", done, "expected [DONE]")
- return nil
-}
-
-func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/responses",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "gpt-4o",
- "input": "请简要回答 hello",
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
- responseID := asString(m["id"])
- cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
- if responseID != "" {
- getResp, getErr := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/v1/responses/" + responseID,
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Retryable: true,
- })
- if getErr != nil {
- return getErr
- }
- cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
- }
- return nil
-}
-
-func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/responses",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "gpt-4o",
- "input": "请流式回答 hello",
- "stream": true,
- },
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- frames, done := parseSSEFrames(resp.Body)
- cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
- hasCreated := false
- hasCompleted := false
- for _, f := range frames {
- switch asString(f["type"]) {
- case "response.created":
- hasCreated = true
- case "response.completed":
- hasCompleted = true
- }
- }
- cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("done_terminated", done, "expected [DONE]")
- return nil
-}
-
-func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/embeddings",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "gpt-4o",
- "input": []string{"hello", "world"},
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- if resp.StatusCode == http.StatusOK {
- cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
- data, _ := m["data"].([]any)
- cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
- }
- errObj, _ := m["error"].(map[string]any)
- _, hasCode := errObj["code"]
- _, hasParam := errObj["param"]
- cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-reasoner",
- "messages": []map[string]any{
- {"role": "user", "content": "先思考后回答:1+1"},
- },
- "stream": true,
- },
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- frames, done := parseSSEFrames(resp.Body)
- hasReasoning := false
- for _, f := range frames {
- choices, _ := f["choices"].([]any)
- for _, c := range choices {
- ch, _ := c.(map[string]any)
- delta, _ := ch["delta"].(map[string]any)
- if asString(delta["reasoning_content"]) != "" {
- hasReasoning = true
- }
- }
- }
- cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found")
- cc.assert("done_terminated", done, "expected [DONE]")
- return nil
-}
-
-func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: toolcallPayload(false),
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- choices, _ := m["choices"].([]any)
- if len(choices) == 0 {
- cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
- }
- c0, _ := choices[0].(map[string]any)
- cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body)))
- msg, _ := c0["message"].(map[string]any)
- tc, _ := msg["tool_calls"].([]any)
- cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: toolcallPayload(true),
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- frames, done := parseSSEFrames(resp.Body)
- hasTool := false
- rawLeak := false
- for _, f := range frames {
- choices, _ := f["choices"].([]any)
- for _, c := range choices {
- ch, _ := c.(map[string]any)
- delta, _ := ch["delta"].(map[string]any)
- if _, ok := delta["tool_calls"]; ok {
- hasTool = true
- }
- content := asString(delta["content"])
- if strings.Contains(strings.ToLower(content), `"tool_calls"`) {
- rawLeak = true
- }
- }
- }
- cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing")
- cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content")
- cc.assert("done_terminated", done, "expected [DONE]")
- return nil
-}
-
-func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/anthropic/v1/messages",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- Body: map[string]any{
- "model": "claude-sonnet-4-5",
- "messages": []map[string]any{
- {"role": "user", "content": "hello"},
- },
- "stream": false,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/anthropic/v1/messages",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- Body: map[string]any{
- "model": "claude-sonnet-4-5",
- "messages": []map[string]any{
- {"role": "user", "content": "stream hello"},
- },
- "stream": true,
- },
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- events := parseClaudeStreamEvents(resp.Body)
- cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events))
- cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events))
- return nil
-}
-
-func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/anthropic/v1/messages/count_tokens",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- },
- Body: map[string]any{
- "model": "claude-sonnet-4-5",
- "messages": []map[string]any{
- {"role": "user", "content": "count me"},
- },
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- v := toInt(m["input_tokens"])
- cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error {
- if strings.TrimSpace(r.accountID) == "" {
- cc.assert("account_present", false, "no account in config")
- return nil
- }
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/admin/accounts/test",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Body: map[string]any{
- "identifier": r.accountID,
- "model": "deepseek-chat",
- "message": "ping",
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- ok, _ := m["success"].(bool)
- cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error {
- accountCount := len(r.configRaw.Accounts)
- n := accountCount*2 + 2
- if n < 2 {
- n = 2
- }
- type one struct {
- Status int
- Err string
- }
- results := make([]one, n)
- var wg sync.WaitGroup
- for i := 0; i < n; i++ {
- wg.Add(1)
- go func(idx int) {
- defer wg.Done()
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.apiKey,
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)},
- },
- "stream": true,
- },
- Stream: true,
- Retryable: true,
- })
- if err != nil {
- results[idx] = one{Err: err.Error()}
- return
- }
- results[idx] = one{Status: resp.StatusCode}
- }(i)
- }
- wg.Wait()
-
- dist := map[int]int{}
- success := 0
- for _, it := range results {
- if it.Status > 0 {
- dist[it.Status]++
- if it.Status == http.StatusOK {
- success++
- }
- }
- }
- cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist))
- _, has5xx := has5xx(dist)
- cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist))
- if err := r.ping("/healthz"); err != nil {
- cc.assert("server_alive", false, err.Error())
- } else {
- cc.assert("server_alive", true, "")
- }
- return nil
-}
-
-func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error {
- k := "testsuite-temp-" + sanitizeID(r.runID)
- add, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/admin/keys",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Body: map[string]any{"key": k},
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode))
-
- cfg1, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/config",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- containsAdded := strings.Contains(string(cfg1.Body), k)
- cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config")
-
- delPath := "/admin/keys/" + url.PathEscape(k)
- del, err := cc.request(ctx, requestSpec{
- Method: http.MethodDelete,
- Path: delPath,
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode))
-
- cfg2, err := cc.request(ctx, requestSpec{
- Method: http.MethodGet,
- Path: "/admin/config",
- Headers: map[string]string{
- "Authorization": "Bearer " + r.adminJWT,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present")
-
- if err := r.ensureOriginalConfigUntouched(); err != nil {
- cc.assert("original_config_unchanged", false, err.Error())
- } else {
- cc.assert("original_config_unchanged", true, "")
- }
- return nil
-}
-
-func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error {
- resp, err := cc.request(ctx, requestSpec{
- Method: http.MethodPost,
- Path: "/v1/chat/completions",
- Headers: map[string]string{
- "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID),
- },
- Body: map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {"role": "user", "content": "hi"},
- },
- "stream": false,
- },
- Retryable: true,
- })
- if err != nil {
- return err
- }
- cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
- var m map[string]any
- _ = json.Unmarshal(resp.Body, &m)
- e, _ := m["error"].(map[string]any)
- cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
- cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body)))
- return nil
-}
-
-func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) {
- cc := &caseContext{
- runner: r,
- id: "auth_prepare",
- traceIDsSet: map[string]struct{}{},
- }
- return cc.request(ctx, requestSpec{
- Method: method,
- Path: path,
- Headers: headers,
- Body: body,
- Retryable: true,
- })
-}
-
-func (r *Runner) writeSummary(start, end time.Time) error {
- passed := 0
- failed := 0
- for _, cs := range r.results {
- if cs.Passed {
- passed++
- } else {
- failed++
- }
- }
- summary := runSummary{
- RunID: r.runID,
- StartedAt: start.Format(time.RFC3339Nano),
- EndedAt: end.Format(time.RFC3339Nano),
- DurationMS: end.Sub(start).Milliseconds(),
- Stats: map[string]any{
- "total": len(r.results),
- "passed": passed,
- "failed": failed,
- },
- Environment: map[string]any{
- "go_version": runtime.Version(),
- "os": runtime.GOOS,
- "arch": runtime.GOARCH,
- "base_url": r.baseURL,
- "config_source": r.originalConfigPath,
- "config_isolated": r.configCopyPath,
- "server_log": r.serverLog,
- "preflight_log": r.preflightLog,
- "retries": r.opts.Retries,
- "timeout_seconds": int(r.opts.Timeout.Seconds()),
- },
- Cases: r.results,
- Warnings: r.warnings,
- }
- if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil {
- return err
- }
- return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644)
-}
-
-func (r *Runner) summaryMarkdown(s runSummary) string {
- var b strings.Builder
- b.WriteString("# DS2API Live Testsuite Summary\n\n")
- b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n")
- fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID)
- fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt)
- fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt)
- fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS)
- fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"])
- if len(s.Warnings) > 0 {
- b.WriteString("## Warnings\n\n")
- for _, w := range s.Warnings {
- fmt.Fprintf(&b, "- %s\n", w)
- }
- b.WriteString("\n")
- }
- b.WriteString("## Failed Cases\n\n")
- hasFailed := false
- for _, c := range s.Cases {
- if c.Passed {
- continue
- }
- hasFailed = true
- fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error)
- if len(c.TraceIDs) > 0 {
- fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", "))
- fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log"))
- }
- fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath)
- }
- if !hasFailed {
- b.WriteString("- none\n")
- }
- b.WriteString("\n## Case Table\n\n")
- b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n")
- b.WriteString("|---|---:|---:|---|---|\n")
- for _, c := range s.Cases {
- status := "PASS"
- if !c.Passed {
- status = "FAIL"
- }
- fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath)
- }
- return b.String()
-}
-
-func toolcallPayload(stream bool) map[string]any {
- return map[string]any{
- "model": "deepseek-chat",
- "messages": []map[string]any{
- {
- "role": "user",
- "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。",
- },
- },
- "tools": []map[string]any{
- {
- "type": "function",
- "function": map[string]any{
- "name": "search",
- "description": "search documents",
- "parameters": map[string]any{
- "type": "object",
- "properties": map[string]any{
- "q": map[string]any{
- "type": "string",
- },
- },
- "required": []string{"q"},
- },
- },
- },
- },
- "stream": stream,
- }
-}
-
-func parseSSEFrames(body []byte) ([]map[string]any, bool) {
- lines := strings.Split(string(body), "\n")
- frames := make([]map[string]any, 0, len(lines))
- done := false
- for _, line := range lines {
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "data:") {
- continue
- }
- payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
- if payload == "" {
- continue
- }
- if payload == "[DONE]" {
- done = true
- continue
- }
- var m map[string]any
- if err := json.Unmarshal([]byte(payload), &m); err == nil {
- frames = append(frames, m)
- }
- }
- return frames, done
-}
-
-func parseClaudeStreamEvents(body []byte) []string {
- events := []string{}
- seen := map[string]bool{}
- lines := strings.Split(string(body), "\n")
- for _, line := range lines {
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "data:") {
- continue
- }
- payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
- if payload == "" {
- continue
- }
- var m map[string]any
- if err := json.Unmarshal([]byte(payload), &m); err != nil {
- continue
- }
- t := asString(m["type"])
- if t == "" || seen[t] {
- continue
- }
- seen[t] = true
- events = append(events, t)
- }
- return events
-}
-
-func extractModelIDs(body []byte) []string {
- var m map[string]any
- if err := json.Unmarshal(body, &m); err != nil {
- return nil
- }
- out := []string{}
- data, _ := m["data"].([]any)
- for _, it := range data {
- item, _ := it.(map[string]any)
- id := asString(item["id"])
- if id != "" {
- out = append(out, id)
- }
- }
- return out
-}
-
-func withTraceQuery(rawURL, traceID string) (string, error) {
- u, err := url.Parse(rawURL)
- if err != nil {
- return "", err
- }
- q := u.Query()
- q.Set("__trace_id", traceID)
- u.RawQuery = q.Encode()
- return u.String(), nil
-}
-
-func writeJSONFile(path string, v any) error {
- b, err := json.MarshalIndent(v, "", " ")
- if err != nil {
- return err
- }
- return os.WriteFile(path, b, 0o644)
-}
-
-func prepareServerEnv(base []string, overrides map[string]string) []string {
- out := make([]string, 0, len(base)+len(overrides))
- skip := map[string]struct{}{}
- for k := range overrides {
- skip[k] = struct{}{}
- }
- for _, e := range base {
- parts := strings.SplitN(e, "=", 2)
- if len(parts) != 2 {
- continue
- }
- if _, ok := skip[parts[0]]; ok {
- continue
- }
- out = append(out, e)
- }
- for k, v := range overrides {
- out = append(out, k+"="+v)
- }
- return out
-}
-
-func findFreePort() (int, error) {
- ln, err := net.Listen("tcp", "127.0.0.1:0")
- if err != nil {
- return 0, err
- }
- defer ln.Close()
- addr, ok := ln.Addr().(*net.TCPAddr)
- if !ok {
- return 0, errors.New("failed to detect tcp port")
- }
- return addr.Port, nil
-}
-
-func uniqueStatusCodes(in []responseLog) []int {
- set := map[int]struct{}{}
- for _, it := range in {
- if it.StatusCode > 0 {
- set[it.StatusCode] = struct{}{}
- }
- }
- out := make([]int, 0, len(set))
- for k := range set {
- out = append(out, k)
- }
- sort.Ints(out)
- return out
-}
-
-func has5xx(dist map[int]int) (int, bool) {
- for k := range dist {
- if k >= 500 {
- return k, true
- }
- }
- return 0, false
-}
-
-func sanitizeID(s string) string {
- s = strings.ReplaceAll(s, ":", "_")
- s = strings.ReplaceAll(s, "/", "_")
- s = strings.ReplaceAll(s, " ", "_")
- return s
-}
-
-func asString(v any) string {
- if v == nil {
- return ""
- }
- switch x := v.(type) {
- case string:
- return strings.TrimSpace(x)
- default:
- return strings.TrimSpace(fmt.Sprintf("%v", v))
- }
-}
-
-func toInt(v any) int {
- switch x := v.(type) {
- case float64:
- return int(x)
- case float32:
- return int(x)
- case int:
- return x
- case int64:
- return int(x)
- default:
- return 0
- }
-}
-
-func contains(xs []string, target string) bool {
- for _, x := range xs {
- if x == target {
- return true
- }
- }
- return false
-}
diff --git a/internal/testsuite/runner_cases_admin.go b/internal/testsuite/runner_cases_admin.go
new file mode 100644
index 0000000..d66adea
--- /dev/null
+++ b/internal/testsuite/runner_cases_admin.go
@@ -0,0 +1,161 @@
+package testsuite
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+)
+
+func (r *Runner) caseAdminLoginVerify(ctx context.Context, cc *caseContext) error {
+ loginResp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/admin/login",
+ Body: map[string]any{"admin_key": r.adminKey, "expire_hours": 24},
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("login_status_200", loginResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", loginResp.StatusCode))
+ var payload map[string]any
+ _ = json.Unmarshal(loginResp.Body, &payload)
+ token := asString(payload["token"])
+ cc.assert("token_exists", token != "", fmt.Sprintf("body=%s", string(loginResp.Body)))
+ if token == "" {
+ return nil
+ }
+ verifyResp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/verify",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + token,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("verify_status_200", verifyResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", verifyResp.StatusCode))
+ var v map[string]any
+ _ = json.Unmarshal(verifyResp.Body, &v)
+ valid, _ := v["valid"].(bool)
+ cc.assert("verify_valid_true", valid, fmt.Sprintf("body=%s", string(verifyResp.Body)))
+ return nil
+}
+
+func (r *Runner) caseAdminQueueStatus(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/queue/status",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ _, hasRec := m["recommended_concurrency"]
+ _, hasQueue := m["max_queue_size"]
+ cc.assert("has_recommended_concurrency", hasRec, fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("has_max_queue_size", hasQueue, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+func (r *Runner) caseAdminAccountTest(ctx context.Context, cc *caseContext) error {
+ if strings.TrimSpace(r.accountID) == "" {
+ cc.assert("account_present", false, "no account in config")
+ return nil
+ }
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/admin/accounts/test",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Body: map[string]any{
+ "identifier": r.accountID,
+ "model": "deepseek-chat",
+ "message": "ping",
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ ok, _ := m["success"].(bool)
+ cc.assert("success_true", ok, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+func (r *Runner) caseConfigWriteIsolated(ctx context.Context, cc *caseContext) error {
+ k := "testsuite-temp-" + sanitizeID(r.runID)
+ add, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/admin/keys",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Body: map[string]any{"key": k},
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("add_key_status_200", add.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", add.StatusCode))
+
+ cfg1, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/config",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ containsAdded := strings.Contains(string(cfg1.Body), k)
+ cc.assert("key_present_in_isolated_config", containsAdded, "added key not found in isolated config")
+
+ delPath := "/admin/keys/" + url.PathEscape(k)
+ del, err := cc.request(ctx, requestSpec{
+ Method: http.MethodDelete,
+ Path: delPath,
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("delete_key_status_200", del.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", del.StatusCode))
+
+ cfg2, err := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/admin/config",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.adminJWT,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("key_removed_in_isolated_config", !strings.Contains(string(cfg2.Body), k), "temporary key still present")
+
+ if err := r.ensureOriginalConfigUntouched(); err != nil {
+ cc.assert("original_config_unchanged", false, err.Error())
+ } else {
+ cc.assert("original_config_unchanged", true, "")
+ }
+ return nil
+}
diff --git a/internal/testsuite/runner_cases_claude.go b/internal/testsuite/runner_cases_claude.go
new file mode 100644
index 0000000..590e524
--- /dev/null
+++ b/internal/testsuite/runner_cases_claude.go
@@ -0,0 +1,103 @@
+package testsuite
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+)
+
+func (r *Runner) caseModelsClaude(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/anthropic/v1/models", Retryable: true})
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ ids := extractModelIDs(resp.Body)
+ cc.assert("non_empty", len(ids) > 0, fmt.Sprintf("models=%v", ids))
+ return nil
+}
+func (r *Runner) caseAnthropicNonstream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/anthropic/v1/messages",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ Body: map[string]any{
+ "model": "claude-sonnet-4-5",
+ "messages": []map[string]any{
+ {"role": "user", "content": "hello"},
+ },
+ "stream": false,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("type_message", asString(m["type"]) == "message", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseAnthropicStream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/anthropic/v1/messages",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ Body: map[string]any{
+ "model": "claude-sonnet-4-5",
+ "messages": []map[string]any{
+ {"role": "user", "content": "stream hello"},
+ },
+ "stream": true,
+ },
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ events := parseClaudeStreamEvents(resp.Body)
+ cc.assert("has_message_start", contains(events, "message_start"), fmt.Sprintf("events=%v", events))
+ cc.assert("has_message_stop", contains(events, "message_stop"), fmt.Sprintf("events=%v", events))
+ return nil
+}
+
+func (r *Runner) caseAnthropicCountTokens(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/anthropic/v1/messages/count_tokens",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ "anthropic-version": "2023-06-01",
+ "content-type": "application/json",
+ },
+ Body: map[string]any{
+ "model": "claude-sonnet-4-5",
+ "messages": []map[string]any{
+ {"role": "user", "content": "count me"},
+ },
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ v := toInt(m["input_tokens"])
+ cc.assert("input_tokens_gt_zero", v > 0, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
diff --git a/internal/testsuite/runner_cases_openai.go b/internal/testsuite/runner_cases_openai.go
new file mode 100644
index 0000000..4ca2e40
--- /dev/null
+++ b/internal/testsuite/runner_cases_openai.go
@@ -0,0 +1,221 @@
+package testsuite
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+)
+
+func (r *Runner) caseHealthz(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/healthz", Retryable: true})
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("status_ok", asString(m["status"]) == "ok", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseReadyz(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/readyz", Retryable: true})
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("status_ready", asString(m["status"]) == "ready", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseModelsOpenAI(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models", Retryable: true})
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ ids := extractModelIDs(resp.Body)
+ cc.assert("has_deepseek_chat", contains(ids, "deepseek-chat"), strings.Join(ids, ","))
+ cc.assert("has_deepseek_reasoner", contains(ids, "deepseek-reasoner"), strings.Join(ids, ","))
+ return nil
+}
+
+func (r *Runner) caseModelOpenAIByID(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{Method: http.MethodGet, Path: "/v1/models/gpt-4o", Retryable: true})
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("object_model", asString(m["object"]) == "model", fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("id_deepseek_chat", asString(m["id"]) == "deepseek-chat", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+func (r *Runner) caseChatNonstream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {"role": "user", "content": "请简单回复一句话"},
+ },
+ "stream": false,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("object_chat_completion", asString(m["object"]) == "chat.completion", fmt.Sprintf("body=%s", string(resp.Body)))
+ choices, _ := m["choices"].([]any)
+ cc.assert("choices_non_empty", len(choices) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseChatStream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {"role": "user", "content": "请流式回复一句话"},
+ },
+ "stream": true,
+ },
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ frames, done := parseSSEFrames(resp.Body)
+ cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
+ cc.assert("done_terminated", done, "expected [DONE]")
+ return nil
+}
+
+func (r *Runner) caseResponsesNonstream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/responses",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "gpt-4o",
+ "input": "请简要回答 hello",
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ cc.assert("object_response", asString(m["object"]) == "response", fmt.Sprintf("body=%s", string(resp.Body)))
+ responseID := asString(m["id"])
+ cc.assert("response_id_present", responseID != "", fmt.Sprintf("body=%s", string(resp.Body)))
+ if responseID != "" {
+ getResp, getErr := cc.request(ctx, requestSpec{
+ Method: http.MethodGet,
+ Path: "/v1/responses/" + responseID,
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Retryable: true,
+ })
+ if getErr != nil {
+ return getErr
+ }
+ cc.assert("get_status_200", getResp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", getResp.StatusCode))
+ }
+ return nil
+}
+
+func (r *Runner) caseResponsesStream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/responses",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "gpt-4o",
+ "input": "请流式回答 hello",
+ "stream": true,
+ },
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ frames, done := parseSSEFrames(resp.Body)
+ cc.assert("frames_non_empty", len(frames) > 0, fmt.Sprintf("len=%d", len(frames)))
+ hasCreated := false
+ hasCompleted := false
+ for _, f := range frames {
+ switch asString(f["type"]) {
+ case "response.created":
+ hasCreated = true
+ case "response.completed":
+ hasCompleted = true
+ }
+ }
+ cc.assert("has_response_created", hasCreated, fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("has_response_completed", hasCompleted, fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("done_terminated", done, "expected [DONE]")
+ return nil
+}
+
+func (r *Runner) caseEmbeddings(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/embeddings",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "gpt-4o",
+ "input": []string{"hello", "world"},
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200_or_501", resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusNotImplemented, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ if resp.StatusCode == http.StatusOK {
+ cc.assert("object_list", asString(m["object"]) == "list", fmt.Sprintf("body=%s", string(resp.Body)))
+ data, _ := m["data"].([]any)
+ cc.assert("data_non_empty", len(data) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+ }
+ errObj, _ := m["error"].(map[string]any)
+ _, hasCode := errObj["code"]
+ _, hasParam := errObj["param"]
+ cc.assert("error_has_code", hasCode, fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("error_has_param", hasParam, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
diff --git a/internal/testsuite/runner_cases_openai_advanced.go b/internal/testsuite/runner_cases_openai_advanced.go
new file mode 100644
index 0000000..34e9f01
--- /dev/null
+++ b/internal/testsuite/runner_cases_openai_advanced.go
@@ -0,0 +1,236 @@
+package testsuite
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strings"
+ "sync"
+)
+
+func (r *Runner) caseReasonerStream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-reasoner",
+ "messages": []map[string]any{
+ {"role": "user", "content": "先思考后回答:1+1"},
+ },
+ "stream": true,
+ },
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ frames, done := parseSSEFrames(resp.Body)
+ hasReasoning := false
+ for _, f := range frames {
+ choices, _ := f["choices"].([]any)
+ for _, c := range choices {
+ ch, _ := c.(map[string]any)
+ delta, _ := ch["delta"].(map[string]any)
+ if asString(delta["reasoning_content"]) != "" {
+ hasReasoning = true
+ }
+ }
+ }
+ cc.assert("has_reasoning_content", hasReasoning, "reasoning_content not found")
+ cc.assert("done_terminated", done, "expected [DONE]")
+ return nil
+}
+
+func (r *Runner) caseToolcallNonstream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: toolcallPayload(false),
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ choices, _ := m["choices"].([]any)
+ if len(choices) == 0 {
+ cc.assert("choices_non_empty", false, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+ }
+ c0, _ := choices[0].(map[string]any)
+ cc.assert("finish_reason_tool_calls", asString(c0["finish_reason"]) == "tool_calls", fmt.Sprintf("body=%s", string(resp.Body)))
+ msg, _ := c0["message"].(map[string]any)
+ tc, _ := msg["tool_calls"].([]any)
+ cc.assert("tool_calls_present", len(tc) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func (r *Runner) caseToolcallStream(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: toolcallPayload(true),
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_200", resp.StatusCode == http.StatusOK, fmt.Sprintf("status=%d", resp.StatusCode))
+ frames, done := parseSSEFrames(resp.Body)
+ hasTool := false
+ rawLeak := false
+ for _, f := range frames {
+ choices, _ := f["choices"].([]any)
+ for _, c := range choices {
+ ch, _ := c.(map[string]any)
+ delta, _ := ch["delta"].(map[string]any)
+ if _, ok := delta["tool_calls"]; ok {
+ hasTool = true
+ }
+ content := asString(delta["content"])
+ if strings.Contains(strings.ToLower(content), `"tool_calls"`) {
+ rawLeak = true
+ }
+ }
+ }
+ cc.assert("tool_calls_delta_present", hasTool, "tool_calls delta missing")
+ cc.assert("no_raw_tool_json_leak", !rawLeak, "raw tool_calls JSON leaked in content")
+ cc.assert("done_terminated", done, "expected [DONE]")
+ return nil
+}
+
+func (r *Runner) caseConcurrencyBurst(ctx context.Context, cc *caseContext) error {
+ accountCount := len(r.configRaw.Accounts)
+ n := accountCount*2 + 2
+ if n < 2 {
+ n = 2
+ }
+ type one struct {
+ Status int
+ Err string
+ }
+ results := make([]one, n)
+ var wg sync.WaitGroup
+ for i := 0; i < n; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer " + r.apiKey,
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {"role": "user", "content": fmt.Sprintf("并发请求 #%d,请回复ok", idx)},
+ },
+ "stream": true,
+ },
+ Stream: true,
+ Retryable: true,
+ })
+ if err != nil {
+ results[idx] = one{Err: err.Error()}
+ return
+ }
+ results[idx] = one{Status: resp.StatusCode}
+ }(i)
+ }
+ wg.Wait()
+
+ dist := map[int]int{}
+ success := 0
+ for _, it := range results {
+ if it.Status > 0 {
+ dist[it.Status]++
+ if it.Status == http.StatusOK {
+ success++
+ }
+ }
+ }
+ cc.assert("success_gt_zero", success > 0, fmt.Sprintf("distribution=%v", dist))
+ _, has5xx := has5xx(dist)
+ cc.assert("no_5xx", !has5xx, fmt.Sprintf("distribution=%v", dist))
+ if err := r.ping("/healthz"); err != nil {
+ cc.assert("server_alive", false, err.Error())
+ } else {
+ cc.assert("server_alive", true, "")
+ }
+ return nil
+}
+
+func (r *Runner) caseInvalidKey(ctx context.Context, cc *caseContext) error {
+ resp, err := cc.request(ctx, requestSpec{
+ Method: http.MethodPost,
+ Path: "/v1/chat/completions",
+ Headers: map[string]string{
+ "Authorization": "Bearer invalid-testsuite-key-" + sanitizeID(r.runID),
+ },
+ Body: map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {"role": "user", "content": "hi"},
+ },
+ "stream": false,
+ },
+ Retryable: true,
+ })
+ if err != nil {
+ return err
+ }
+ cc.assert("status_401", resp.StatusCode == http.StatusUnauthorized, fmt.Sprintf("status=%d", resp.StatusCode))
+ var m map[string]any
+ _ = json.Unmarshal(resp.Body, &m)
+ e, _ := m["error"].(map[string]any)
+ cc.assert("error_object_present", len(e) > 0, fmt.Sprintf("body=%s", string(resp.Body)))
+ cc.assert("error_message_present", asString(e["message"]) != "", fmt.Sprintf("body=%s", string(resp.Body)))
+ return nil
+}
+
+func toolcallPayload(stream bool) map[string]any {
+ return map[string]any{
+ "model": "deepseek-chat",
+ "messages": []map[string]any{
+ {
+ "role": "user",
+ "content": "你必须调用工具 search 查询 golang,并仅返回工具调用。",
+ },
+ },
+ "tools": []map[string]any{
+ {
+ "type": "function",
+ "function": map[string]any{
+ "name": "search",
+ "description": "search documents",
+ "parameters": map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "q": map[string]any{
+ "type": "string",
+ },
+ },
+ "required": []string{"q"},
+ },
+ },
+ },
+ },
+ "stream": stream,
+ }
+}
diff --git a/internal/testsuite/runner_core.go b/internal/testsuite/runner_core.go
new file mode 100644
index 0000000..06eafa5
--- /dev/null
+++ b/internal/testsuite/runner_core.go
@@ -0,0 +1,290 @@
+package testsuite
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+type Options struct {
+ ConfigPath string
+ AdminKey string
+ OutputDir string
+ Port int
+ Timeout time.Duration
+ Retries int
+ NoPreflight bool
+ MaxKeepRuns int
+}
+
+type runSummary struct {
+ RunID string `json:"run_id"`
+ StartedAt string `json:"started_at"`
+ EndedAt string `json:"ended_at"`
+ DurationMS int64 `json:"duration_ms"`
+ Stats map[string]any `json:"stats"`
+ Environment map[string]any `json:"environment"`
+ Cases []caseResult `json:"cases"`
+ Warnings []string `json:"warnings,omitempty"`
+}
+
+type caseResult struct {
+ CaseID string `json:"case_id"`
+ Passed bool `json:"passed"`
+ DurationMS int64 `json:"duration_ms"`
+ TraceIDs []string `json:"trace_ids"`
+ StatusCodes []int `json:"status_codes"`
+ Error string `json:"error,omitempty"`
+ ArtifactPath string `json:"artifact_path"`
+ Assertions []assertionResult `json:"assertions"`
+}
+
+type assertionResult struct {
+ Name string `json:"name"`
+ Passed bool `json:"passed"`
+ Detail string `json:"detail,omitempty"`
+}
+
+type requestLog struct {
+ Seq int `json:"seq"`
+ Attempt int `json:"attempt"`
+ TraceID string `json:"trace_id"`
+ Method string `json:"method"`
+ URL string `json:"url"`
+ Headers map[string]string `json:"headers"`
+ Body any `json:"body,omitempty"`
+ Timestamp string `json:"timestamp"`
+}
+
+type responseLog struct {
+ Seq int `json:"seq"`
+ Attempt int `json:"attempt"`
+ TraceID string `json:"trace_id"`
+ StatusCode int `json:"status_code"`
+ Headers map[string][]string `json:"headers"`
+ BodyText string `json:"body_text"`
+ DurationMS int64 `json:"duration_ms"`
+ NetworkErr string `json:"network_error,omitempty"`
+ ReceivedAt string `json:"received_at"`
+}
+
+type caseContext struct {
+ runner *Runner
+ id string
+ dir string
+ startedAt time.Time
+ mu sync.Mutex
+ seq int
+ assertions []assertionResult
+ requests []requestLog
+ responses []responseLog
+ streamRaw strings.Builder
+ traceIDsSet map[string]struct{}
+}
+
+type requestSpec struct {
+ Method string
+ Path string
+ Headers map[string]string
+ Body any
+ Stream bool
+ Retryable bool
+}
+
+type responseResult struct {
+ StatusCode int
+ Headers http.Header
+ Body []byte
+ TraceID string
+ URL string
+}
+
+type Runner struct {
+ opts Options
+
+ runID string
+ runDir string
+ serverLog string
+ preflightLog string
+
+ baseURL string
+ httpClient *http.Client
+ serverCmd *exec.Cmd
+ serverLogFd *os.File
+
+ configCopyPath string
+ originalConfigPath string
+ originalConfigHash string
+
+ configRaw runConfig
+ apiKey string
+ adminKey string
+ adminJWT string
+ accountID string
+
+ warnings []string
+ results []caseResult
+}
+
+type runConfig struct {
+ Keys []string `json:"keys"`
+ Accounts []struct {
+ Email string `json:"email,omitempty"`
+ Mobile string `json:"mobile,omitempty"`
+ Password string `json:"password,omitempty"`
+ Token string `json:"token,omitempty"`
+ } `json:"accounts"`
+}
+
+func Run(ctx context.Context, opts Options) error {
+ r, err := newRunner(opts)
+ if err != nil {
+ return err
+ }
+ start := time.Now()
+ defer func() {
+ _ = r.stopServer()
+ }()
+
+ if err := r.prepareRunDir(); err != nil {
+ return err
+ }
+
+ if !r.opts.NoPreflight {
+ if err := r.runPreflight(ctx); err != nil {
+ _ = r.writeSummary(start, time.Now())
+ return err
+ }
+ }
+
+ if err := r.prepareConfigIsolation(); err != nil {
+ _ = r.writeSummary(start, time.Now())
+ return err
+ }
+
+ if err := r.startServer(ctx); err != nil {
+ _ = r.writeSummary(start, time.Now())
+ return err
+ }
+
+ if err := r.prepareAuth(ctx); err != nil {
+ r.warnings = append(r.warnings, "auth prepare failed: "+err.Error())
+ }
+
+ for _, c := range r.cases() {
+ r.runCase(ctx, c)
+ }
+
+ if err := r.ensureOriginalConfigUntouched(); err != nil {
+ r.warnings = append(r.warnings, err.Error())
+ }
+
+ end := time.Now()
+ if err := r.writeSummary(start, end); err != nil {
+ return err
+ }
+
+ // Prune old test runs, keeping only the most recent N.
+ if err := r.pruneOldRuns(); err != nil {
+ r.warnings = append(r.warnings, "prune old runs: "+err.Error())
+ }
+
+ failed := 0
+ for _, cs := range r.results {
+ if !cs.Passed {
+ failed++
+ }
+ }
+ if failed > 0 {
+ return fmt.Errorf("testsuite failed: %d case(s) failed, see %s", failed, filepath.Join(r.runDir, "summary.md"))
+ }
+ return nil
+}
+
+func newRunner(opts Options) (*Runner, error) {
+ if strings.TrimSpace(opts.ConfigPath) == "" {
+ opts.ConfigPath = "config.json"
+ }
+ if strings.TrimSpace(opts.OutputDir) == "" {
+ opts.OutputDir = "artifacts/testsuite"
+ }
+ if opts.Timeout <= 0 {
+ opts.Timeout = 120 * time.Second
+ }
+ if opts.Retries < 0 {
+ opts.Retries = 0
+ }
+ adminKey := strings.TrimSpace(opts.AdminKey)
+ if adminKey == "" {
+ adminKey = strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY"))
+ }
+ if adminKey == "" {
+ adminKey = "admin"
+ }
+ opts.AdminKey = adminKey
+
+ return &Runner{
+ opts: opts,
+ httpClient: &http.Client{
+ Timeout: 0,
+ },
+ runID: time.Now().UTC().Format("20060102T150405Z"),
+ adminKey: adminKey,
+ }, nil
+}
+func (r *Runner) runCase(ctx context.Context, c caseDef) {
+ caseDir := filepath.Join(r.runDir, "cases", c.ID)
+ _ = os.MkdirAll(caseDir, 0o755)
+ cc := &caseContext{
+ runner: r,
+ id: c.ID,
+ dir: caseDir,
+ startedAt: time.Now(),
+ traceIDsSet: map[string]struct{}{},
+ }
+ err := c.Run(ctx, cc)
+ duration := time.Since(cc.startedAt).Milliseconds()
+
+ if err != nil {
+ cc.assertions = append(cc.assertions, assertionResult{
+ Name: "case_error",
+ Passed: false,
+ Detail: err.Error(),
+ })
+ }
+ passed := err == nil
+ for _, a := range cc.assertions {
+ if !a.Passed {
+ passed = false
+ break
+ }
+ }
+
+ traceIDs := make([]string, 0, len(cc.traceIDsSet))
+ for t := range cc.traceIDsSet {
+ traceIDs = append(traceIDs, t)
+ }
+ sort.Strings(traceIDs)
+ statuses := uniqueStatusCodes(cc.responses)
+ cs := caseResult{
+ CaseID: c.ID,
+ Passed: passed,
+ DurationMS: duration,
+ TraceIDs: traceIDs,
+ StatusCodes: statuses,
+ ArtifactPath: caseDir,
+ Assertions: cc.assertions,
+ }
+ if err != nil {
+ cs.Error = err.Error()
+ }
+ _ = cc.flushArtifacts(cs)
+ r.results = append(r.results, cs)
+}
diff --git a/internal/testsuite/runner_defaults.go b/internal/testsuite/runner_defaults.go
new file mode 100644
index 0000000..ab30bf1
--- /dev/null
+++ b/internal/testsuite/runner_defaults.go
@@ -0,0 +1,20 @@
+package testsuite
+
+import (
+ "os"
+ "strings"
+ "time"
+)
+
+func DefaultOptions() Options {
+ return Options{
+ ConfigPath: "config.json",
+ AdminKey: strings.TrimSpace(os.Getenv("DS2API_ADMIN_KEY")),
+ OutputDir: "artifacts/testsuite",
+ Port: 0,
+ Timeout: 120 * time.Second,
+ Retries: 2,
+ NoPreflight: false,
+ MaxKeepRuns: 5,
+ }
+}
diff --git a/internal/testsuite/runner_env.go b/internal/testsuite/runner_env.go
new file mode 100644
index 0000000..3ae0ba4
--- /dev/null
+++ b/internal/testsuite/runner_env.go
@@ -0,0 +1,261 @@
+package testsuite
+
+import (
+ "context"
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "sort"
+ "strconv"
+ "strings"
+ "time"
+)
+
+func (r *Runner) prepareRunDir() error {
+ r.runDir = filepath.Join(r.opts.OutputDir, r.runID)
+ if err := os.MkdirAll(r.runDir, 0o755); err != nil {
+ return err
+ }
+ if err := os.MkdirAll(filepath.Join(r.runDir, "cases"), 0o755); err != nil {
+ return err
+ }
+ r.serverLog = filepath.Join(r.runDir, "server.log")
+ r.preflightLog = filepath.Join(r.runDir, "preflight.log")
+ return nil
+}
+
+// pruneOldRuns removes old test run directories, keeping the most recent MaxKeepRuns.
+// Run IDs use the format "20060102T150405Z", so alphabetical order == chronological order.
+func (r *Runner) pruneOldRuns() error {
+ keep := r.opts.MaxKeepRuns
+ if keep <= 0 {
+ return nil // 0 or negative means no pruning
+ }
+
+ entries, err := os.ReadDir(r.opts.OutputDir)
+ if err != nil {
+ return err
+ }
+
+ // Collect only directories (each run is a directory).
+ var runDirs []string
+ for _, e := range entries {
+ if !e.IsDir() {
+ continue
+ }
+ runDirs = append(runDirs, e.Name())
+ }
+
+ sort.Strings(runDirs)
+
+ if len(runDirs) <= keep {
+ return nil
+ }
+
+ // Remove oldest runs (those at the beginning of the sorted list).
+ toRemove := runDirs[:len(runDirs)-keep]
+ var errs []string
+ for _, name := range toRemove {
+ dirPath := filepath.Join(r.opts.OutputDir, name)
+ if err := os.RemoveAll(dirPath); err != nil {
+ errs = append(errs, fmt.Sprintf("remove %s: %v", name, err))
+ } else {
+ fmt.Fprintf(os.Stdout, "pruned old test run: %s\n", name)
+ }
+ }
+
+ if len(errs) > 0 {
+ return errors.New(strings.Join(errs, "; "))
+ }
+ return nil
+}
+
+func (r *Runner) runPreflight(ctx context.Context) error {
+ steps := [][]string{
+ {"go", "test", "./...", "-count=1"},
+ {"node", "--check", "api/chat-stream.js"},
+ {"node", "--check", "api/helpers/stream-tool-sieve.js"},
+ {"node", "--test", "api/helpers/stream-tool-sieve.test.js", "api/chat-stream.test.js", "api/compat/js_compat_test.js"},
+ {"npm", "run", "build", "--prefix", "webui"},
+ }
+ f, err := os.OpenFile(r.preflightLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ for _, step := range steps {
+ if _, err := fmt.Fprintf(f, "\n$ %s\n", strings.Join(step, " ")); err != nil {
+ return err
+ }
+ cmd := exec.CommandContext(ctx, step[0], step[1:]...)
+ cmd.Stdout = f
+ cmd.Stderr = f
+ if err := cmd.Run(); err != nil {
+ return fmt.Errorf("preflight failed at `%s`: %w", strings.Join(step, " "), err)
+ }
+ }
+ return nil
+}
+
+func (r *Runner) prepareConfigIsolation() error {
+ abs, err := filepath.Abs(r.opts.ConfigPath)
+ if err != nil {
+ return err
+ }
+ r.originalConfigPath = abs
+ raw, err := os.ReadFile(abs)
+ if err != nil {
+ return err
+ }
+ sum := sha256.Sum256(raw)
+ r.originalConfigHash = hex.EncodeToString(sum[:])
+
+ tmpDir := filepath.Join(r.runDir, "tmp")
+ if err := os.MkdirAll(tmpDir, 0o755); err != nil {
+ return err
+ }
+ r.configCopyPath = filepath.Join(tmpDir, "config.json")
+ if err := os.WriteFile(r.configCopyPath, raw, 0o644); err != nil {
+ return err
+ }
+ var cfg runConfig
+ if err := json.Unmarshal(raw, &cfg); err != nil {
+ return fmt.Errorf("parse config failed: %w", err)
+ }
+ r.configRaw = cfg
+ if len(cfg.Keys) > 0 {
+ r.apiKey = strings.TrimSpace(cfg.Keys[0])
+ }
+ for _, acc := range cfg.Accounts {
+ id := strings.TrimSpace(acc.Email)
+ if id == "" {
+ id = strings.TrimSpace(acc.Mobile)
+ }
+ if id != "" {
+ r.accountID = id
+ break
+ }
+ }
+ return nil
+}
+
+func (r *Runner) startServer(ctx context.Context) error {
+ port := r.opts.Port
+ if port <= 0 {
+ p, err := findFreePort()
+ if err != nil {
+ return err
+ }
+ port = p
+ }
+ r.baseURL = "http://127.0.0.1:" + strconv.Itoa(port)
+
+ logFd, err := os.OpenFile(r.serverLog, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ return err
+ }
+ r.serverLogFd = logFd
+ cmd := exec.CommandContext(ctx, "go", "run", "./cmd/ds2api")
+ cmd.Stdout = logFd
+ cmd.Stderr = logFd
+ cmd.Env = prepareServerEnv(os.Environ(), map[string]string{
+ "PORT": strconv.Itoa(port),
+ "DS2API_CONFIG_PATH": r.configCopyPath,
+ "DS2API_AUTO_BUILD_WEBUI": "false",
+ "DS2API_CONFIG_JSON": "",
+ "CONFIG_JSON": "",
+ })
+ if err := cmd.Start(); err != nil {
+ _ = logFd.Close()
+ return err
+ }
+ r.serverCmd = cmd
+
+ deadline := time.Now().Add(90 * time.Second)
+ for time.Now().Before(deadline) {
+ if r.ping("/healthz") == nil && r.ping("/readyz") == nil {
+ return nil
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ return errors.New("server readiness timeout")
+}
+
+func (r *Runner) stopServer() error {
+ var errs []string
+ if r.serverCmd != nil && r.serverCmd.Process != nil {
+ _ = r.serverCmd.Process.Signal(os.Interrupt)
+ done := make(chan error, 1)
+ go func() { done <- r.serverCmd.Wait() }()
+ select {
+ case <-time.After(5 * time.Second):
+ _ = r.serverCmd.Process.Kill()
+ <-done
+ case <-done:
+ }
+ }
+ if r.serverLogFd != nil {
+ if err := r.serverLogFd.Close(); err != nil {
+ errs = append(errs, err.Error())
+ }
+ }
+ if len(errs) > 0 {
+ return errors.New(strings.Join(errs, "; "))
+ }
+ return nil
+}
+
+func (r *Runner) ping(path string) error {
+ resp, err := r.httpClient.Get(r.baseURL + path)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("status=%d", resp.StatusCode)
+ }
+ return nil
+}
+
+func (r *Runner) prepareAuth(ctx context.Context) error {
+ reqBody := map[string]any{
+ "admin_key": r.adminKey,
+ "expire_hours": 24,
+ }
+ resp, err := r.doSimpleJSON(ctx, http.MethodPost, "/admin/login", nil, reqBody)
+ if err != nil {
+ return err
+ }
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("admin login status=%d body=%s", resp.StatusCode, string(resp.Body))
+ }
+ var m map[string]any
+ if err := json.Unmarshal(resp.Body, &m); err != nil {
+ return err
+ }
+ token, _ := m["token"].(string)
+ if strings.TrimSpace(token) == "" {
+ return errors.New("empty admin jwt token")
+ }
+ r.adminJWT = token
+ return nil
+}
+
+func (r *Runner) ensureOriginalConfigUntouched() error {
+ raw, err := os.ReadFile(r.originalConfigPath)
+ if err != nil {
+ return err
+ }
+ sum := sha256.Sum256(raw)
+ current := hex.EncodeToString(sum[:])
+ if current != r.originalConfigHash {
+ return fmt.Errorf("original config changed unexpectedly: %s", r.originalConfigPath)
+ }
+ return nil
+}
diff --git a/internal/testsuite/runner_http.go b/internal/testsuite/runner_http.go
new file mode 100644
index 0000000..d98c60a
--- /dev/null
+++ b/internal/testsuite/runner_http.go
@@ -0,0 +1,217 @@
+package testsuite
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "time"
+)
+
+func (cc *caseContext) assert(name string, ok bool, detail string) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.assertions = append(cc.assertions, assertionResult{
+ Name: name,
+ Passed: ok,
+ Detail: detail,
+ })
+}
+
+func (cc *caseContext) request(ctx context.Context, spec requestSpec) (*responseResult, error) {
+ retries := cc.runner.opts.Retries
+ if !spec.Retryable {
+ retries = 0
+ }
+ var lastErr error
+ for attempt := 1; attempt <= retries+1; attempt++ {
+ resp, err := cc.requestOnce(ctx, spec, attempt)
+ if err == nil && resp.StatusCode < 500 {
+ return resp, nil
+ }
+ if err != nil {
+ lastErr = err
+ } else if resp.StatusCode >= 500 {
+ lastErr = fmt.Errorf("status=%d", resp.StatusCode)
+ }
+ if attempt <= retries {
+ sleep := time.Duration(300*(1<<(attempt-1))) * time.Millisecond
+ time.Sleep(sleep)
+ }
+ }
+ return nil, lastErr
+}
+
+func (cc *caseContext) requestOnce(ctx context.Context, spec requestSpec, attempt int) (*responseResult, error) {
+ cc.mu.Lock()
+ cc.seq++
+ seq := cc.seq
+ traceID := fmt.Sprintf("ts_%s_%s_%03d", cc.runner.runID, sanitizeID(cc.id), seq)
+ cc.traceIDsSet[traceID] = struct{}{}
+ cc.mu.Unlock()
+
+ fullURL, err := withTraceQuery(cc.runner.baseURL+spec.Path, traceID)
+ if err != nil {
+ return nil, err
+ }
+
+ headers := map[string]string{}
+ for k, v := range spec.Headers {
+ headers[k] = v
+ }
+ headers["X-Ds2-Test-Trace"] = traceID
+
+ var bodyBytes []byte
+ var bodyAny any
+ if spec.Body != nil {
+ b, err := json.Marshal(spec.Body)
+ if err != nil {
+ return nil, err
+ }
+ bodyBytes = b
+ bodyAny = spec.Body
+ headers["Content-Type"] = "application/json"
+ }
+ cc.mu.Lock()
+ cc.requests = append(cc.requests, requestLog{
+ Seq: seq,
+ Attempt: attempt,
+ TraceID: traceID,
+ Method: spec.Method,
+ URL: fullURL,
+ Headers: headers,
+ Body: bodyAny,
+ Timestamp: time.Now().Format(time.RFC3339Nano),
+ })
+ cc.mu.Unlock()
+
+ reqCtx, cancel := context.WithTimeout(ctx, cc.runner.opts.Timeout)
+ defer cancel()
+ req, err := http.NewRequestWithContext(reqCtx, spec.Method, fullURL, bytes.NewReader(bodyBytes))
+ if err != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ start := time.Now()
+ resp, err := cc.runner.httpClient.Do(req)
+ if err != nil {
+ cc.mu.Lock()
+ cc.responses = append(cc.responses, responseLog{
+ Seq: seq,
+ Attempt: attempt,
+ TraceID: traceID,
+ StatusCode: 0,
+ DurationMS: time.Since(start).Milliseconds(),
+ NetworkErr: err.Error(),
+ ReceivedAt: time.Now().Format(time.RFC3339Nano),
+ })
+ cc.mu.Unlock()
+ return nil, err
+ }
+ defer resp.Body.Close()
+ body, _ := io.ReadAll(resp.Body)
+
+ cc.mu.Lock()
+ cc.responses = append(cc.responses, responseLog{
+ Seq: seq,
+ Attempt: attempt,
+ TraceID: traceID,
+ StatusCode: resp.StatusCode,
+ Headers: resp.Header,
+ BodyText: string(body),
+ DurationMS: time.Since(start).Milliseconds(),
+ ReceivedAt: time.Now().Format(time.RFC3339Nano),
+ })
+
+ if spec.Stream {
+ cc.streamRaw.WriteString(fmt.Sprintf("### trace=%s url=%s\n", traceID, fullURL))
+ cc.streamRaw.Write(body)
+ cc.streamRaw.WriteString("\n\n")
+ }
+ cc.mu.Unlock()
+
+ return &responseResult{
+ StatusCode: resp.StatusCode,
+ Headers: resp.Header,
+ Body: body,
+ TraceID: traceID,
+ URL: fullURL,
+ }, nil
+}
+
+func (cc *caseContext) flushArtifacts(cs caseResult) error {
+ requestPath := filepath.Join(cc.dir, "request.json")
+ headersPath := filepath.Join(cc.dir, "response.headers")
+ bodyPath := filepath.Join(cc.dir, "response.body")
+ streamPath := filepath.Join(cc.dir, "stream.raw")
+ assertPath := filepath.Join(cc.dir, "assertions.json")
+ metaPath := filepath.Join(cc.dir, "meta.json")
+
+ if err := writeJSONFile(requestPath, cc.requests); err != nil {
+ return err
+ }
+ respHeaders := make([]map[string]any, 0, len(cc.responses))
+ respBodies := make([]map[string]any, 0, len(cc.responses))
+ for _, r := range cc.responses {
+ respHeaders = append(respHeaders, map[string]any{
+ "seq": r.Seq,
+ "attempt": r.Attempt,
+ "trace_id": r.TraceID,
+ "status_code": r.StatusCode,
+ "headers": r.Headers,
+ })
+ respBodies = append(respBodies, map[string]any{
+ "seq": r.Seq,
+ "attempt": r.Attempt,
+ "trace_id": r.TraceID,
+ "status_code": r.StatusCode,
+ "body_text": r.BodyText,
+ "network_error": r.NetworkErr,
+ "duration_ms": r.DurationMS,
+ })
+ }
+ if err := writeJSONFile(headersPath, respHeaders); err != nil {
+ return err
+ }
+ if err := writeJSONFile(bodyPath, respBodies); err != nil {
+ return err
+ }
+ if err := os.WriteFile(streamPath, []byte(cc.streamRaw.String()), 0o644); err != nil {
+ return err
+ }
+ if err := writeJSONFile(assertPath, cc.assertions); err != nil {
+ return err
+ }
+ meta := map[string]any{
+ "case_id": cs.CaseID,
+ "trace_id": strings.Join(cs.TraceIDs, ","),
+ "attempt": len(cc.responses),
+ "duration_ms": cs.DurationMS,
+ "status": map[bool]string{true: "passed", false: "failed"}[cs.Passed],
+ "status_codes": cs.StatusCodes,
+ "assertions": cs.Assertions,
+ "artifact_path": cs.ArtifactPath,
+ }
+ return writeJSONFile(metaPath, meta)
+}
+func (r *Runner) doSimpleJSON(ctx context.Context, method, path string, headers map[string]string, body any) (*responseResult, error) {
+ cc := &caseContext{
+ runner: r,
+ id: "auth_prepare",
+ traceIDsSet: map[string]struct{}{},
+ }
+ return cc.request(ctx, requestSpec{
+ Method: method,
+ Path: path,
+ Headers: headers,
+ Body: body,
+ Retryable: true,
+ })
+}
diff --git a/internal/testsuite/runner_registry.go b/internal/testsuite/runner_registry.go
new file mode 100644
index 0000000..08b602a
--- /dev/null
+++ b/internal/testsuite/runner_registry.go
@@ -0,0 +1,43 @@
+package testsuite
+
+import "context"
+
+type caseDef struct {
+ ID string
+ Run func(context.Context, *caseContext) error
+}
+
+func (r *Runner) cases() []caseDef {
+ return []caseDef{
+ {ID: "healthz_ok", Run: r.caseHealthz},
+ {ID: "readyz_ok", Run: r.caseReadyz},
+ {ID: "models_openai", Run: r.caseModelsOpenAI},
+ {ID: "model_openai_by_id", Run: r.caseModelOpenAIByID},
+ {ID: "models_claude", Run: r.caseModelsClaude},
+ {ID: "admin_login_verify", Run: r.caseAdminLoginVerify},
+ {ID: "admin_queue_status", Run: r.caseAdminQueueStatus},
+ {ID: "chat_nonstream_basic", Run: r.caseChatNonstream},
+ {ID: "chat_stream_basic", Run: r.caseChatStream},
+ {ID: "responses_nonstream_basic", Run: r.caseResponsesNonstream},
+ {ID: "responses_stream_basic", Run: r.caseResponsesStream},
+ {ID: "embeddings_contract", Run: r.caseEmbeddings},
+ {ID: "reasoner_stream", Run: r.caseReasonerStream},
+ {ID: "toolcall_nonstream", Run: r.caseToolcallNonstream},
+ {ID: "toolcall_stream", Run: r.caseToolcallStream},
+ {ID: "anthropic_messages_nonstream", Run: r.caseAnthropicNonstream},
+ {ID: "anthropic_messages_stream", Run: r.caseAnthropicStream},
+ {ID: "anthropic_count_tokens", Run: r.caseAnthropicCountTokens},
+ {ID: "admin_account_test_single", Run: r.caseAdminAccountTest},
+ {ID: "concurrency_burst", Run: r.caseConcurrencyBurst},
+ {ID: "concurrency_threshold_limit", Run: r.caseConcurrencyThresholdLimit},
+ {ID: "stream_abort_release", Run: r.caseStreamAbortRelease},
+ {ID: "toolcall_stream_mixed", Run: r.caseToolcallStreamMixed},
+ {ID: "sse_json_integrity", Run: r.caseSSEJSONIntegrity},
+ {ID: "error_contract_invalid_model", Run: r.caseInvalidModel},
+ {ID: "error_contract_missing_messages", Run: r.caseMissingMessages},
+ {ID: "admin_unauthorized_contract", Run: r.caseAdminUnauthorized},
+ {ID: "config_write_isolated", Run: r.caseConfigWriteIsolated},
+ {ID: "token_refresh_managed_account", Run: r.caseTokenRefreshManagedAccount},
+ {ID: "error_contract_invalid_key", Run: r.caseInvalidKey},
+ }
+}
diff --git a/internal/testsuite/runner_summary.go b/internal/testsuite/runner_summary.go
new file mode 100644
index 0000000..25b44a4
--- /dev/null
+++ b/internal/testsuite/runner_summary.go
@@ -0,0 +1,97 @@
+package testsuite
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "time"
+)
+
+func (r *Runner) writeSummary(start, end time.Time) error {
+ passed := 0
+ failed := 0
+ for _, cs := range r.results {
+ if cs.Passed {
+ passed++
+ } else {
+ failed++
+ }
+ }
+ summary := runSummary{
+ RunID: r.runID,
+ StartedAt: start.Format(time.RFC3339Nano),
+ EndedAt: end.Format(time.RFC3339Nano),
+ DurationMS: end.Sub(start).Milliseconds(),
+ Stats: map[string]any{
+ "total": len(r.results),
+ "passed": passed,
+ "failed": failed,
+ },
+ Environment: map[string]any{
+ "go_version": runtime.Version(),
+ "os": runtime.GOOS,
+ "arch": runtime.GOARCH,
+ "base_url": r.baseURL,
+ "config_source": r.originalConfigPath,
+ "config_isolated": r.configCopyPath,
+ "server_log": r.serverLog,
+ "preflight_log": r.preflightLog,
+ "retries": r.opts.Retries,
+ "timeout_seconds": int(r.opts.Timeout.Seconds()),
+ },
+ Cases: r.results,
+ Warnings: r.warnings,
+ }
+ if err := writeJSONFile(filepath.Join(r.runDir, "summary.json"), summary); err != nil {
+ return err
+ }
+ return os.WriteFile(filepath.Join(r.runDir, "summary.md"), []byte(r.summaryMarkdown(summary)), 0o644)
+}
+
+func (r *Runner) summaryMarkdown(s runSummary) string {
+ var b strings.Builder
+ b.WriteString("# DS2API Live Testsuite Summary\n\n")
+ b.WriteString("**Sensitive Notice:** this run stores full raw request/response logs. Do not share artifacts publicly.\n\n")
+ fmt.Fprintf(&b, "- Run ID: `%s`\n", s.RunID)
+ fmt.Fprintf(&b, "- Started: `%s`\n", s.StartedAt)
+ fmt.Fprintf(&b, "- Ended: `%s`\n", s.EndedAt)
+ fmt.Fprintf(&b, "- Duration: `%d ms`\n", s.DurationMS)
+ fmt.Fprintf(&b, "- Passed/Failed: `%d/%d`\n\n", s.Stats["passed"], s.Stats["failed"])
+ if len(s.Warnings) > 0 {
+ b.WriteString("## Warnings\n\n")
+ for _, w := range s.Warnings {
+ fmt.Fprintf(&b, "- %s\n", w)
+ }
+ b.WriteString("\n")
+ }
+ b.WriteString("## Failed Cases\n\n")
+ hasFailed := false
+ for _, c := range s.Cases {
+ if c.Passed {
+ continue
+ }
+ hasFailed = true
+ fmt.Fprintf(&b, "- `%s`: %s\n", c.CaseID, c.Error)
+ if len(c.TraceIDs) > 0 {
+ fmt.Fprintf(&b, " - trace_ids: `%s`\n", strings.Join(c.TraceIDs, ", "))
+ fmt.Fprintf(&b, " - grep: `rg \"%s\" %s`\n", c.TraceIDs[0], filepath.Join(r.runDir, "server.log"))
+ }
+ fmt.Fprintf(&b, " - artifact: `%s`\n", c.ArtifactPath)
+ }
+ if !hasFailed {
+ b.WriteString("- none\n")
+ }
+ b.WriteString("\n## Case Table\n\n")
+ b.WriteString("| case_id | status | duration_ms | statuses | artifact |\n")
+ b.WriteString("|---|---:|---:|---|---|\n")
+ for _, c := range s.Cases {
+ status := "PASS"
+ if !c.Passed {
+ status = "FAIL"
+ }
+ fmt.Fprintf(&b, "| %s | %s | %d | %v | `%s` |\n", c.CaseID, status, c.DurationMS, c.StatusCodes, c.ArtifactPath)
+ }
+ return b.String()
+}
diff --git a/internal/testsuite/runner_utils.go b/internal/testsuite/runner_utils.go
new file mode 100644
index 0000000..c4879c6
--- /dev/null
+++ b/internal/testsuite/runner_utils.go
@@ -0,0 +1,202 @@
+package testsuite
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net"
+ "net/url"
+ "os"
+ "sort"
+ "strings"
+)
+
+func parseSSEFrames(body []byte) ([]map[string]any, bool) {
+ lines := strings.Split(string(body), "\n")
+ frames := make([]map[string]any, 0, len(lines))
+ done := false
+ for _, line := range lines {
+ line = strings.TrimSpace(line)
+ if !strings.HasPrefix(line, "data:") {
+ continue
+ }
+ payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ if payload == "" {
+ continue
+ }
+ if payload == "[DONE]" {
+ done = true
+ continue
+ }
+ var m map[string]any
+ if err := json.Unmarshal([]byte(payload), &m); err == nil {
+ frames = append(frames, m)
+ }
+ }
+ return frames, done
+}
+
+func parseClaudeStreamEvents(body []byte) []string {
+ events := []string{}
+ seen := map[string]bool{}
+ lines := strings.Split(string(body), "\n")
+ for _, line := range lines {
+ line = strings.TrimSpace(line)
+ if !strings.HasPrefix(line, "data:") {
+ continue
+ }
+ payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ if payload == "" {
+ continue
+ }
+ var m map[string]any
+ if err := json.Unmarshal([]byte(payload), &m); err != nil {
+ continue
+ }
+ t := asString(m["type"])
+ if t == "" || seen[t] {
+ continue
+ }
+ seen[t] = true
+ events = append(events, t)
+ }
+ return events
+}
+
+func extractModelIDs(body []byte) []string {
+ var m map[string]any
+ if err := json.Unmarshal(body, &m); err != nil {
+ return nil
+ }
+ out := []string{}
+ data, _ := m["data"].([]any)
+ for _, it := range data {
+ item, _ := it.(map[string]any)
+ id := asString(item["id"])
+ if id != "" {
+ out = append(out, id)
+ }
+ }
+ return out
+}
+
+func withTraceQuery(rawURL, traceID string) (string, error) {
+ u, err := url.Parse(rawURL)
+ if err != nil {
+ return "", err
+ }
+ q := u.Query()
+ q.Set("__trace_id", traceID)
+ u.RawQuery = q.Encode()
+ return u.String(), nil
+}
+
+func writeJSONFile(path string, v any) error {
+ b, err := json.MarshalIndent(v, "", " ")
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(path, b, 0o644)
+}
+
+func prepareServerEnv(base []string, overrides map[string]string) []string {
+ out := make([]string, 0, len(base)+len(overrides))
+ skip := map[string]struct{}{}
+ for k := range overrides {
+ skip[k] = struct{}{}
+ }
+ for _, e := range base {
+ parts := strings.SplitN(e, "=", 2)
+ if len(parts) != 2 {
+ continue
+ }
+ if _, ok := skip[parts[0]]; ok {
+ continue
+ }
+ out = append(out, e)
+ }
+ for k, v := range overrides {
+ out = append(out, k+"="+v)
+ }
+ return out
+}
+
+func findFreePort() (int, error) {
+ ln, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ return 0, err
+ }
+ defer ln.Close()
+ addr, ok := ln.Addr().(*net.TCPAddr)
+ if !ok {
+ return 0, errors.New("failed to detect tcp port")
+ }
+ return addr.Port, nil
+}
+
+func uniqueStatusCodes(in []responseLog) []int {
+ set := map[int]struct{}{}
+ for _, it := range in {
+ if it.StatusCode > 0 {
+ set[it.StatusCode] = struct{}{}
+ }
+ }
+ out := make([]int, 0, len(set))
+ for k := range set {
+ out = append(out, k)
+ }
+ sort.Ints(out)
+ return out
+}
+
+func has5xx(dist map[int]int) (int, bool) {
+ for k := range dist {
+ if k >= 500 {
+ return k, true
+ }
+ }
+ return 0, false
+}
+
+func sanitizeID(s string) string {
+ s = strings.ReplaceAll(s, ":", "_")
+ s = strings.ReplaceAll(s, "/", "_")
+ s = strings.ReplaceAll(s, " ", "_")
+ return s
+}
+
+func asString(v any) string {
+ if v == nil {
+ return ""
+ }
+ switch x := v.(type) {
+ case string:
+ return strings.TrimSpace(x)
+ default:
+ return strings.TrimSpace(fmt.Sprintf("%v", v))
+ }
+}
+
+func toInt(v any) int {
+ switch x := v.(type) {
+ case float64:
+ return int(x)
+ case float32:
+ return int(x)
+ case int:
+ return x
+ case int64:
+ return int(x)
+ default:
+ return 0
+ }
+}
+
+func contains(xs []string, target string) bool {
+ for _, x := range xs {
+ if x == target {
+ return true
+ }
+ }
+ return false
+}
diff --git a/internal/util/toolcalls_candidates.go b/internal/util/toolcalls_candidates.go
new file mode 100644
index 0000000..4e8afc4
--- /dev/null
+++ b/internal/util/toolcalls_candidates.go
@@ -0,0 +1,138 @@
+package util
+
+import (
+ "regexp"
+ "strings"
+)
+
+var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
+var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
+var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
+
+func buildToolCallCandidates(text string) []string {
+ trimmed := strings.TrimSpace(text)
+ candidates := []string{trimmed}
+
+ // fenced code block candidates: ```json ... ```
+ for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
+ if len(match) >= 2 {
+ candidates = append(candidates, strings.TrimSpace(match[1]))
+ }
+ }
+
+ // best-effort extraction around "tool_calls" key in mixed text payloads.
+ candidates = append(candidates, extractToolCallObjects(trimmed)...)
+
+ // best-effort object slice: from first '{' to last '}'
+ first := strings.Index(trimmed, "{")
+ last := strings.LastIndex(trimmed, "}")
+ if first >= 0 && last > first {
+ candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
+ }
+
+ // legacy regex extraction fallback
+ if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
+ candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
+ }
+
+ uniq := make([]string, 0, len(candidates))
+ seen := map[string]struct{}{}
+ for _, c := range candidates {
+ if c == "" {
+ continue
+ }
+ if _, ok := seen[c]; ok {
+ continue
+ }
+ seen[c] = struct{}{}
+ uniq = append(uniq, c)
+ }
+ return uniq
+}
+
+func extractToolCallObjects(text string) []string {
+ if text == "" {
+ return nil
+ }
+ lower := strings.ToLower(text)
+ out := []string{}
+ offset := 0
+ for {
+ idx := strings.Index(lower[offset:], "tool_calls")
+ if idx < 0 {
+ break
+ }
+ idx += offset
+ start := strings.LastIndex(text[:idx], "{")
+ for start >= 0 {
+ candidate, end, ok := extractJSONObject(text, start)
+ if ok {
+ // Move forward to avoid repeatedly matching the same object.
+ offset = end
+ out = append(out, strings.TrimSpace(candidate))
+ break
+ }
+ start = strings.LastIndex(text[:start], "{")
+ }
+ if start < 0 {
+ offset = idx + len("tool_calls")
+ }
+ }
+ return out
+}
+
+func extractJSONObject(text string, start int) (string, int, bool) {
+ if start < 0 || start >= len(text) || text[start] != '{' {
+ return "", 0, false
+ }
+ depth := 0
+ quote := byte(0)
+ escaped := false
+ for i := start; i < len(text); i++ {
+ ch := text[i]
+ if quote != 0 {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == quote {
+ quote = 0
+ }
+ continue
+ }
+ if ch == '"' || ch == '\'' {
+ quote = ch
+ continue
+ }
+ if ch == '{' {
+ depth++
+ continue
+ }
+ if ch == '}' {
+ depth--
+ if depth == 0 {
+ return text[start : i+1], i + 1, true
+ }
+ }
+ }
+ return "", 0, false
+}
+
+func looksLikeToolExampleContext(text string) bool {
+ t := strings.ToLower(strings.TrimSpace(text))
+ if t == "" {
+ return false
+ }
+ return strings.Contains(t, "```")
+}
+
+func stripFencedCodeBlocks(text string) string {
+ if strings.TrimSpace(text) == "" {
+ return ""
+ }
+ return fencedBlockPattern.ReplaceAllString(text, " ")
+}
diff --git a/internal/util/toolcalls_format.go b/internal/util/toolcalls_format.go
new file mode 100644
index 0000000..8feb48f
--- /dev/null
+++ b/internal/util/toolcalls_format.go
@@ -0,0 +1,41 @@
+package util
+
+import (
+ "encoding/json"
+ "strings"
+
+ "github.com/google/uuid"
+)
+
+func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
+ out := make([]map[string]any, 0, len(calls))
+ for _, c := range calls {
+ args, _ := json.Marshal(c.Input)
+ out = append(out, map[string]any{
+ "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ "type": "function",
+ "function": map[string]any{
+ "name": c.Name,
+ "arguments": string(args),
+ },
+ })
+ }
+ return out
+}
+
+func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
+ out := make([]map[string]any, 0, len(calls))
+ for i, c := range calls {
+ args, _ := json.Marshal(c.Input)
+ out = append(out, map[string]any{
+ "index": i,
+ "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
+ "type": "function",
+ "function": map[string]any{
+ "name": c.Name,
+ "arguments": string(args),
+ },
+ })
+ }
+ return out
+}
diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls_parse.go
similarity index 52%
rename from internal/util/toolcalls.go
rename to internal/util/toolcalls_parse.go
index 9e44b94..ab9fe84 100644
--- a/internal/util/toolcalls.go
+++ b/internal/util/toolcalls_parse.go
@@ -2,16 +2,9 @@ package util
import (
"encoding/json"
- "regexp"
"strings"
-
- "github.com/google/uuid"
)
-var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
-var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
-var fencedBlockPattern = regexp.MustCompile("(?s)```.*?```")
-
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
@@ -102,47 +95,6 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
return out
}
-func buildToolCallCandidates(text string) []string {
- trimmed := strings.TrimSpace(text)
- candidates := []string{trimmed}
-
- // fenced code block candidates: ```json ... ```
- for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
- if len(match) >= 2 {
- candidates = append(candidates, strings.TrimSpace(match[1]))
- }
- }
-
- // best-effort extraction around "tool_calls" key in mixed text payloads.
- candidates = append(candidates, extractToolCallObjects(trimmed)...)
-
- // best-effort object slice: from first '{' to last '}'
- first := strings.Index(trimmed, "{")
- last := strings.LastIndex(trimmed, "}")
- if first >= 0 && last > first {
- candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
- }
-
- // legacy regex extraction fallback
- if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
- candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
- }
-
- uniq := make([]string, 0, len(candidates))
- seen := map[string]struct{}{}
- for _, c := range candidates {
- if c == "" {
- continue
- }
- if _, ok := seen[c]; ok {
- continue
- }
- seen[c] = struct{}{}
- uniq = append(uniq, c)
- }
- return uniq
-}
-
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
@@ -243,123 +195,3 @@ func parseToolCallInput(v any) map[string]any {
return map[string]any{}
}
}
-
-func extractToolCallObjects(text string) []string {
- if text == "" {
- return nil
- }
- lower := strings.ToLower(text)
- out := []string{}
- offset := 0
- for {
- idx := strings.Index(lower[offset:], "tool_calls")
- if idx < 0 {
- break
- }
- idx += offset
- start := strings.LastIndex(text[:idx], "{")
- for start >= 0 {
- candidate, end, ok := extractJSONObject(text, start)
- if ok {
- // Move forward to avoid repeatedly matching the same object.
- offset = end
- out = append(out, strings.TrimSpace(candidate))
- break
- }
- start = strings.LastIndex(text[:start], "{")
- }
- if start < 0 {
- offset = idx + len("tool_calls")
- }
- }
- return out
-}
-
-func extractJSONObject(text string, start int) (string, int, bool) {
- if start < 0 || start >= len(text) || text[start] != '{' {
- return "", 0, false
- }
- depth := 0
- quote := byte(0)
- escaped := false
- for i := start; i < len(text); i++ {
- ch := text[i]
- if quote != 0 {
- if escaped {
- escaped = false
- continue
- }
- if ch == '\\' {
- escaped = true
- continue
- }
- if ch == quote {
- quote = 0
- }
- continue
- }
- if ch == '"' || ch == '\'' {
- quote = ch
- continue
- }
- if ch == '{' {
- depth++
- continue
- }
- if ch == '}' {
- depth--
- if depth == 0 {
- return text[start : i+1], i + 1, true
- }
- }
- }
- return "", 0, false
-}
-
-func looksLikeToolExampleContext(text string) bool {
- t := strings.ToLower(strings.TrimSpace(text))
- if t == "" {
- return false
- }
- return strings.Contains(t, "```")
-}
-
-func stripFencedCodeBlocks(text string) string {
- if strings.TrimSpace(text) == "" {
- return ""
- }
- return fencedBlockPattern.ReplaceAllString(text, " ")
-}
-
-func FormatOpenAIToolCalls(calls []ParsedToolCall) []map[string]any {
- out := make([]map[string]any, 0, len(calls))
- for _, c := range calls {
- args, _ := json.Marshal(c.Input)
- out = append(out, map[string]any{
- "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
- "type": "function",
- "function": map[string]any{
- "name": c.Name,
- "arguments": string(args),
- },
- })
- }
- return out
-}
-
-func FormatOpenAIStreamToolCalls(calls []ParsedToolCall) []map[string]any {
- out := make([]map[string]any, 0, len(calls))
- for i, c := range calls {
- args, _ := json.Marshal(c.Input)
- out = append(out, map[string]any{
- "index": i,
- "id": "call_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
- "type": "function",
- "function": map[string]any{
- "name": c.Name,
- "arguments": string(args),
- },
- })
- }
- return out
-}
diff --git a/plans/refactor-baseline.md b/plans/refactor-baseline.md
new file mode 100644
index 0000000..9b31ed1
--- /dev/null
+++ b/plans/refactor-baseline.md
@@ -0,0 +1,31 @@
+# DS2API Refactor Baseline (Backfilled)
+
+- Recorded at: `2026-02-22T08:53:54Z`
+- Branch: `dev`
+- HEAD: `5d3989a`
+- Scope: backend + node api + webui large-file decoupling (no behavior change)
+
+## Gate Commands
+
+1. `./tests/scripts/run-unit-all.sh`
+ - Result: PASS
+ - Includes:
+ - `go test ./...`
+ - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js`
+2. `npm --prefix webui run build`
+ - Result: PASS
+3. `./tests/scripts/check-refactor-line-gate.sh`
+ - Result: PASS (`checked=131 missing=0 over_limit=0`)
+4. Stage gates (1-5) replay:
+ - `go test ./internal/config ./internal/admin ./internal/account ./internal/deepseek ./internal/format/openai` -> PASS
+ - `go test ./internal/adapter/openai ./internal/util ./internal/sse ./internal/compat` -> PASS
+ - `go test ./internal/adapter/claude ./internal/adapter/gemini ./internal/config` -> PASS
+ - `go test ./internal/testsuite ./cmd/ds2api-tests` -> PASS
+ - `node --test api/helpers/stream-tool-sieve.test.js api/chat-stream.test.js api/compat/js_compat_test.js` -> PASS
+5. Final full regression:
+ - `go test ./... -count=1` -> PASS
+
+## Notes
+
+- This baseline file is a backfill artifact for phase 0 process tracking.
+- Frontend manual smoke for phase 6 still requires human execution and sign-off.
diff --git a/plans/refactor-line-gate-targets.txt b/plans/refactor-line-gate-targets.txt
new file mode 100644
index 0000000..d3678ad
--- /dev/null
+++ b/plans/refactor-line-gate-targets.txt
@@ -0,0 +1,151 @@
+# Line gate targets for large-file decoupling refactor.
+# Default limit: 300 lines
+# Entry/facade limit: 120 lines (enforced in script)
+
+internal/config/config.go
+internal/config/logger.go
+internal/config/paths.go
+internal/config/codec.go
+internal/config/store.go
+internal/config/store_index.go
+internal/config/store_accessors.go
+internal/config/account.go
+
+internal/admin/handler_config_read.go
+internal/admin/handler_config_write.go
+internal/admin/handler_config_import.go
+internal/admin/handler_settings_read.go
+internal/admin/handler_settings_write.go
+internal/admin/handler_settings_parse.go
+internal/admin/handler_settings_runtime.go
+internal/admin/handler_accounts_crud.go
+internal/admin/handler_accounts_testing.go
+internal/admin/handler_accounts_queue.go
+
+internal/account/pool_core.go
+internal/account/pool_acquire.go
+internal/account/pool_waiters.go
+internal/account/pool_limits.go
+
+internal/deepseek/client_core.go
+internal/deepseek/client_auth.go
+internal/deepseek/client_completion.go
+internal/deepseek/client_http_json.go
+internal/deepseek/client_http_helpers.go
+
+internal/format/openai/render_chat.go
+internal/format/openai/render_responses.go
+internal/format/openai/render_stream_events.go
+internal/format/openai/render_usage.go
+
+internal/adapter/openai/handler_routes.go
+internal/adapter/openai/handler_chat.go
+internal/adapter/openai/handler_errors.go
+internal/adapter/openai/handler_toolcall_policy.go
+internal/adapter/openai/handler_toolcall_format.go
+internal/adapter/openai/responses_handler.go
+internal/adapter/openai/responses_input_normalize.go
+internal/adapter/openai/responses_input_items.go
+internal/adapter/openai/responses_stream_runtime_core.go
+internal/adapter/openai/responses_stream_runtime_events.go
+internal/adapter/openai/responses_stream_runtime_toolcalls.go
+internal/adapter/openai/tool_sieve_state.go
+internal/adapter/openai/tool_sieve_core.go
+internal/adapter/openai/tool_sieve_incremental.go
+internal/adapter/openai/tool_sieve_jsonscan.go
+
+internal/util/toolcalls_parse.go
+internal/util/toolcalls_candidates.go
+internal/util/toolcalls_format.go
+
+internal/adapter/claude/handler_routes.go
+internal/adapter/claude/handler_messages.go
+internal/adapter/claude/handler_tokens.go
+internal/adapter/claude/handler_errors.go
+internal/adapter/claude/handler_utils.go
+internal/adapter/claude/stream_runtime_core.go
+internal/adapter/claude/stream_runtime_emit.go
+internal/adapter/claude/stream_runtime_finalize.go
+
+internal/adapter/gemini/handler_routes.go
+internal/adapter/gemini/handler_generate.go
+internal/adapter/gemini/handler_stream_runtime.go
+internal/adapter/gemini/handler_errors.go
+internal/adapter/gemini/convert_request.go
+internal/adapter/gemini/convert_messages.go
+internal/adapter/gemini/convert_tools.go
+internal/adapter/gemini/convert_passthrough.go
+
+internal/testsuite/runner_core.go
+internal/testsuite/runner_env.go
+internal/testsuite/runner_http.go
+internal/testsuite/runner_cases_openai.go
+internal/testsuite/runner_cases_openai_advanced.go
+internal/testsuite/runner_cases_admin.go
+internal/testsuite/runner_cases_claude.go
+internal/testsuite/runner_summary.go
+internal/testsuite/runner_utils.go
+internal/testsuite/runner_defaults.go
+internal/testsuite/runner_registry.go
+internal/testsuite/edge_cases_abort.go
+internal/testsuite/edge_cases_error_contract.go
+
+api/chat-stream.js
+api/chat-stream/index.js
+api/chat-stream/vercel_stream.js
+api/chat-stream/proxy_go.js
+api/chat-stream/sse_parse.js
+api/chat-stream/http_internal.js
+api/chat-stream/toolcall_policy.js
+api/chat-stream/error_shape.js
+api/chat-stream/token_usage.js
+api/chat-stream/stream_emitter.js
+
+api/helpers/stream-tool-sieve.js
+api/helpers/stream-tool-sieve/index.js
+api/helpers/stream-tool-sieve/state.js
+api/helpers/stream-tool-sieve/sieve.js
+api/helpers/stream-tool-sieve/incremental.js
+api/helpers/stream-tool-sieve/jsonscan.js
+api/helpers/stream-tool-sieve/parse.js
+api/helpers/stream-tool-sieve/format.js
+
+webui/src/App.jsx
+webui/src/app/AppRoutes.jsx
+webui/src/app/useAdminAuth.js
+webui/src/app/useAdminConfig.js
+webui/src/layout/DashboardShell.jsx
+
+webui/src/components/AccountManager.jsx
+webui/src/features/account/AccountManagerContainer.jsx
+webui/src/features/account/useAccountsData.js
+webui/src/features/account/useAccountActions.js
+webui/src/features/account/QueueCards.jsx
+webui/src/features/account/ApiKeysPanel.jsx
+webui/src/features/account/AccountsTable.jsx
+webui/src/features/account/AddKeyModal.jsx
+webui/src/features/account/AddAccountModal.jsx
+
+webui/src/components/ApiTester.jsx
+webui/src/features/apiTester/ApiTesterContainer.jsx
+webui/src/features/apiTester/useApiTesterState.js
+webui/src/features/apiTester/useChatStreamClient.js
+webui/src/features/apiTester/ConfigPanel.jsx
+webui/src/features/apiTester/ChatPanel.jsx
+
+webui/src/components/Settings.jsx
+webui/src/features/settings/SettingsContainer.jsx
+webui/src/features/settings/useSettingsForm.js
+webui/src/features/settings/settingsApi.js
+webui/src/features/settings/SecuritySection.jsx
+webui/src/features/settings/RuntimeSection.jsx
+webui/src/features/settings/BehaviorSection.jsx
+webui/src/features/settings/ModelSection.jsx
+webui/src/features/settings/BackupSection.jsx
+
+webui/src/components/VercelSync.jsx
+webui/src/features/vercel/VercelSyncContainer.jsx
+webui/src/features/vercel/useVercelSyncState.js
+webui/src/features/vercel/VercelSyncForm.jsx
+webui/src/features/vercel/VercelSyncStatus.jsx
+webui/src/features/vercel/VercelGuide.jsx
diff --git a/plans/refactor-line-gate.md b/plans/refactor-line-gate.md
new file mode 100644
index 0000000..86f0d82
--- /dev/null
+++ b/plans/refactor-line-gate.md
@@ -0,0 +1,21 @@
+# Refactor Line Gate
+
+## Rules
+
+1. Production file default upper bound: `<= 300` lines.
+2. Entry/facade files upper bound: `<= 120` lines.
+3. Scope is limited to target files in `plans/refactor-line-gate-targets.txt`.
+4. Test files are out of scope for this gate.
+
+## Command
+
+```bash
+./tests/scripts/check-refactor-line-gate.sh
+```
+
+## Naming Note
+
+- Original split plan used `internal/admin/handler_accounts_test.go` for account probing logic.
+- In Go, `*_test.go` files are test-only compilation units and cannot host production handlers.
+- The production file is implemented as `internal/admin/handler_accounts_testing.go`.
+
diff --git a/plans/stage6-manual-smoke.md b/plans/stage6-manual-smoke.md
new file mode 100644
index 0000000..8875b54
--- /dev/null
+++ b/plans/stage6-manual-smoke.md
@@ -0,0 +1,29 @@
+# Stage 6 Manual Smoke Checklist
+
+- Date:
+- Tester:
+- Environment:
+
+## Items
+
+1. Login flow (`/admin/login`) succeeds and failure message shape unchanged.
+2. Account manager:
+ - add/edit/delete account
+ - queue status cards render and refresh
+3. API tester:
+ - non-stream request succeeds
+ - stream request receives incremental output and final state
+4. Settings:
+ - read settings
+ - save settings
+ - backup/export path works
+5. Vercel sync:
+ - status poll
+ - manual refresh
+ - sync action and status feedback text
+
+## Result
+
+- Status: `PENDING`
+- Notes:
+
diff --git a/tests/scripts/check-refactor-line-gate.sh b/tests/scripts/check-refactor-line-gate.sh
new file mode 100755
index 0000000..0568666
--- /dev/null
+++ b/tests/scripts/check-refactor-line-gate.sh
@@ -0,0 +1,62 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+ROOT_DIR="$(cd "$(dirname "$0")/../.." && pwd)"
+TARGETS_FILE="$ROOT_DIR/plans/refactor-line-gate-targets.txt"
+
+DEFAULT_MAX=300
+ENTRY_MAX=120
+
+is_entry_file() {
+ case "$1" in
+ api/chat-stream.js|\
+ api/helpers/stream-tool-sieve.js|\
+ webui/src/App.jsx|\
+ webui/src/components/AccountManager.jsx|\
+ webui/src/components/ApiTester.jsx|\
+ webui/src/components/Settings.jsx|\
+ webui/src/components/VercelSync.jsx)
+ return 0
+ ;;
+ esac
+ return 1
+}
+
+if [[ ! -f "$TARGETS_FILE" ]]; then
+ echo "missing targets file: $TARGETS_FILE" >&2
+ exit 1
+fi
+
+missing=0
+over=0
+checked=0
+
+while IFS= read -r file; do
+ [[ -z "$file" ]] && continue
+ [[ "${file:0:1}" == "#" ]] && continue
+
+ checked=$((checked + 1))
+ abs="$ROOT_DIR/$file"
+ if [[ ! -f "$abs" ]]; then
+ echo "MISSING $file"
+ missing=$((missing + 1))
+ continue
+ fi
+
+ lines="$(wc -l < "$abs" | tr -d ' ')"
+ limit="$DEFAULT_MAX"
+ if is_entry_file "$file"; then
+ limit="$ENTRY_MAX"
+ fi
+
+ if (( lines > limit )); then
+ echo "OVER $file lines=$lines limit=$limit"
+ over=$((over + 1))
+ fi
+done < "$TARGETS_FILE"
+
+echo "checked=$checked missing=$missing over_limit=$over"
+
+if (( missing > 0 || over > 0 )); then
+ exit 1
+fi
diff --git a/webui/src/App.jsx b/webui/src/App.jsx
index 2c3f099..8067b8b 100644
--- a/webui/src/App.jsx
+++ b/webui/src/App.jsx
@@ -1,346 +1,3 @@
-import { useState, useEffect, useCallback, useMemo } from 'react'
-import {
- Routes,
- Route,
- Navigate,
- useNavigate,
- useLocation
-} from 'react-router-dom'
-import {
- LayoutDashboard,
- Key,
- Upload,
- Cloud,
- Settings as SettingsIcon,
- LogOut,
- Menu,
- X,
- Server,
- Users
-} from 'lucide-react'
-import clsx from 'clsx'
+import AppRoutes from './app/AppRoutes'
-import AccountManager from './components/AccountManager'
-import ApiTester from './components/ApiTester'
-import BatchImport from './components/BatchImport'
-import VercelSync from './components/VercelSync'
-import Settings from './components/Settings'
-import Login from './components/Login'
-import LandingPage from './components/LandingPage'
-import LanguageToggle from './components/LanguageToggle'
-import { useI18n } from './i18n'
-import { detectRuntimeEnv } from './utils/runtimeEnv'
-
-function Dashboard({ token, onLogout, config, fetchConfig, showMessage, message, onForceLogout, isVercel }) {
- const { t } = useI18n()
- const [activeTab, setActiveTab] = useState('accounts')
- const [sidebarOpen, setSidebarOpen] = useState(false)
-
- const navItems = [
- { id: 'accounts', label: t('nav.accounts.label'), icon: Users, description: t('nav.accounts.desc') },
- { id: 'test', label: t('nav.test.label'), icon: Server, description: t('nav.test.desc') },
- { id: 'import', label: t('nav.import.label'), icon: Upload, description: t('nav.import.desc') },
- { id: 'vercel', label: t('nav.vercel.label'), icon: Cloud, description: t('nav.vercel.desc') },
- { id: 'settings', label: t('nav.settings.label'), icon: SettingsIcon, description: t('nav.settings.desc') },
- ]
-
- const authFetch = useCallback(async (url, options = {}) => {
- const headers = {
- ...options.headers,
- 'Authorization': `Bearer ${token}`
- }
- const res = await fetch(url, { ...options, headers })
-
- if (res.status === 401) {
- onLogout()
- throw new Error(t('auth.expired'))
- }
- return res
- }, [onLogout, t, token])
-
- const renderTab = () => {
- switch (activeTab) {
- case 'accounts':
- return
- {navItems.find(n => n.id === activeTab)?.description} -
-{t('auth.checking')}
-{t('auth.checking')}
+{t('accountManager.available')}
-{t('accountManager.inUse')}
-{t('accountManager.totalPool')}
-{t('accountManager.apiKeysDesc')} ({config.keys?.length || 0})
-{t('accountManager.accountsDesc')}
-{t('accountManager.generateHint')}
-- {customKeyManaged ? t('apiTester.modeManaged') : t('apiTester.modeDirect')} -
- )} -- {t('vercel.description')} -
- {pollPaused && ( -- {t('vercel.pollPaused', { count: pollFailures })} -
- -
-
{t('vercel.projectIdHint')}
-- {t('vercel.redeployHint')} -
-{result.message}
- - {result.deployment_url && ( - - )} -{t('vercel.steps.one')}
-{t('vercel.steps.two')}
-
- {t('vercel.steps.three')} DS2API_CONFIG_JSON
-
{t('vercel.steps.four')}
-{t('accountManager.accountsDesc')}
+{t('accountManager.generateHint')}
+{t('accountManager.apiKeysDesc')} ({config.keys?.length || 0})
+{t('accountManager.available')}
+{t('accountManager.inUse')}
+{t('accountManager.totalPool')}
++ {customKeyManaged ? t('apiTester.modeManaged') : t('apiTester.modeDirect')} +
+ )} +{t('vercel.steps.one')}
+{t('vercel.steps.two')}
+
+ {t('vercel.steps.three')} DS2API_CONFIG_JSON
+
{t('vercel.steps.four')}
++ {t('vercel.description')} +
+ {pollPaused && ( ++ {t('vercel.pollPaused', { count: pollFailures })} +
+ +
+
{t('vercel.projectIdHint')}
++ {t('vercel.redeployHint')} +
+{result.message}
+ + {result.deployment_url && ( + + )} ++ {navItems.find(n => n.id === activeTab)?.description} +
+