diff --git a/api/chat-stream.js b/api/chat-stream.js index aa92b17..309c473 100644 --- a/api/chat-stream.js +++ b/api/chat-stream.js @@ -1,11 +1,13 @@ 'use strict'; +const crypto = require('crypto'); + const { extractToolNames, createToolSieveState, processToolSieveChunk, flushToolSieve, - parseToolCalls, + parseStandaloneToolCalls, formatOpenAIStreamToolCalls, } = require('./helpers/stream-tool-sieve'); @@ -90,16 +92,49 @@ module.exports = async function handler(req, res) { return; } const releaseLease = createLeaseReleaser(req, leaseID); + const upstreamController = new AbortController(); + let clientClosed = false; + let reader = null; + const markClientClosed = () => { + if (clientClosed) { + return; + } + clientClosed = true; + upstreamController.abort(); + if (reader && typeof reader.cancel === 'function') { + Promise.resolve(reader.cancel()).catch(() => {}); + } + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); try { - const completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { - method: 'POST', - headers: { - ...BASE_HEADERS, - authorization: `Bearer ${deepseekToken}`, - 'x-ds-pow-response': powHeader, - }, - body: JSON.stringify(completionPayload), - }); + let completionRes; + try { + completionRes = await fetch(DEEPSEEK_COMPLETION_URL, { + method: 'POST', + headers: { + ...BASE_HEADERS, + authorization: `Bearer ${deepseekToken}`, + 'x-ds-pow-response': powHeader, + }, + body: JSON.stringify(completionPayload), + signal: upstreamController.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + return; + } + throw err; + } + if (clientClosed) { + return; + } if (!completionRes.ok || !completionRes.body) { const detail = await safeReadText(completionRes); @@ -124,12 +159,16 @@ module.exports = async function handler(req, res) { const toolSieveEnabled = toolNames.length > 0; const toolSieveState = createToolSieveState(); let toolCallsEmitted = false; + const streamToolCallIDs = new Map(); const decoder = new TextDecoder(); - const reader = completionRes.body.getReader(); + reader = completionRes.body.getReader(); let buffered = ''; let ended = false; const sendFrame = (obj) => { + if (clientClosed || res.writableEnded || res.destroyed) { + return; + } res.write(`data: ${JSON.stringify(obj)}\n\n`); if (typeof res.flush === 'function') { res.flush(); @@ -156,7 +195,11 @@ module.exports = async function handler(req, res) { return; } ended = true; - const detected = parseToolCalls(outputText, toolNames); + if (clientClosed || res.writableEnded || res.destroyed) { + await releaseLease(); + return; + } + const detected = parseStandaloneToolCalls(outputText, toolNames); if (detected.length > 0 && !toolCallsEmitted) { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(detected) }); @@ -179,14 +222,22 @@ module.exports = async function handler(req, res) { choices: [{ delta: {}, index: 0, finish_reason: reason }], usage: buildUsage(finalPrompt, thinkingText, outputText), }); - res.write('data: [DONE]\n\n'); + if (!res.writableEnded && !res.destroyed) { + res.write('data: [DONE]\n\n'); + } await releaseLease(); - res.end(); + if (!res.writableEnded && !res.destroyed) { + res.end(); + } }; try { // eslint-disable-next-line no-constant-condition while (true) { + if (clientClosed) { + await finish('stop'); + return; + } const { value, done } = await reader.read(); if (done) { break; @@ -245,6 +296,11 @@ module.exports = async function handler(req, res) { } const events = processToolSieveChunk(toolSieveState, p.text, toolNames); for (const evt of events) { + if (evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0) { + toolCallsEmitted = true; + sendDeltaFrame({ tool_calls: formatIncrementalToolCallDeltas(evt.deltas, streamToolCallIDs) }); + continue; + } if (evt.type === 'tool_calls') { toolCallsEmitted = true; sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls) }); @@ -259,10 +315,16 @@ module.exports = async function handler(req, res) { } } await finish('stop'); - } catch (_err) { + } catch (err) { + if (clientClosed || isAbortError(err)) { + await finish('stop'); + return; + } await finish('stop'); } } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); await releaseLease(); } }; @@ -656,6 +718,55 @@ function buildUsage(prompt, thinking, output) { }; } +function formatIncrementalToolCallDeltas(deltas, idStore) { + if (!Array.isArray(deltas) || deltas.length === 0) { + return []; + } + const out = []; + for (const d of deltas) { + if (!d || typeof d !== 'object') { + continue; + } + const index = Number.isInteger(d.index) ? d.index : 0; + const id = ensureStreamToolCallID(idStore, index); + const item = { + index, + id, + type: 'function', + }; + const fn = {}; + if (asString(d.name)) { + fn.name = asString(d.name); + } + if (typeof d.arguments === 'string' && d.arguments !== '') { + fn.arguments = d.arguments; + } + if (Object.keys(fn).length > 0) { + item.function = fn; + } + out.push(item); + } + return out; +} + +function ensureStreamToolCallID(idStore, index) { + const key = Number.isInteger(index) ? index : 0; + const existing = idStore.get(key); + if (existing) { + return existing; + } + const next = `call_${newCallID()}`; + idStore.set(key, next); + return next; +} + +function newCallID() { + if (typeof crypto.randomUUID === 'function') { + return crypto.randomUUID().replace(/-/g, ''); + } + return `${Date.now()}${Math.floor(Math.random() * 1e9)}`; +} + function estimateTokens(text) { const t = asString(text); if (!t) { @@ -667,44 +778,92 @@ function estimateTokens(text) { async function proxyToGo(req, res, rawBody) { const url = buildInternalGoURL(req); - - const upstream = await fetch(url.toString(), { - method: 'POST', - headers: buildInternalGoHeaders(req, { withContentType: true }), - body: rawBody, - }); - - res.statusCode = upstream.status; - upstream.headers.forEach((value, key) => { - if (key.toLowerCase() === 'content-length') { + const controller = new AbortController(); + let clientClosed = false; + const markClientClosed = () => { + if (clientClosed) { return; } - res.setHeader(key, value); - }); + clientClosed = true; + controller.abort(); + }; + const onReqAborted = () => markClientClosed(); + const onResClose = () => { + if (!res.writableEnded) { + markClientClosed(); + } + }; + req.on('aborted', onReqAborted); + res.on('close', onResClose); - if (!upstream.body || typeof upstream.body.getReader !== 'function') { - const bytes = Buffer.from(await upstream.arrayBuffer()); - res.end(bytes); - return; - } - - const reader = upstream.body.getReader(); try { - // eslint-disable-next-line no-constant-condition - while (true) { - const { value, done } = await reader.read(); - if (done) { - break; + let upstream; + try { + upstream = await fetch(url.toString(), { + method: 'POST', + headers: buildInternalGoHeaders(req, { withContentType: true }), + body: rawBody, + signal: controller.signal, + }); + } catch (err) { + if (clientClosed || isAbortError(err)) { + if (!res.writableEnded) { + res.end(); + } + return; } - if (value && value.length > 0) { - res.write(Buffer.from(value)); - if (typeof res.flush === 'function') { - res.flush(); + throw err; + } + if (clientClosed) { + if (!res.writableEnded) { + res.end(); + } + return; + } + + res.statusCode = upstream.status; + upstream.headers.forEach((value, key) => { + if (key.toLowerCase() === 'content-length') { + return; + } + res.setHeader(key, value); + }); + + if (!upstream.body || typeof upstream.body.getReader !== 'function') { + const bytes = Buffer.from(await upstream.arrayBuffer()); + res.end(bytes); + return; + } + + const reader = upstream.body.getReader(); + try { + // eslint-disable-next-line no-constant-condition + while (true) { + if (clientClosed) { + break; + } + const { value, done } = await reader.read(); + if (done) { + break; + } + if (value && value.length > 0) { + res.write(Buffer.from(value)); + if (typeof res.flush === 'function') { + res.flush(); + } } } + if (!res.writableEnded) { + res.end(); + } + } catch (err) { + if (!isAbortError(err) && !res.writableEnded) { + res.end(); + } } - res.end(); - } catch (_err) { + } finally { + req.removeListener('aborted', onReqAborted); + res.removeListener('close', onResClose); if (!res.writableEnded) { res.end(); } @@ -762,6 +921,13 @@ function asString(v) { return String(v).trim(); } +function isAbortError(err) { + if (!err || typeof err !== 'object') { + return false; + } + return err.name === 'AbortError' || err.code === 'ABORT_ERR'; +} + module.exports.__test = { parseChunkForContent, extractContentRecursive, diff --git a/api/chat-stream.test.js b/api/chat-stream.test.js index b347342..c849f7c 100644 --- a/api/chat-stream.test.js +++ b/api/chat-stream.test.js @@ -49,12 +49,13 @@ test('parseChunkForContent + sieve does not leak suspicious prefix in split tool events.push(...flushToolSieve(state, ['read_file'])); const hasToolCalls = events.some((evt) => evt.type === 'tool_calls' && evt.calls && evt.calls.length > 0); + const hasToolDeltas = events.some((evt) => evt.type === 'tool_call_deltas' && evt.deltas && evt.deltas.length > 0); const leakedText = events .filter((evt) => evt.type === 'text' && evt.text) .map((evt) => evt.text) .join(''); - assert.equal(hasToolCalls, true); + assert.equal(hasToolCalls || hasToolDeltas, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); diff --git a/api/helpers/stream-tool-sieve.js b/api/helpers/stream-tool-sieve.js index 3ced63d..8b586aa 100644 --- a/api/helpers/stream-tool-sieve.js +++ b/api/helpers/stream-tool-sieve.js @@ -2,6 +2,7 @@ const crypto = require('crypto'); const TOOL_CALL_PATTERN = /\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}/s; +const TOOL_SIEVE_CAPTURE_LIMIT = 8 * 1024; function extractToolNames(tools) { if (!Array.isArray(tools) || tools.length === 0) { @@ -26,9 +27,25 @@ function createToolSieveState() { pending: '', capture: '', capturing: false, + hasMeaningfulText: false, + toolNameSent: false, + toolName: '', + toolArgsStart: -1, + toolArgsSent: -1, + toolArgsString: false, + toolArgsDone: false, }; } +function resetIncrementalToolState(state) { + state.toolNameSent = false; + state.toolName = ''; + state.toolArgsStart = -1; + state.toolArgsSent = -1; + state.toolArgsString = false; + state.toolArgsDone = false; +} + function processToolSieveChunk(state, chunk, toolNames) { if (!state) { return []; @@ -44,13 +61,31 @@ function processToolSieveChunk(state, chunk, toolNames) { state.capture += state.pending; state.pending = ''; } - const consumed = consumeToolCapture(state.capture, toolNames); + const deltas = buildIncrementalToolDeltas(state); + if (deltas.length > 0) { + events.push({ type: 'tool_call_deltas', deltas }); + } + const consumed = consumeToolCapture(state, toolNames); if (!consumed.ready) { + if (state.capture.length > TOOL_SIEVE_CAPTURE_LIMIT) { + if (hasMeaningfulText(state.capture)) { + state.hasMeaningfulText = true; + } + events.push({ type: 'text', text: state.capture }); + state.capture = ''; + state.capturing = false; + resetIncrementalToolState(state); + continue; + } break; } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); if (consumed.prefix) { + if (hasMeaningfulText(consumed.prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { @@ -70,11 +105,15 @@ function processToolSieveChunk(state, chunk, toolNames) { if (start >= 0) { const prefix = state.pending.slice(0, start); if (prefix) { + if (hasMeaningfulText(prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: prefix }); } state.capture = state.pending.slice(start); state.pending = ''; state.capturing = true; + resetIncrementalToolState(state); continue; } @@ -83,6 +122,9 @@ function processToolSieveChunk(state, chunk, toolNames) { break; } state.pending = hold; + if (hasMeaningfulText(safe)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: safe }); } return events; @@ -94,24 +136,37 @@ function flushToolSieve(state, toolNames) { } const events = processToolSieveChunk(state, '', toolNames); if (state.capturing) { - const consumed = consumeToolCapture(state.capture, toolNames); + const consumed = consumeToolCapture(state, toolNames); if (consumed.ready) { if (consumed.prefix) { + if (hasMeaningfulText(consumed.prefix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.prefix }); } if (Array.isArray(consumed.calls) && consumed.calls.length > 0) { events.push({ type: 'tool_calls', calls: consumed.calls }); } if (consumed.suffix) { + if (hasMeaningfulText(consumed.suffix)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: consumed.suffix }); } } else if (state.capture) { - // Incomplete captured tool JSON at stream end: suppress raw capture. + if (hasMeaningfulText(state.capture)) { + state.hasMeaningfulText = true; + } + events.push({ type: 'text', text: state.capture }); } state.capture = ''; state.capturing = false; + resetIncrementalToolState(state); } if (state.pending) { + if (hasMeaningfulText(state.pending)) { + state.hasMeaningfulText = true; + } events.push({ type: 'text', text: state.pending }); state.pending = ''; } @@ -159,7 +214,8 @@ function findToolSegmentStart(s) { return start >= 0 ? start : keyIdx; } -function consumeToolCapture(captured, toolNames) { +function consumeToolCapture(state, toolNames) { + const captured = state.capture; if (!captured) { return { ready: false, prefix: '', calls: [], suffix: '' }; } @@ -176,25 +232,361 @@ function consumeToolCapture(captured, toolNames) { if (!obj.ok) { return { ready: false, prefix: '', calls: [], suffix: '' }; } - const parsed = parseToolCalls(captured.slice(start, obj.end), toolNames); - if (parsed.length === 0) { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. + const prefixPart = captured.slice(0, start); + const suffixPart = captured.slice(obj.end); + if (!state.toolNameSent && (state.hasMeaningfulText || hasMeaningfulText(prefixPart) || hasMeaningfulText(suffixPart))) { return { ready: true, - prefix: captured.slice(0, start), + prefix: captured, calls: [], - suffix: captured.slice(obj.end), + suffix: '', + }; + } + const parsed = parseStandaloneToolCalls(captured.slice(start, obj.end), toolNames); + if (parsed.length === 0) { + if (state.toolNameSent) { + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: captured, + calls: [], + suffix: '', + }; + } + if (state.toolNameSent) { + if (parsed.length > 1) { + return { + ready: true, + prefix: prefixPart, + calls: parsed.slice(1), + suffix: suffixPart, + }; + } + return { + ready: true, + prefix: prefixPart, + calls: [], + suffix: suffixPart, }; } return { ready: true, - prefix: captured.slice(0, start), + prefix: prefixPart, calls: parsed, - suffix: captured.slice(obj.end), + suffix: suffixPart, }; } +function buildIncrementalToolDeltas(state) { + const captured = state.capture || ''; + if (!captured || state.hasMeaningfulText) { + return []; + } + const lower = captured.toLowerCase(); + const keyIdx = lower.indexOf('tool_calls'); + if (keyIdx < 0) { + return []; + } + const start = captured.slice(0, keyIdx).lastIndexOf('{'); + if (start < 0 || hasMeaningfulText(captured.slice(0, start))) { + return []; + } + const callStart = findFirstToolCallObjectStart(captured, keyIdx); + if (callStart < 0) { + return []; + } + + const deltas = []; + if (!state.toolName) { + const name = extractToolCallName(captured, callStart); + if (!name) { + return []; + } + state.toolName = name; + } + + if (state.toolArgsStart < 0) { + const args = findToolCallArgsStart(captured, callStart); + if (args) { + state.toolArgsString = Boolean(args.stringMode); + state.toolArgsStart = state.toolArgsString ? args.start + 1 : args.start; + state.toolArgsSent = state.toolArgsStart; + } + } + if (!state.toolNameSent) { + if (state.toolArgsStart < 0) { + return []; + } + state.toolNameSent = true; + deltas.push({ index: 0, name: state.toolName }); + } + if (state.toolArgsStart < 0 || state.toolArgsDone) { + return deltas; + } + const progress = scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString); + if (!progress) { + return deltas; + } + if (progress.end > state.toolArgsSent) { + deltas.push({ + index: 0, + arguments: captured.slice(state.toolArgsSent, progress.end), + }); + state.toolArgsSent = progress.end; + } + if (progress.complete) { + state.toolArgsDone = true; + } + return deltas; +} + +function findFirstToolCallObjectStart(text, keyIdx) { + const arrStart = findToolCallsArrayStart(text, keyIdx); + if (arrStart < 0) { + return -1; + } + const i = skipSpaces(text, arrStart + 1); + if (i >= text.length || text[i] !== '{') { + return -1; + } + return i; +} + +function findToolCallsArrayStart(text, keyIdx) { + let i = keyIdx + 'tool_calls'.length; + while (i < text.length && text[i] !== ':') { + i += 1; + } + if (i >= text.length) { + return -1; + } + i = skipSpaces(text, i + 1); + if (i >= text.length || text[i] !== '[') { + return -1; + } + return i; +} + +function extractToolCallName(text, callStart) { + let valueStart = findObjectFieldValueStart(text, callStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return ''; + } + valueStart = findObjectFieldValueStart(text, fnStart, ['name']); + if (valueStart < 0 || text[valueStart] !== '"') { + return ''; + } + } + const parsed = parseJSONStringLiteral(text, valueStart); + if (!parsed) { + return ''; + } + return parsed.value; +} + +function findToolCallArgsStart(text, callStart) { + const keys = ['input', 'arguments', 'args', 'parameters', 'params']; + let valueStart = findObjectFieldValueStart(text, callStart, keys); + if (valueStart < 0) { + const fnStart = findFunctionObjectStart(text, callStart); + if (fnStart < 0) { + return null; + } + valueStart = findObjectFieldValueStart(text, fnStart, keys); + if (valueStart < 0) { + return null; + } + } + if (valueStart >= text.length) { + return null; + } + const ch = text[valueStart]; + if (ch === '{' || ch === '[') { + return { start: valueStart, stringMode: false }; + } + if (ch === '"') { + return { start: valueStart, stringMode: true }; + } + return null; +} + +function scanToolCallArgsProgress(text, start, stringMode) { + if (start < 0 || start > text.length) { + return null; + } + if (stringMode) { + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { end: i, complete: true }; + } + } + return { end: text.length, complete: false }; + } + if (start >= text.length || (text[start] !== '{' && text[start] !== '[')) { + return null; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = start; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '{' || ch === '[') { + depth += 1; + continue; + } + if (ch === '}' || ch === ']') { + depth -= 1; + if (depth === 0) { + return { end: i + 1, complete: true }; + } + } + } + return { end: text.length, complete: false }; +} + +function findObjectFieldValueStart(text, objStart, keys) { + if (!text || objStart < 0 || objStart >= text.length || text[objStart] !== '{') { + return -1; + } + let depth = 0; + let quote = ''; + let escaped = false; + for (let i = objStart; i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (escaped) { + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + if (depth === 1) { + const parsed = parseJSONStringLiteral(text, i); + if (!parsed) { + return -1; + } + let j = skipSpaces(text, parsed.end); + if (j >= text.length || text[j] !== ':') { + i = parsed.end - 1; + continue; + } + j = skipSpaces(text, j + 1); + if (j >= text.length) { + return -1; + } + if (keys.includes(parsed.value)) { + return j; + } + i = j - 1; + continue; + } + quote = ch; + continue; + } + if (ch === '{') { + depth += 1; + continue; + } + if (ch === '}') { + depth -= 1; + if (depth === 0) { + break; + } + } + } + return -1; +} + +function findFunctionObjectStart(text, callStart) { + const valueStart = findObjectFieldValueStart(text, callStart, ['function']); + if (valueStart < 0 || valueStart >= text.length || text[valueStart] !== '{') { + return -1; + } + return valueStart; +} + +function parseJSONStringLiteral(text, start) { + if (!text || start < 0 || start >= text.length || text[start] !== '"') { + return null; + } + let out = ''; + let escaped = false; + for (let i = start + 1; i < text.length; i += 1) { + const ch = text[i]; + if (escaped) { + out += ch; + escaped = false; + continue; + } + if (ch === '\\') { + escaped = true; + continue; + } + if (ch === '"') { + return { value: out, end: i + 1 }; + } + out += ch; + } + return null; +} + +function skipSpaces(text, i) { + let idx = i; + while (idx < text.length) { + const ch = text[idx]; + if (ch === ' ' || ch === '\t' || ch === '\n' || ch === '\r') { + idx += 1; + continue; + } + break; + } + return idx; +} + function extractJSONObjectFrom(text, start) { if (!text || start < 0 || start >= text.length || text[start] !== '{') { return { ok: false, end: 0 }; @@ -251,26 +643,35 @@ function parseToolCalls(text, toolNames) { if (parsed.length === 0) { return []; } - const allowed = new Set((toolNames || []).filter(Boolean)); - const out = []; - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - if (allowed.size > 0 && !allowed.has(tc.name)) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + return filterToolCalls(parsed, toolNames); +} + +function parseStandaloneToolCalls(text, toolNames) { + const trimmed = toStringSafe(text); + if (!trimmed) { + return []; } - if (out.length === 0 && parsed.length > 0) { - for (const tc of parsed) { - if (!tc || !tc.name) { - continue; - } - out.push({ name: tc.name, input: tc.input || {} }); + const candidates = [trimmed]; + if (trimmed.startsWith('```') && trimmed.endsWith('```')) { + const m = trimmed.match(/```(?:json)?\s*([\s\S]*?)\s*```/i); + if (m && m[1]) { + candidates.push(toStringSafe(m[1])); } } - return out; + for (const candidate of candidates) { + const c = toStringSafe(candidate); + if (!c) { + continue; + } + if (!c.startsWith('{') && !c.startsWith('[')) { + continue; + } + const parsed = parseToolCallsPayload(c); + if (parsed.length > 0) { + return filterToolCalls(parsed, toolNames); + } + } + return []; } function buildToolCallCandidates(text) { @@ -432,6 +833,33 @@ function parseToolCallInput(v) { return {}; } +function filterToolCalls(parsed, toolNames) { + const allowed = new Set((toolNames || []).filter(Boolean)); + const out = []; + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + if (allowed.size > 0 && !allowed.has(tc.name)) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + if (out.length === 0 && parsed.length > 0) { + for (const tc of parsed) { + if (!tc || !tc.name) { + continue; + } + out.push({ name: tc.name, input: tc.input || {} }); + } + } + return out; +} + +function hasMeaningfulText(text) { + return toStringSafe(text) !== ''; +} + function formatOpenAIStreamToolCalls(calls) { if (!Array.isArray(calls) || calls.length === 0) { return []; @@ -473,5 +901,6 @@ module.exports = { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, formatOpenAIStreamToolCalls, }; diff --git a/api/helpers/stream-tool-sieve.test.js b/api/helpers/stream-tool-sieve.test.js index 47b3100..ad1dc0b 100644 --- a/api/helpers/stream-tool-sieve.test.js +++ b/api/helpers/stream-tool-sieve.test.js @@ -9,6 +9,7 @@ const { processToolSieveChunk, flushToolSieve, parseToolCalls, + parseStandaloneToolCalls, } = require('./stream-tool-sieve'); function runSieve(chunks, toolNames) { @@ -73,6 +74,15 @@ test('parseToolCalls supports fenced json and function.arguments string payload' assert.deepEqual(calls[0].input, { path: 'README.md' }); }); +test('parseStandaloneToolCalls only matches standalone payload and ignores mixed prose', () => { + const mixed = '这里是示例:{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]},请勿执行。'; + const standalone = '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'; + const mixedCalls = parseStandaloneToolCalls(mixed, ['read_file']); + const standaloneCalls = parseStandaloneToolCalls(standalone, ['read_file']); + assert.equal(mixedCalls.length, 0); + assert.equal(standaloneCalls.length, 1); +}); + test('sieve emits tool_calls and does not leak suspicious prefix on late key convergence', () => { const events = runSieve( [ @@ -84,13 +94,14 @@ test('sieve emits tool_calls and does not leak suspicious prefix on late key con ); const leakedText = collectText(events); const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); - assert.equal(hasToolCall, true); + const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); + assert.equal(hasToolCall || hasToolDelta, true); assert.equal(leakedText.includes('{'), false); assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); assert.equal(leakedText.includes('后置正文C。'), true); }); -test('sieve drops invalid tool json body while preserving surrounding text', () => { +test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => { const events = runSieve( [ '前置正文D。', @@ -104,18 +115,18 @@ test('sieve drops invalid tool json body while preserving surrounding text', () assert.equal(hasToolCall, false); assert.equal(leakedText.includes('前置正文D。'), true); assert.equal(leakedText.includes('后置正文E。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); }); -test('sieve suppresses incomplete captured tool json on stream finalize', () => { +test('sieve flushes incomplete captured tool json as text on stream finalize', () => { const events = runSieve( ['前置正文F。', '{"tool_calls":[{"name":"read_file"'], ['read_file'], ); const leakedText = collectText(events); assert.equal(leakedText.includes('前置正文F。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); - assert.equal(leakedText.includes('{'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(leakedText.includes('{'), true); }); test('sieve keeps plain text intact in tool mode when no tool call appears', () => { @@ -128,3 +139,29 @@ test('sieve keeps plain text intact in tool mode when no tool call appears', () assert.equal(hasToolCall, false); assert.equal(leakedText, '你好,这是普通文本回复。请继续。'); }); + +test('sieve emits incremental tool_call_deltas for split arguments payload', () => { + const state = createToolSieveState(); + const first = processToolSieveChunk( + state, + '{"tool_calls":[{"name":"read_file","input":{"path":"READ', + ['read_file'], + ); + const second = processToolSieveChunk( + state, + 'ME.MD","mode":"head"}}]}', + ['read_file'], + ); + const tail = flushToolSieve(state, ['read_file']); + const events = [...first, ...second, ...tail]; + const deltaEvents = events.filter((evt) => evt.type === 'tool_call_deltas'); + assert.equal(deltaEvents.length > 0, true); + const merged = deltaEvents.flatMap((evt) => evt.deltas || []); + const hasName = merged.some((d) => d.name === 'read_file'); + const argsJoined = merged + .map((d) => d.arguments || '') + .join(''); + assert.equal(hasName, true); + assert.equal(argsJoined.includes('"path":"README.MD"'), true); + assert.equal(argsJoined.includes('"mode":"head"'), true); +}); diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index 1602cf6..962e450 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -11,6 +11,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/google/uuid" "ds2api/internal/auth" "ds2api/internal/config" @@ -134,7 +135,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re finalThinking := result.Thinking finalText := result.Text - detected := util.ParseToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCalls(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if thinkingEnabled && finalThinking != "" { @@ -188,6 +189,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt bufferToolContent := len(toolNames) > 0 var toolSieve toolStreamSieveState toolCallsEmitted := false + streamToolCallIDs := map[int]string{} initialType := "text" if thinkingEnabled { initialType = "thinking" @@ -220,7 +222,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt finalize := func(finishReason string) { finalThinking := thinking.String() finalText := text.String() - detected := util.ParseToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCalls(finalText, toolNames) if len(detected) > 0 && !toolCallsEmitted { finishReason = "tool_calls" delta := map[string]any{ @@ -352,6 +354,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt // Keep thinking delta only frame. } for _, evt := range events { + if len(evt.ToolCallDeltas) > 0 { + toolCallsEmitted = true + tcDelta := map[string]any{ + "tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs), + } + if !firstChunkSent { + tcDelta["role"] = "assistant" + firstChunkSent = true + } + newChoices = append(newChoices, map[string]any{ + "delta": tcDelta, + "index": 0, + }) + continue + } if len(evt.ToolCalls) > 0 { toolCallsEmitted = true tcDelta := map[string]any{ @@ -441,6 +458,40 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, return messages, names } +func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any { + if len(deltas) == 0 { + return nil + } + out := make([]map[string]any, 0, len(deltas)) + for _, d := range deltas { + if d.Name == "" && d.Arguments == "" { + continue + } + callID, ok := ids[d.Index] + if !ok || callID == "" { + callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "") + ids[d.Index] = callID + } + item := map[string]any{ + "index": d.Index, + "id": callID, + "type": "function", + } + fn := map[string]any{} + if d.Name != "" { + fn["name"] = d.Name + } + if d.Arguments != "" { + fn["arguments"] = d.Arguments + } + if len(fn) > 0 { + item["function"] = fn + } + out = append(out, item) + } + return out +} + func writeOpenAIError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": map[string]any{ diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index f9c44dd..30197d7 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string { return "" } +func streamToolCallArgumentChunks(frames []map[string]any) []string { + out := make([]string, 0, 4) + for _, frame := range frames { + choices, _ := frame["choices"].([]any) + for _, item := range choices { + choice, _ := item.(map[string]any) + delta, _ := choice["delta"].(map[string]any) + toolCalls, _ := delta["tool_calls"].([]any) + for _, tc := range toolCalls { + tcm, _ := tc.(map[string]any) + fn, _ := tcm["function"].(map[string]any) + if args, ok := fn["arguments"].(string); ok && args != "" { + out = append(out, args) + } + } + } + } + return out +} + func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -190,6 +210,37 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) { } } +func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"下面是示例:"}`, + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: {"p":"response/content","v":"请勿执行。"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"}) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if _, ok := msg["tool_calls"]; ok { + t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected embedded example to pass through as text, got %q", content) + } +} + func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( @@ -391,11 +442,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !done { t.Fatalf("expected [DONE], body=%s", rec.Body.String()) } - if !streamHasToolCallsDelta(frames) { - t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String()) - } - if streamHasRawToolJSONContent(frames) { - t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String()) + if streamHasToolCallsDelta(frames) { + t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String()) } content := strings.Builder{} for _, frame := range frames { @@ -412,8 +460,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) { if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") { t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got) } - if streamFinishReason(frames) != "tool_calls" { - t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) + if !strings.Contains(got, `"tool_calls"`) { + t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got) + } + if streamFinishReason(frames) != "stop" { + t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String()) } } @@ -495,16 +546,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) { } } } - got := strings.ToLower(content.String()) - if strings.Contains(got, "tool_calls") { - t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String()) - } - if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") { + got := content.String() + if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") { t.Fatalf("expected pre/post plain text to remain, got=%q", content.String()) } + if !strings.Contains(strings.ToLower(got), "tool_calls") { + t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got) + } } -func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) { +func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`, @@ -533,7 +584,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing. } } } - if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") { - t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String()) + if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") { + t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String()) + } +} + +func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`, + `data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"}) + + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if !streamHasToolCallsDelta(frames) { + t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String()) + } + if streamHasRawToolJSONContent(frames) { + t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String()) + } + argChunks := streamToolCallArgumentChunks(frames) + if len(argChunks) < 2 { + t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String()) + } + joined := strings.Join(argChunks, "") + if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) { + t.Fatalf("unexpected merged arguments stream: %q", joined) + } + if streamFinishReason(frames) != "tool_calls" { + t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String()) } } diff --git a/internal/adapter/openai/tool_sieve.go b/internal/adapter/openai/tool_sieve.go index d1a9014..d890314 100644 --- a/internal/adapter/openai/tool_sieve.go +++ b/internal/adapter/openai/tool_sieve.go @@ -7,14 +7,39 @@ import ( ) type toolStreamSieveState struct { - pending strings.Builder - capture strings.Builder - capturing bool + pending strings.Builder + capture strings.Builder + capturing bool + hasMeaningfulText bool + toolNameSent bool + toolName string + toolArgsStart int + toolArgsSent int + toolArgsString bool + toolArgsDone bool } type toolStreamEvent struct { - Content string - ToolCalls []util.ParsedToolCall + Content string + ToolCalls []util.ParsedToolCall + ToolCallDeltas []toolCallDelta +} + +type toolCallDelta struct { + Index int + Name string + Arguments string +} + +const toolSieveCaptureLimit = 8 * 1024 + +func (s *toolStreamSieveState) resetIncrementalToolState() { + s.toolNameSent = false + s.toolName = "" + s.toolArgsStart = -1 + s.toolArgsSent = -1 + s.toolArgsString = false + s.toolArgsDone = false } func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent { @@ -32,13 +57,31 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames state.capture.WriteString(state.pending.String()) state.pending.Reset() } - prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames) + if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 { + events = append(events, toolStreamEvent{ToolCallDeltas: deltas}) + } + prefix, calls, suffix, ready := consumeToolCapture(state, toolNames) if !ready { + if state.capture.Len() > toolSieveCaptureLimit { + content := state.capture.String() + state.capture.Reset() + state.capturing = false + state.resetIncrementalToolState() + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) + continue + } break } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() if prefix != "" { + if strings.TrimSpace(prefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: prefix}) } if len(calls) > 0 { @@ -58,11 +101,15 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames if start >= 0 { prefix := pending[:start] if prefix != "" { + if strings.TrimSpace(prefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: prefix}) } state.pending.Reset() state.capture.WriteString(pending[start:]) state.capturing = true + state.resetIncrementalToolState() continue } @@ -72,6 +119,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames } state.pending.Reset() state.pending.WriteString(hold) + if strings.TrimSpace(safe) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: safe}) } @@ -84,25 +134,42 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea } events := processToolSieveChunk(state, "", toolNames) if state.capturing { - consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames) + consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames) if ready { if consumedPrefix != "" { + if strings.TrimSpace(consumedPrefix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: consumedPrefix}) } if len(consumedCalls) > 0 { events = append(events, toolStreamEvent{ToolCalls: consumedCalls}) } if consumedSuffix != "" { + if strings.TrimSpace(consumedSuffix) != "" { + state.hasMeaningfulText = true + } events = append(events, toolStreamEvent{Content: consumedSuffix}) } } else { - // Incomplete captured tool JSON at stream end: suppress raw capture. + content := state.capture.String() + if content != "" { + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) + } } state.capture.Reset() state.capturing = false + state.resetIncrementalToolState() } if state.pending.Len() > 0 { - events = append(events, toolStreamEvent{Content: state.pending.String()}) + content := state.pending.String() + if strings.TrimSpace(content) != "" { + state.hasMeaningfulText = true + } + events = append(events, toolStreamEvent{Content: content}) state.pending.Reset() } return events @@ -154,7 +221,8 @@ func findToolSegmentStart(s string) int { return keyIdx } -func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { +func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) { + captured := state.capture.String() if captured == "" { return "", nil, "", false } @@ -171,13 +239,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal if !ok { return "", nil, "", false } - parsed := util.ParseToolCalls(obj, toolNames) - if len(parsed) == 0 { - // `tool_calls` key exists but strict JSON parse failed. - // Drop the captured object body to avoid leaking raw tool JSON. - return captured[:start], nil, captured[end:], true + prefixPart := captured[:start] + suffixPart := captured[end:] + if !state.toolNameSent && (state.hasMeaningfulText || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "") { + return captured, nil, "", true } - return captured[:start], parsed, captured[end:], true + parsed := util.ParseStandaloneToolCalls(obj, toolNames) + if len(parsed) == 0 { + if state.toolNameSent { + return prefixPart, nil, suffixPart, true + } + return captured, nil, "", true + } + if state.toolNameSent { + if len(parsed) > 1 { + return prefixPart, parsed[1:], suffixPart, true + } + return prefixPart, nil, suffixPart, true + } + return prefixPart, parsed, suffixPart, true } func extractJSONObjectFrom(text string, start int) (string, int, bool) { @@ -221,3 +301,320 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) { } return "", 0, false } + +func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta { + captured := state.capture.String() + if captured == "" || state.hasMeaningfulText { + return nil + } + lower := strings.ToLower(captured) + keyIdx := strings.Index(lower, "tool_calls") + if keyIdx < 0 { + return nil + } + start := strings.LastIndex(captured[:keyIdx], "{") + if start < 0 || strings.TrimSpace(captured[:start]) != "" { + return nil + } + callStart, ok := findFirstToolCallObjectStart(captured, keyIdx) + if !ok { + return nil + } + deltas := make([]toolCallDelta, 0, 2) + if state.toolName == "" { + name, ok := extractToolCallName(captured, callStart) + if !ok || name == "" { + return nil + } + state.toolName = name + } + if state.toolArgsStart < 0 { + argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart) + if ok { + state.toolArgsString = stringMode + if stringMode { + state.toolArgsStart = argsStart + 1 + } else { + state.toolArgsStart = argsStart + } + state.toolArgsSent = state.toolArgsStart + } + } + if !state.toolNameSent { + if state.toolArgsStart < 0 { + return nil + } + state.toolNameSent = true + deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName}) + } + if state.toolArgsStart < 0 || state.toolArgsDone { + return deltas + } + end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString) + if !ok { + return deltas + } + if end > state.toolArgsSent { + deltas = append(deltas, toolCallDelta{ + Index: 0, + Arguments: captured[state.toolArgsSent:end], + }) + state.toolArgsSent = end + } + if complete { + state.toolArgsDone = true + } + return deltas +} + +func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) { + arrStart, ok := findToolCallsArrayStart(text, keyIdx) + if !ok { + return -1, false + } + i := skipSpaces(text, arrStart+1) + if i >= len(text) || text[i] != '{' { + return -1, false + } + return i, true +} + +func findToolCallsArrayStart(text string, keyIdx int) (int, bool) { + i := keyIdx + len("tool_calls") + for i < len(text) && text[i] != ':' { + i++ + } + if i >= len(text) { + return -1, false + } + i = skipSpaces(text, i+1) + if i >= len(text) || text[i] != '[' { + return -1, false + } + return i, true +} + +func extractToolCallName(text string, callStart int) (string, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return "", false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"}) + if !ok || valueStart >= len(text) || text[valueStart] != '"' { + return "", false + } + } + name, _, ok := parseJSONStringLiteral(text, valueStart) + if !ok { + return "", false + } + return name, true +} + +func findToolCallArgsStart(text string, callStart int) (int, bool, bool) { + keys := []string{"input", "arguments", "args", "parameters", "params"} + valueStart, ok := findObjectFieldValueStart(text, callStart, keys) + if !ok { + fnStart, fnOK := findFunctionObjectStart(text, callStart) + if !fnOK { + return -1, false, false + } + valueStart, ok = findObjectFieldValueStart(text, fnStart, keys) + if !ok { + return -1, false, false + } + } + if valueStart >= len(text) { + return -1, false, false + } + ch := text[valueStart] + if ch == '{' || ch == '[' { + return valueStart, false, true + } + if ch == '"' { + return valueStart, true, true + } + return -1, false, false +} + +func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) { + if start < 0 || start > len(text) { + return 0, false, false + } + if stringMode { + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return i, true, true + } + } + return len(text), false, true + } + if start >= len(text) { + return start, false, false + } + if text[start] != '{' && text[start] != '[' { + return 0, false, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := start; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '{' || ch == '[' { + depth++ + continue + } + if ch == '}' || ch == ']' { + depth-- + if depth == 0 { + return i + 1, true, true + } + } + } + return len(text), false, true +} + +func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) { + if objStart < 0 || objStart >= len(text) || text[objStart] != '{' { + return 0, false + } + depth := 0 + quote := byte(0) + escaped := false + for i := objStart; i < len(text); i++ { + ch := text[i] + if quote != 0 { + if escaped { + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + if depth == 1 { + key, end, ok := parseJSONStringLiteral(text, i) + if !ok { + return 0, false + } + j := skipSpaces(text, end) + if j >= len(text) || text[j] != ':' { + i = end - 1 + continue + } + j = skipSpaces(text, j+1) + if j >= len(text) { + return 0, false + } + if containsKey(keys, key) { + return j, true + } + i = j - 1 + continue + } + quote = ch + continue + } + if ch == '{' { + depth++ + continue + } + if ch == '}' { + depth-- + if depth == 0 { + break + } + } + } + return 0, false +} + +func findFunctionObjectStart(text string, callStart int) (int, bool) { + valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"}) + if !ok || valueStart >= len(text) || text[valueStart] != '{' { + return -1, false + } + return valueStart, true +} + +func parseJSONStringLiteral(text string, start int) (string, int, bool) { + if start < 0 || start >= len(text) || text[start] != '"' { + return "", 0, false + } + var b strings.Builder + escaped := false + for i := start + 1; i < len(text); i++ { + ch := text[i] + if escaped { + b.WriteByte(ch) + escaped = false + continue + } + if ch == '\\' { + escaped = true + continue + } + if ch == '"' { + return b.String(), i + 1, true + } + b.WriteByte(ch) + } + return "", 0, false +} + +func containsKey(keys []string, value string) bool { + for _, k := range keys { + if k == value { + return true + } + } + return false +} + +func skipSpaces(text string, i int) int { + for i < len(text) { + switch text[i] { + case ' ', '\t', '\n', '\r': + i++ + default: + return i + } + } + return i +} diff --git a/internal/util/toolcalls.go b/internal/util/toolcalls.go index 9b9d4e6..4760546 100644 --- a/internal/util/toolcalls.go +++ b/internal/util/toolcalls.go @@ -33,6 +33,36 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall { return nil } + return filterToolCalls(parsed, availableToolNames) +} + +func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall { + trimmed := strings.TrimSpace(text) + if trimmed == "" { + return nil + } + candidates := []string{trimmed} + if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") { + if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 { + candidates = append(candidates, strings.TrimSpace(m[1])) + } + } + for _, candidate := range candidates { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") { + continue + } + if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 { + return filterToolCalls(parsed, availableToolNames) + } + } + return nil +} + +func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall { allowed := map[string]struct{}{} for _, name := range availableToolNames { allowed[name] = struct{}{} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 8c44320..8a29a18 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -62,3 +62,16 @@ func TestFormatOpenAIToolCalls(t *testing.T) { t.Fatalf("unexpected function name: %#v", fn) } } + +func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) { + mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 { + t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls) + } + + standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}` + calls := ParseStandaloneToolCalls(standalone, []string{"search"}) + if len(calls) != 1 { + t.Fatalf("expected standalone parser to match, got %#v", calls) + } +}