mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-02 07:25:26 +08:00
@@ -148,6 +148,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools`
|
||||
4. 把这整段内容并入 system prompt。
|
||||
|
||||
工具调用正例仍只示范 canonical XML:`<tool_calls>` → `<invoke name="...">` → `<parameter name="...">`。
|
||||
提示词会额外强调:如果要调用工具,工具块的首个非空白字符必须就是 `<tool_calls>`,不能只输出 `</tool_calls>` 而漏掉 opening tag。
|
||||
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
|
||||
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
|
||||
|
||||
@@ -192,6 +193,7 @@ assistant 历史 `tool_calls` 不会保留成 OpenAI 原生 JSON,而会转成
|
||||
```
|
||||
|
||||
这也是当前项目里唯一受支持的 canonical tool-calling 形态;其他形态都会作为普通文本保留,不会作为可执行调用语法。
|
||||
例外是 parser 会对一个非常窄的模型失误做修复:如果 assistant 输出了 `<invoke ...>` ... `</tool_calls>`,但漏掉最前面的 opening `<tool_calls>`,解析阶段会补回 wrapper 后再尝试识别。
|
||||
|
||||
这件事很重要,因为它决定了:
|
||||
|
||||
|
||||
@@ -23,9 +23,14 @@
|
||||
- 工具名必须放在 `invoke` 的 `name` 属性
|
||||
- 参数必须使用 `<parameter name="...">...</parameter>`
|
||||
|
||||
兼容修复:
|
||||
|
||||
- 如果模型漏掉 opening `<tool_calls>`,但后面仍输出了一个或多个 `<invoke ...>` 并以 `</tool_calls>` 收尾,Go 解析链路会在解析前补回缺失的 opening wrapper。
|
||||
- 这是一个针对常见模型失误的窄修复,不改变推荐输出格式;prompt 仍要求模型直接输出完整 canonical XML。
|
||||
|
||||
## 2) 非 canonical 内容
|
||||
|
||||
任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。
|
||||
任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。一个例外是上一节提到的“缺失 opening `<tool_calls>`、但 closing `</tool_calls>` 仍存在”的窄修复场景。
|
||||
|
||||
当前 parser 不把 allow-list 当作硬安全边界:即使传入了已声明工具名列表,XML 里出现未声明工具名时也会尽量解析并交给上层协议输出;真正的执行侧仍必须自行校验工具名和参数。
|
||||
|
||||
@@ -33,7 +38,8 @@
|
||||
|
||||
在流式链路中(Go / Node 一致):
|
||||
|
||||
- 只有从 `<tool_calls` 开始的 canonical wrapper 才会进入结构化捕获
|
||||
- canonical `<tool_calls>` wrapper 会进入结构化捕获
|
||||
- 如果流里直接从 `<invoke ...>` 开始,但后面补上了 `</tool_calls>`,Go 流式筛分也会按缺失 opening wrapper 的修复路径尝试恢复
|
||||
- 已识别成功的工具调用不会再次回流到普通文本
|
||||
- 不符合新格式的块不会执行,并继续按原样文本透传
|
||||
- fenced code block 中的 XML 示例始终按普通文本处理
|
||||
@@ -43,14 +49,14 @@
|
||||
`ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回:
|
||||
|
||||
- `calls`:解析出的工具调用列表(`name` + `input`)
|
||||
- `sawToolCallSyntax`:只有检测到 `<tool_calls` 时才会为 `true`
|
||||
- `sawToolCallSyntax`:检测到 canonical wrapper,或命中“缺失 opening wrapper 但可修复”的形态时会为 `true`
|
||||
- `rejectedByPolicy`:当前固定为 `false`
|
||||
- `rejectedToolNames`:当前固定为空数组
|
||||
|
||||
## 5) 落地建议
|
||||
|
||||
1. Prompt 里只示范 canonical XML 语法。
|
||||
2. 上游客户端需要直接输出 canonical XML;DS2API 不会把其他形态改写成工具调用。
|
||||
2. 上游客户端仍应直接输出 canonical XML;DS2API 只对“closing tag 在、opening tag 漏掉”的常见失误做窄修复,不会泛化接受其他旧格式。
|
||||
3. 不要依赖 parser 做安全控制;执行器侧仍应做工具名和参数校验。
|
||||
|
||||
## 6) 回归验证
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
'use strict';
|
||||
|
||||
const TOOLS_WRAPPER_PATTERN = /<tool_calls\b[^>]*>([\s\S]*?)<\/tool_calls>/gi;
|
||||
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?invoke>/gi;
|
||||
const PARAMETER_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi;
|
||||
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
|
||||
const CDATA_PATTERN = /^<!\[CDATA\[([\s\S]*?)]]>$/i;
|
||||
const XML_ATTR_PATTERN = /\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')/gi;
|
||||
@@ -25,9 +22,9 @@ function parseMarkupToolCalls(text) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
for (const wrapper of raw.matchAll(TOOLS_WRAPPER_PATTERN)) {
|
||||
const body = toStringSafe(wrapper[1]);
|
||||
for (const block of body.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
|
||||
for (const wrapper of findXmlElementBlocks(raw, 'tool_calls')) {
|
||||
const body = toStringSafe(wrapper.body);
|
||||
for (const block of findXmlElementBlocks(body, 'invoke')) {
|
||||
const parsed = parseMarkupSingleToolCall(block);
|
||||
if (parsed) {
|
||||
out.push(parsed);
|
||||
@@ -38,12 +35,12 @@ function parseMarkupToolCalls(text) {
|
||||
}
|
||||
|
||||
function parseMarkupSingleToolCall(block) {
|
||||
const attrs = parseTagAttributes(block[1]);
|
||||
const attrs = parseTagAttributes(block.attrs);
|
||||
const name = toStringSafe(attrs.name).trim();
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
const inner = toStringSafe(block[2]).trim();
|
||||
const inner = toStringSafe(block.body).trim();
|
||||
|
||||
if (inner) {
|
||||
try {
|
||||
@@ -63,13 +60,13 @@ function parseMarkupSingleToolCall(block) {
|
||||
}
|
||||
}
|
||||
const input = {};
|
||||
for (const match of inner.matchAll(PARAMETER_BLOCK_PATTERN)) {
|
||||
const parameterAttrs = parseTagAttributes(match[1]);
|
||||
for (const match of findXmlElementBlocks(inner, 'parameter')) {
|
||||
const parameterAttrs = parseTagAttributes(match.attrs);
|
||||
const paramName = toStringSafe(parameterAttrs.name).trim();
|
||||
if (!paramName) {
|
||||
continue;
|
||||
}
|
||||
appendMarkupValue(input, paramName, parseMarkupValue(match[2]));
|
||||
appendMarkupValue(input, paramName, parseMarkupValue(match.body));
|
||||
}
|
||||
if (Object.keys(input).length === 0 && inner.trim() !== '') {
|
||||
return null;
|
||||
@@ -77,6 +74,154 @@ function parseMarkupSingleToolCall(block) {
|
||||
return { name, input };
|
||||
}
|
||||
|
||||
function findXmlElementBlocks(text, tag) {
|
||||
const source = toStringSafe(text);
|
||||
const name = toStringSafe(tag).toLowerCase();
|
||||
if (!source || !name) {
|
||||
return [];
|
||||
}
|
||||
const out = [];
|
||||
let pos = 0;
|
||||
while (pos < source.length) {
|
||||
const start = findXmlStartTagOutsideCDATA(source, name, pos);
|
||||
if (!start) {
|
||||
break;
|
||||
}
|
||||
const end = findMatchingXmlEndTagOutsideCDATA(source, name, start.bodyStart);
|
||||
if (!end) {
|
||||
break;
|
||||
}
|
||||
out.push({
|
||||
attrs: start.attrs,
|
||||
body: source.slice(start.bodyStart, end.closeStart),
|
||||
start: start.start,
|
||||
end: end.closeEnd,
|
||||
});
|
||||
pos = end.closeEnd;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function findXmlStartTagOutsideCDATA(text, tag, from) {
|
||||
const lower = text.toLowerCase();
|
||||
const target = `<${tag}`;
|
||||
for (let i = Math.max(0, from || 0); i < text.length;) {
|
||||
const skipped = skipXmlIgnoredSection(lower, i);
|
||||
if (skipped.blocked) {
|
||||
return null;
|
||||
}
|
||||
if (skipped.advanced) {
|
||||
i = skipped.next;
|
||||
continue;
|
||||
}
|
||||
if (lower.startsWith(target, i) && hasXmlTagBoundary(text, i + target.length)) {
|
||||
const tagEnd = findXmlTagEnd(text, i + target.length);
|
||||
if (tagEnd < 0) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
start: i,
|
||||
bodyStart: tagEnd + 1,
|
||||
attrs: text.slice(i + target.length, tagEnd),
|
||||
};
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function findMatchingXmlEndTagOutsideCDATA(text, tag, from) {
|
||||
const lower = text.toLowerCase();
|
||||
const openTarget = `<${tag}`;
|
||||
const closeTarget = `</${tag}`;
|
||||
let depth = 1;
|
||||
for (let i = Math.max(0, from || 0); i < text.length;) {
|
||||
const skipped = skipXmlIgnoredSection(lower, i);
|
||||
if (skipped.blocked) {
|
||||
return null;
|
||||
}
|
||||
if (skipped.advanced) {
|
||||
i = skipped.next;
|
||||
continue;
|
||||
}
|
||||
if (lower.startsWith(closeTarget, i) && hasXmlTagBoundary(text, i + closeTarget.length)) {
|
||||
const tagEnd = findXmlTagEnd(text, i + closeTarget.length);
|
||||
if (tagEnd < 0) {
|
||||
return null;
|
||||
}
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return { closeStart: i, closeEnd: tagEnd + 1 };
|
||||
}
|
||||
i = tagEnd + 1;
|
||||
continue;
|
||||
}
|
||||
if (lower.startsWith(openTarget, i) && hasXmlTagBoundary(text, i + openTarget.length)) {
|
||||
const tagEnd = findXmlTagEnd(text, i + openTarget.length);
|
||||
if (tagEnd < 0) {
|
||||
return null;
|
||||
}
|
||||
if (!isSelfClosingXmlTag(text.slice(i, tagEnd))) {
|
||||
depth += 1;
|
||||
}
|
||||
i = tagEnd + 1;
|
||||
continue;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function skipXmlIgnoredSection(lower, i) {
|
||||
if (lower.startsWith('<![cdata[', i)) {
|
||||
const end = lower.indexOf(']]>', i + '<![cdata['.length);
|
||||
if (end < 0) {
|
||||
return { advanced: false, blocked: true, next: i };
|
||||
}
|
||||
return { advanced: true, blocked: false, next: end + ']]>'.length };
|
||||
}
|
||||
if (lower.startsWith('<!--', i)) {
|
||||
const end = lower.indexOf('-->', i + '<!--'.length);
|
||||
if (end < 0) {
|
||||
return { advanced: false, blocked: true, next: i };
|
||||
}
|
||||
return { advanced: true, blocked: false, next: end + '-->'.length };
|
||||
}
|
||||
return { advanced: false, blocked: false, next: i };
|
||||
}
|
||||
|
||||
function findXmlTagEnd(text, from) {
|
||||
let quote = '';
|
||||
for (let i = Math.max(0, from || 0); i < text.length; i += 1) {
|
||||
const ch = text[i];
|
||||
if (quote) {
|
||||
if (ch === quote) {
|
||||
quote = '';
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (ch === '"' || ch === "'") {
|
||||
quote = ch;
|
||||
continue;
|
||||
}
|
||||
if (ch === '>') {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
function hasXmlTagBoundary(text, idx) {
|
||||
if (idx >= text.length) {
|
||||
return true;
|
||||
}
|
||||
return [' ', '\t', '\n', '\r', '>', '/'].includes(text[idx]);
|
||||
}
|
||||
|
||||
function isSelfClosingXmlTag(startTag) {
|
||||
return toStringSafe(startTag).trim().endsWith('/');
|
||||
}
|
||||
|
||||
function parseMarkupInput(raw) {
|
||||
const s = toStringSafe(raw).trim();
|
||||
if (!s) {
|
||||
@@ -120,6 +265,10 @@ function parseMarkupKVObject(text) {
|
||||
}
|
||||
|
||||
function parseMarkupValue(raw) {
|
||||
const cdata = extractStandaloneCDATA(raw);
|
||||
if (cdata.ok) {
|
||||
return cdata.value;
|
||||
}
|
||||
const s = toStringSafe(extractRawTagValue(raw)).trim();
|
||||
if (!s) {
|
||||
return '';
|
||||
@@ -152,9 +301,9 @@ function extractRawTagValue(inner) {
|
||||
}
|
||||
|
||||
// 1. Check for CDATA
|
||||
const cdataMatch = s.match(CDATA_PATTERN);
|
||||
if (cdataMatch && cdataMatch[1] !== undefined) {
|
||||
return cdataMatch[1];
|
||||
const cdata = extractStandaloneCDATA(s);
|
||||
if (cdata.ok) {
|
||||
return cdata.value;
|
||||
}
|
||||
|
||||
// 2. Fallback to unescaping standard HTML entities
|
||||
@@ -172,6 +321,15 @@ function unescapeHtml(safe) {
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
function extractStandaloneCDATA(inner) {
|
||||
const s = toStringSafe(inner).trim();
|
||||
const cdataMatch = s.match(CDATA_PATTERN);
|
||||
if (cdataMatch && cdataMatch[1] !== undefined) {
|
||||
return { ok: true, value: cdataMatch[1] };
|
||||
}
|
||||
return { ok: false, value: '' };
|
||||
}
|
||||
|
||||
function parseTagAttributes(raw) {
|
||||
const source = toStringSafe(raw);
|
||||
const out = {};
|
||||
|
||||
@@ -16,9 +16,10 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
|
||||
if (openIdx < 0) {
|
||||
continue;
|
||||
}
|
||||
// Find the LAST occurrence of the specific closing tag.
|
||||
const closeIdx = lower.lastIndexOf(pair.close);
|
||||
if (closeIdx < openIdx) {
|
||||
// Ignore closing tags that appear inside CDATA payloads, such as
|
||||
// write-file content containing tool-call documentation examples.
|
||||
const closeIdx = findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length);
|
||||
if (closeIdx < 0) {
|
||||
// Opening tag present but specific closing tag hasn't arrived.
|
||||
// Return not-ready so buffering continues until the wrapper closes.
|
||||
return { ready: false, prefix: '', calls: [], suffix: '' };
|
||||
@@ -46,8 +47,9 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
|
||||
function hasOpenXMLToolTag(captured) {
|
||||
const lower = captured.toLowerCase();
|
||||
for (const pair of XML_TOOL_TAG_PAIRS) {
|
||||
if (lower.includes(pair.open)) {
|
||||
if (!lower.includes(pair.close)) {
|
||||
const openIdx = lower.indexOf(pair.open);
|
||||
if (openIdx >= 0) {
|
||||
if (findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length) < 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -74,6 +76,38 @@ function findPartialXMLToolTagStart(s) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
function findXMLCloseOutsideCDATA(s, closeTag, start) {
|
||||
const text = typeof s === 'string' ? s : '';
|
||||
const target = String(closeTag || '').toLowerCase();
|
||||
if (!text || !target) {
|
||||
return -1;
|
||||
}
|
||||
const lower = text.toLowerCase();
|
||||
for (let i = Math.max(0, start || 0); i < text.length;) {
|
||||
if (lower.startsWith('<![cdata[', i)) {
|
||||
const end = lower.indexOf(']]>', i + '<![cdata['.length);
|
||||
if (end < 0) {
|
||||
return -1;
|
||||
}
|
||||
i = end + ']]>'.length;
|
||||
continue;
|
||||
}
|
||||
if (lower.startsWith('<!--', i)) {
|
||||
const end = lower.indexOf('-->', i + '<!--'.length);
|
||||
if (end < 0) {
|
||||
return -1;
|
||||
}
|
||||
i = end + '-->'.length;
|
||||
continue;
|
||||
}
|
||||
if (lower.startsWith(target, i)) {
|
||||
return i;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
consumeXMLToolCapture,
|
||||
hasOpenXMLToolTag,
|
||||
|
||||
@@ -71,15 +71,30 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
|
||||
hooks.OnFinalize(reason, scannerErr)
|
||||
}
|
||||
}
|
||||
contextDone := func() bool {
|
||||
if cfg.Context.Err() == nil {
|
||||
return false
|
||||
}
|
||||
if hooks.OnContextDone != nil {
|
||||
hooks.OnContextDone()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
if hooks.OnContextDone != nil {
|
||||
hooks.OnContextDone()
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
return
|
||||
case <-tickCh(ticker):
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
if !hasContent {
|
||||
keepaliveCount++
|
||||
if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput {
|
||||
@@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
|
||||
hooks.OnKeepAlive()
|
||||
}
|
||||
case parsed, ok := <-parsedLines:
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
finalize(StopReasonUpstreamCompleted, <-done)
|
||||
return
|
||||
|
||||
47
internal/stream/engine_test.go
Normal file
47
internal/stream/engine_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
var finalized bool
|
||||
var contextDone bool
|
||||
var parsedCalled bool
|
||||
|
||||
ConsumeSSE(ConsumeConfig{
|
||||
Context: ctx,
|
||||
Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"),
|
||||
ThinkingEnabled: false,
|
||||
InitialType: "text",
|
||||
KeepAliveInterval: 0,
|
||||
}, ConsumeHooks{
|
||||
OnParsed: func(_ sse.LineResult) ParsedDecision {
|
||||
parsedCalled = true
|
||||
return ParsedDecision{}
|
||||
},
|
||||
OnFinalize: func(_ StopReason, _ error) {
|
||||
finalized = true
|
||||
},
|
||||
OnContextDone: func() {
|
||||
contextDone = true
|
||||
},
|
||||
})
|
||||
|
||||
if !contextDone {
|
||||
t.Fatal("expected OnContextDone to run for an already-cancelled context")
|
||||
}
|
||||
if finalized {
|
||||
t.Fatal("expected OnFinalize not to run after context cancellation wins")
|
||||
}
|
||||
if parsedCalled {
|
||||
t.Fatal("expected parsed lines not to be processed after context cancellation wins")
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,8 @@ RULES:
|
||||
7) Numbers, booleans, and null stay plain text.
|
||||
8) Use only the parameter names in the tool schema. Do not invent fields.
|
||||
9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
|
||||
10) If you call a tool, the first non-whitespace characters of that tool block must be exactly <tool_calls>.
|
||||
11) Never omit the opening <tool_calls> tag, even if you already plan to close with </tool_calls>.
|
||||
|
||||
PARAMETER SHAPES:
|
||||
- string => <parameter name="x"><![CDATA[value]]></parameter>
|
||||
@@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences:
|
||||
` + "```xml" + `
|
||||
<tool_calls>...</tool_calls>
|
||||
` + "```" + `
|
||||
Wrong 3 — missing opening wrapper:
|
||||
<invoke name="TOOL_NAME">...</invoke>
|
||||
</tool_calls>
|
||||
|
||||
Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.
|
||||
|
||||
|
||||
@@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"read_file"})
|
||||
if !strings.Contains(out, "Never omit the opening <tool_calls> tag") {
|
||||
t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "Wrong 3 — missing opening wrapper") {
|
||||
t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func findInvokeBlocks(text, name string) []string {
|
||||
open := `<invoke name="` + name + `">`
|
||||
remaining := text
|
||||
|
||||
@@ -43,6 +43,9 @@ func parseMarkupKVObject(text string) map[string]any {
|
||||
}
|
||||
|
||||
func parseMarkupValue(inner string) any {
|
||||
if value, ok := extractStandaloneCDATA(inner); ok {
|
||||
return value
|
||||
}
|
||||
value := strings.TrimSpace(extractRawTagValue(inner))
|
||||
if value == "" {
|
||||
return ""
|
||||
@@ -89,8 +92,8 @@ func extractRawTagValue(inner string) string {
|
||||
}
|
||||
|
||||
// 1. Check for CDATA - if present, it's the ultimate "safe" container.
|
||||
if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 {
|
||||
return cdataMatches[1] // Return raw content between CDATA brackets
|
||||
if value, ok := extractStandaloneCDATA(trimmed); ok {
|
||||
return value // Return raw content between CDATA brackets
|
||||
}
|
||||
|
||||
// 2. If no CDATA, we still want to be robust.
|
||||
@@ -102,3 +105,11 @@ func extractRawTagValue(inner string) string {
|
||||
// but for KV objects we usually want the value.
|
||||
return html.UnescapeString(inner)
|
||||
}
|
||||
|
||||
func extractStandaloneCDATA(inner string) (string, bool) {
|
||||
trimmed := strings.TrimSpace(inner)
|
||||
if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 {
|
||||
return cdataMatches[1], true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -87,7 +87,13 @@ func stripFencedCodeBlocks(text string) string {
|
||||
lines := strings.SplitAfter(text, "\n")
|
||||
inFence := false
|
||||
fenceMarker := ""
|
||||
inCDATA := false
|
||||
for _, line := range lines {
|
||||
if inCDATA || cdataStartsBeforeFence(line) {
|
||||
b.WriteString(line)
|
||||
inCDATA = updateCDATAState(inCDATA, line)
|
||||
continue
|
||||
}
|
||||
trimmed := strings.TrimLeft(line, " \t")
|
||||
if !inFence {
|
||||
if marker, ok := parseFenceOpen(trimmed); ok {
|
||||
@@ -111,6 +117,54 @@ func stripFencedCodeBlocks(text string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func cdataStartsBeforeFence(line string) bool {
|
||||
cdataIdx := strings.Index(strings.ToLower(line), "<![cdata[")
|
||||
if cdataIdx < 0 {
|
||||
return false
|
||||
}
|
||||
fenceIdx := firstFenceMarkerIndex(line)
|
||||
return fenceIdx < 0 || cdataIdx < fenceIdx
|
||||
}
|
||||
|
||||
func firstFenceMarkerIndex(line string) int {
|
||||
idxBacktick := strings.Index(line, "```")
|
||||
idxTilde := strings.Index(line, "~~~")
|
||||
switch {
|
||||
case idxBacktick < 0:
|
||||
return idxTilde
|
||||
case idxTilde < 0:
|
||||
return idxBacktick
|
||||
case idxBacktick < idxTilde:
|
||||
return idxBacktick
|
||||
default:
|
||||
return idxTilde
|
||||
}
|
||||
}
|
||||
|
||||
func updateCDATAState(inCDATA bool, line string) bool {
|
||||
lower := strings.ToLower(line)
|
||||
pos := 0
|
||||
state := inCDATA
|
||||
for pos < len(lower) {
|
||||
if state {
|
||||
end := strings.Index(lower[pos:], "]]>")
|
||||
if end < 0 {
|
||||
return true
|
||||
}
|
||||
pos += end + len("]]>")
|
||||
state = false
|
||||
continue
|
||||
}
|
||||
start := strings.Index(lower[pos:], "<![cdata[")
|
||||
if start < 0 {
|
||||
return false
|
||||
}
|
||||
pos += start + len("<![cdata[")
|
||||
state = true
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
func parseFenceOpen(line string) (string, bool) {
|
||||
if len(line) < 3 {
|
||||
return "", false
|
||||
|
||||
@@ -7,22 +7,24 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)<tool_calls\b[^>]*>\s*(.*?)\s*</tool_calls>`)
|
||||
var xmlInvokePattern = regexp.MustCompile(`(?is)<invoke\b([^>]*)>\s*(.*?)\s*</invoke>`)
|
||||
var xmlParameterPattern = regexp.MustCompile(`(?is)<parameter\b([^>]*)>\s*(.*?)\s*</parameter>`)
|
||||
var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`)
|
||||
var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)</tool_calls>`)
|
||||
var xmlInvokeStartPattern = regexp.MustCompile(`(?is)<invoke\b[^>]*\bname\s*=\s*("([^"]*)"|'([^']*)')`)
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1)
|
||||
wrappers := findXMLElementBlocks(text, "tool_calls")
|
||||
if len(wrappers) == 0 {
|
||||
repaired := repairMissingXMLToolCallsOpeningWrapper(text)
|
||||
if repaired != text {
|
||||
wrappers = findXMLElementBlocks(repaired, "tool_calls")
|
||||
}
|
||||
}
|
||||
if len(wrappers) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(wrappers))
|
||||
for _, wrapper := range wrappers {
|
||||
if len(wrapper) < 2 {
|
||||
continue
|
||||
}
|
||||
for _, block := range xmlInvokePattern.FindAllStringSubmatch(wrapper[1], -1) {
|
||||
for _, block := range findXMLElementBlocks(wrapper.Body, "invoke") {
|
||||
call, ok := parseSingleXMLToolCall(block)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -36,17 +38,36 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
return out
|
||||
}
|
||||
|
||||
func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
|
||||
if len(block) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
func repairMissingXMLToolCallsOpeningWrapper(text string) string {
|
||||
lower := strings.ToLower(text)
|
||||
if strings.Contains(lower, "<tool_calls") {
|
||||
return text
|
||||
}
|
||||
attrs := parseXMLTagAttributes(block[1])
|
||||
|
||||
closeMatches := xmlToolCallsClosePattern.FindAllStringIndex(text, -1)
|
||||
if len(closeMatches) == 0 {
|
||||
return text
|
||||
}
|
||||
invokeLoc := xmlInvokeStartPattern.FindStringIndex(text)
|
||||
if invokeLoc == nil {
|
||||
return text
|
||||
}
|
||||
closeLoc := closeMatches[len(closeMatches)-1]
|
||||
if invokeLoc[0] >= closeLoc[0] {
|
||||
return text
|
||||
}
|
||||
|
||||
return text[:invokeLoc[0]] + "<tool_calls>" + text[invokeLoc[0]:closeLoc[0]] + "</tool_calls>" + text[closeLoc[1]:]
|
||||
}
|
||||
|
||||
func parseSingleXMLToolCall(block xmlElementBlock) (ParsedToolCall, bool) {
|
||||
attrs := parseXMLTagAttributes(block.Attrs)
|
||||
name := strings.TrimSpace(html.UnescapeString(attrs["name"]))
|
||||
if name == "" {
|
||||
return ParsedToolCall{}, false
|
||||
}
|
||||
|
||||
inner := strings.TrimSpace(block[2])
|
||||
inner := strings.TrimSpace(block.Body)
|
||||
if strings.HasPrefix(inner, "{") {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
|
||||
@@ -64,16 +85,13 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
|
||||
}
|
||||
|
||||
input := map[string]any{}
|
||||
for _, paramMatch := range xmlParameterPattern.FindAllStringSubmatch(inner, -1) {
|
||||
if len(paramMatch) < 3 {
|
||||
continue
|
||||
}
|
||||
paramAttrs := parseXMLTagAttributes(paramMatch[1])
|
||||
for _, paramMatch := range findXMLElementBlocks(inner, "parameter") {
|
||||
paramAttrs := parseXMLTagAttributes(paramMatch.Attrs)
|
||||
paramName := strings.TrimSpace(html.UnescapeString(paramAttrs["name"]))
|
||||
if paramName == "" {
|
||||
continue
|
||||
}
|
||||
value := parseInvokeParameterValue(paramMatch[2])
|
||||
value := parseInvokeParameterValue(paramMatch.Body)
|
||||
appendMarkupValue(input, paramName, value)
|
||||
}
|
||||
|
||||
@@ -86,6 +104,168 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
|
||||
return ParsedToolCall{Name: name, Input: input}, true
|
||||
}
|
||||
|
||||
type xmlElementBlock struct {
|
||||
Attrs string
|
||||
Body string
|
||||
Start int
|
||||
End int
|
||||
}
|
||||
|
||||
func findXMLElementBlocks(text, tag string) []xmlElementBlock {
|
||||
if text == "" || tag == "" {
|
||||
return nil
|
||||
}
|
||||
var out []xmlElementBlock
|
||||
pos := 0
|
||||
for pos < len(text) {
|
||||
start, bodyStart, attrs, ok := findXMLStartTagOutsideCDATA(text, tag, pos)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
closeStart, closeEnd, ok := findMatchingXMLEndTagOutsideCDATA(text, tag, bodyStart)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
out = append(out, xmlElementBlock{
|
||||
Attrs: attrs,
|
||||
Body: text[bodyStart:closeStart],
|
||||
Start: start,
|
||||
End: closeEnd,
|
||||
})
|
||||
pos = closeEnd
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func findXMLStartTagOutsideCDATA(text, tag string, from int) (start, bodyStart int, attrs string, ok bool) {
|
||||
lower := strings.ToLower(text)
|
||||
target := "<" + strings.ToLower(tag)
|
||||
for i := maxInt(from, 0); i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(lower, i)
|
||||
if blocked {
|
||||
return -1, -1, "", false
|
||||
}
|
||||
if advanced {
|
||||
i = next
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(lower[i:], target) && hasXMLTagBoundary(text, i+len(target)) {
|
||||
end := findXMLTagEnd(text, i+len(target))
|
||||
if end < 0 {
|
||||
return -1, -1, "", false
|
||||
}
|
||||
return i, end + 1, text[i+len(target) : end], true
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1, -1, "", false
|
||||
}
|
||||
|
||||
func findMatchingXMLEndTagOutsideCDATA(text, tag string, from int) (closeStart, closeEnd int, ok bool) {
|
||||
lower := strings.ToLower(text)
|
||||
openTarget := "<" + strings.ToLower(tag)
|
||||
closeTarget := "</" + strings.ToLower(tag)
|
||||
depth := 1
|
||||
for i := maxInt(from, 0); i < len(text); {
|
||||
next, advanced, blocked := skipXMLIgnoredSection(lower, i)
|
||||
if blocked {
|
||||
return -1, -1, false
|
||||
}
|
||||
if advanced {
|
||||
i = next
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(lower[i:], closeTarget) && hasXMLTagBoundary(text, i+len(closeTarget)) {
|
||||
end := findXMLTagEnd(text, i+len(closeTarget))
|
||||
if end < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i, end + 1, true
|
||||
}
|
||||
i = end + 1
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(lower[i:], openTarget) && hasXMLTagBoundary(text, i+len(openTarget)) {
|
||||
end := findXMLTagEnd(text, i+len(openTarget))
|
||||
if end < 0 {
|
||||
return -1, -1, false
|
||||
}
|
||||
if !isSelfClosingXMLTag(text[:end]) {
|
||||
depth++
|
||||
}
|
||||
i = end + 1
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
return -1, -1, false
|
||||
}
|
||||
|
||||
func skipXMLIgnoredSection(lower string, i int) (next int, advanced bool, blocked bool) {
|
||||
switch {
|
||||
case strings.HasPrefix(lower[i:], "<![cdata["):
|
||||
end := strings.Index(lower[i+len("<![cdata["):], "]]>")
|
||||
if end < 0 {
|
||||
return 0, false, true
|
||||
}
|
||||
return i + len("<![cdata[") + end + len("]]>"), true, false
|
||||
case strings.HasPrefix(lower[i:], "<!--"):
|
||||
end := strings.Index(lower[i+len("<!--"):], "-->")
|
||||
if end < 0 {
|
||||
return 0, false, true
|
||||
}
|
||||
return i + len("<!--") + end + len("-->"), true, false
|
||||
default:
|
||||
return i, false, false
|
||||
}
|
||||
}
|
||||
|
||||
func findXMLTagEnd(text string, from int) int {
|
||||
quote := byte(0)
|
||||
for i := maxInt(from, 0); i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '>' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func hasXMLTagBoundary(text string, idx int) bool {
|
||||
if idx >= len(text) {
|
||||
return true
|
||||
}
|
||||
switch text[idx] {
|
||||
case ' ', '\t', '\n', '\r', '>', '/':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isSelfClosingXMLTag(startTag string) bool {
|
||||
return strings.HasSuffix(strings.TrimSpace(startTag), "/")
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func parseXMLTagAttributes(raw string) map[string]string {
|
||||
if strings.TrimSpace(raw) == "" {
|
||||
return map[string]string{}
|
||||
@@ -113,6 +293,9 @@ func parseInvokeParameterValue(raw string) any {
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if value, ok := extractStandaloneCDATA(trimmed); ok {
|
||||
return value
|
||||
}
|
||||
if parsed := parseStructuredToolCallInput(trimmed); len(parsed) > 0 {
|
||||
if len(parsed) == 1 {
|
||||
if rawValue, ok := parsed["_raw"].(string); ok {
|
||||
|
||||
@@ -54,6 +54,32 @@ echo "hello"
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsKeepsToolSyntaxInsideCDATAAsParameterText(t *testing.T) {
|
||||
payload := strings.Join([]string{
|
||||
"# Release notes",
|
||||
"",
|
||||
"```xml",
|
||||
"<tool_calls>",
|
||||
" <invoke name=\"demo\">",
|
||||
" <parameter name=\"value\">x</parameter>",
|
||||
" </invoke>",
|
||||
"</tool_calls>",
|
||||
"```",
|
||||
}, "\n")
|
||||
text := `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[` + payload + `]]></parameter><parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"Write"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
content, _ := calls[0].Input["content"].(string)
|
||||
if content != payload {
|
||||
t.Fatalf("expected CDATA payload with nested tool syntax to survive intact, got %q", content)
|
||||
}
|
||||
if calls[0].Input["file_path"] != "DS2API-4.0-Release-Notes.md" {
|
||||
t.Fatalf("expected file_path parameter, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsInvokeParameters(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="get_weather"><parameter name="city">beijing</parameter><parameter name="unit">c</parameter></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"get_weather"})
|
||||
@@ -175,6 +201,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) {
|
||||
text := `Before tool call
|
||||
<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>
|
||||
</tool_calls>
|
||||
after`
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file"})
|
||||
if len(res.Calls) != 1 {
|
||||
t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res)
|
||||
}
|
||||
if res.Calls[0].Name != "read_file" {
|
||||
t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0])
|
||||
}
|
||||
if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" {
|
||||
t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input)
|
||||
}
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="read_file"><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
|
||||
|
||||
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
|
||||
// Order matters: longer/wrapper tags must be checked first.
|
||||
@@ -24,7 +24,7 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls "}
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
|
||||
|
||||
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
|
||||
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -35,9 +35,10 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
if openIdx < 0 {
|
||||
continue
|
||||
}
|
||||
// Find the LAST occurrence of the specific closing tag to get the outermost block.
|
||||
closeIdx := strings.LastIndex(lower, pair.close)
|
||||
if closeIdx < openIdx {
|
||||
// Find the matching closing tag outside CDATA. Long write-file tool
|
||||
// calls often contain XML examples in CDATA, including </tool_calls>.
|
||||
closeIdx := findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open))
|
||||
if closeIdx < 0 {
|
||||
// Opening tag is present but its specific closing tag hasn't arrived.
|
||||
// Return not-ready so we keep buffering until the canonical wrapper closes.
|
||||
return "", nil, "", false
|
||||
@@ -55,6 +56,22 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
if !strings.Contains(lower, "<tool_calls") {
|
||||
invokeIdx := strings.Index(lower, "<invoke")
|
||||
closeIdx := findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx)
|
||||
if invokeIdx >= 0 && closeIdx > invokeIdx {
|
||||
closeEnd := closeIdx + len("</tool_calls>")
|
||||
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
|
||||
prefixPart := captured[:invokeIdx]
|
||||
suffixPart := captured[closeEnd:]
|
||||
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
|
||||
if len(parsed) > 0 {
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true
|
||||
}
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
@@ -63,8 +80,9 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
func hasOpenXMLToolTag(captured string) bool {
|
||||
lower := strings.ToLower(captured)
|
||||
for _, pair := range xmlToolCallTagPairs {
|
||||
if strings.Contains(lower, pair.open) {
|
||||
if !strings.Contains(lower, pair.close) {
|
||||
openIdx := strings.Index(lower, pair.open)
|
||||
if openIdx >= 0 {
|
||||
if findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open)) < 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -72,6 +90,38 @@ func hasOpenXMLToolTag(captured string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func findXMLCloseOutsideCDATA(s, closeTag string, start int) int {
|
||||
if s == "" || closeTag == "" {
|
||||
return -1
|
||||
}
|
||||
if start < 0 {
|
||||
start = 0
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
target := strings.ToLower(closeTag)
|
||||
for i := start; i < len(s); {
|
||||
switch {
|
||||
case strings.HasPrefix(lower[i:], "<![cdata["):
|
||||
end := strings.Index(lower[i+len("<![cdata["):], "]]>")
|
||||
if end < 0 {
|
||||
return -1
|
||||
}
|
||||
i += len("<![cdata[") + end + len("]]>")
|
||||
case strings.HasPrefix(lower[i:], "<!--"):
|
||||
end := strings.Index(lower[i+len("<!--"):], "-->")
|
||||
if end < 0 {
|
||||
return -1
|
||||
}
|
||||
i += len("<!--") + end + len("-->")
|
||||
case strings.HasPrefix(lower[i:], target):
|
||||
return i
|
||||
default:
|
||||
i++
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findPartialXMLToolTagStart checks if the string ends with a partial canonical
|
||||
// XML wrapper tag (e.g., "<too") and returns the position of the '<'.
|
||||
func findPartialXMLToolTagStart(s string) int {
|
||||
|
||||
@@ -84,6 +84,65 @@ func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveKeepsCDATAEmbeddedToolClosingBuffered(t *testing.T) {
|
||||
var state State
|
||||
payload := strings.Join([]string{
|
||||
"# DS2API 4.0 更新内容",
|
||||
"",
|
||||
strings.Repeat("x", 4096),
|
||||
"```xml",
|
||||
"<tool_calls>",
|
||||
" <invoke name=\"demo\">",
|
||||
" <parameter name=\"value\">x</parameter>",
|
||||
" </invoke>",
|
||||
"</tool_calls>",
|
||||
"```",
|
||||
"tail",
|
||||
}, "\n")
|
||||
innerClose := strings.Index(payload, "</tool_calls>") + len("</tool_calls>")
|
||||
chunks := []string{
|
||||
"<tool_calls>\n <invoke name=\"Write\">\n <parameter name=\"content\"><![CDATA[",
|
||||
payload[:innerClose],
|
||||
payload[innerClose:],
|
||||
"]]></parameter>\n <parameter name=\"file_path\">DS2API-4.0-Release-Notes.md</parameter>\n </invoke>\n</tool_calls>",
|
||||
}
|
||||
|
||||
var events []Event
|
||||
for i, c := range chunks {
|
||||
next := ProcessChunk(&state, c, []string{"Write"})
|
||||
if i <= 1 {
|
||||
for _, evt := range next {
|
||||
if evt.Content != "" || len(evt.ToolCalls) > 0 {
|
||||
t.Fatalf("expected no events before outer closing tag, chunk=%d events=%#v", i, next)
|
||||
}
|
||||
}
|
||||
}
|
||||
events = append(events, next...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"Write"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
var gotPayload string
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
gotPayload, _ = evt.ToolCalls[0].Input["content"].(string)
|
||||
}
|
||||
}
|
||||
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one parsed tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if textContent.Len() != 0 {
|
||||
t.Fatalf("expected no leaked text, got %q", textContent.String())
|
||||
}
|
||||
if gotPayload != payload {
|
||||
t.Fatalf("expected full CDATA payload to survive intact, got len=%d want=%d", len(gotPayload), len(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
|
||||
var state State
|
||||
// Model outputs some prose then an XML tool call.
|
||||
@@ -288,6 +347,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"invoke_tag_missing_wrapper", "some text <invoke name=\"read_file\">\n", 10},
|
||||
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
|
||||
{"no_xml", "just plain text", -1},
|
||||
@@ -310,6 +370,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"partial_tool_calls", "Hello <tool_ca", 6},
|
||||
{"partial_invoke", "Hello <inv", 6},
|
||||
{"bare_tool_call_not_held", "Hello <tool_name", -1},
|
||||
{"partial_lt_only", "Text <", 5},
|
||||
{"complete_tag", "Text <tool_calls>done", -1},
|
||||
@@ -505,3 +566,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) {
|
||||
t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<invoke name=\"read_file\">\n",
|
||||
" <parameter name=\"path\">README.md</parameter>\n",
|
||||
"</invoke>\n",
|
||||
"</tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if strings.Contains(textContent.String(), "<invoke") || strings.Contains(textContent.String(), "</tool_calls>") {
|
||||
t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,6 +118,60 @@ test('sieve keeps long XML tool calls buffered until the closing tag arrives', (
|
||||
assert.equal(finalCalls[0].input.content, longContent);
|
||||
});
|
||||
|
||||
test('sieve keeps CDATA tool examples buffered until the outer closing tag arrives', () => {
|
||||
const content = [
|
||||
'# DS2API 4.0 更新内容',
|
||||
'',
|
||||
'x'.repeat(4096),
|
||||
'```xml',
|
||||
'<tool_calls>',
|
||||
' <invoke name="demo">',
|
||||
' <parameter name="value">x</parameter>',
|
||||
' </invoke>',
|
||||
'</tool_calls>',
|
||||
'```',
|
||||
'tail',
|
||||
].join('\n');
|
||||
const innerClose = content.indexOf('</tool_calls>') + '</tool_calls>'.length;
|
||||
const state = createToolSieveState();
|
||||
const chunks = [
|
||||
'<tool_calls>\n <invoke name="Write">\n <parameter name="content"><![CDATA[',
|
||||
content.slice(0, innerClose),
|
||||
content.slice(innerClose),
|
||||
']]></parameter>\n <parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter>\n </invoke>\n</tool_calls>',
|
||||
];
|
||||
const events = [];
|
||||
chunks.forEach((chunk, idx) => {
|
||||
const next = processToolSieveChunk(state, chunk, ['Write']);
|
||||
if (idx <= 1) {
|
||||
assert.deepEqual(next, []);
|
||||
}
|
||||
events.push(...next);
|
||||
});
|
||||
events.push(...flushToolSieve(state, ['Write']));
|
||||
|
||||
const leakedText = collectText(events);
|
||||
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
|
||||
assert.equal(leakedText, '');
|
||||
assert.equal(finalCalls.length, 1);
|
||||
assert.equal(finalCalls[0].name, 'Write');
|
||||
assert.equal(finalCalls[0].input.content, content);
|
||||
});
|
||||
|
||||
test('parseToolCalls keeps XML-looking CDATA content intact', () => {
|
||||
const content = [
|
||||
'# Release notes',
|
||||
'```xml',
|
||||
'<tool_calls><invoke name="demo"><parameter name="value">x</parameter></invoke></tool_calls>',
|
||||
'```',
|
||||
].join('\n');
|
||||
const payload = `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[${content}]]></parameter><parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter></invoke></tool_calls>`;
|
||||
const calls = parseToolCalls(payload, ['Write']);
|
||||
assert.equal(calls.length, 1);
|
||||
assert.equal(calls[0].input.content, content);
|
||||
assert.equal(calls[0].input.file_path, 'DS2API-4.0-Release-Notes.md');
|
||||
});
|
||||
|
||||
test('sieve passes JSON tool_calls payload through as text (XML-only)', () => {
|
||||
const events = runSieve(
|
||||
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],
|
||||
|
||||
Reference in New Issue
Block a user