Merge pull request #313 from CJackHwang/dev

toolcall优化补丁
This commit is contained in:
CJACK.
2026-04-26 09:53:54 +08:00
committed by GitHub
15 changed files with 821 additions and 53 deletions

View File

@@ -148,6 +148,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools`
4. 把这整段内容并入 system prompt。
工具调用正例仍只示范 canonical XML`<tool_calls>``<invoke name="...">``<parameter name="...">`
提示词会额外强调:如果要调用工具,工具块的首个非空白字符必须就是 `<tool_calls>`,不能只输出 `</tool_calls>` 而漏掉 opening tag。
正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。
对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command``exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。
@@ -192,6 +193,7 @@ assistant 历史 `tool_calls` 不会保留成 OpenAI 原生 JSON而会转成
```
这也是当前项目里唯一受支持的 canonical tool-calling 形态;其他形态都会作为普通文本保留,不会作为可执行调用语法。
例外是 parser 会对一个非常窄的模型失误做修复:如果 assistant 输出了 `<invoke ...>` ... `</tool_calls>`,但漏掉最前面的 opening `<tool_calls>`,解析阶段会补回 wrapper 后再尝试识别。
这件事很重要,因为它决定了:

View File

@@ -23,9 +23,14 @@
- 工具名必须放在 `invoke``name` 属性
- 参数必须使用 `<parameter name="...">...</parameter>`
兼容修复:
- 如果模型漏掉 opening `<tool_calls>`,但后面仍输出了一个或多个 `<invoke ...>` 并以 `</tool_calls>` 收尾Go 解析链路会在解析前补回缺失的 opening wrapper。
- 这是一个针对常见模型失误的窄修复不改变推荐输出格式prompt 仍要求模型直接输出完整 canonical XML。
## 2) 非 canonical 内容
任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。
任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。一个例外是上一节提到的“缺失 opening `<tool_calls>`、但 closing `</tool_calls>` 仍存在”的窄修复场景。
当前 parser 不把 allow-list 当作硬安全边界即使传入了已声明工具名列表XML 里出现未声明工具名时也会尽量解析并交给上层协议输出;真正的执行侧仍必须自行校验工具名和参数。
@@ -33,7 +38,8 @@
在流式链路中Go / Node 一致):
- 只有从 `<tool_calls` 开始的 canonical wrapper 会进入结构化捕获
- canonical `<tool_calls>` wrapper 会进入结构化捕获
- 如果流里直接从 `<invoke ...>` 开始,但后面补上了 `</tool_calls>`Go 流式筛分也会按缺失 opening wrapper 的修复路径尝试恢复
- 已识别成功的工具调用不会再次回流到普通文本
- 不符合新格式的块不会执行,并继续按原样文本透传
- fenced code block 中的 XML 示例始终按普通文本处理
@@ -43,14 +49,14 @@
`ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回:
- `calls`:解析出的工具调用列表(`name` + `input`
- `sawToolCallSyntax`只有检测到 `<tool_calls` 时才会为 `true`
- `sawToolCallSyntax`:检测到 canonical wrapper或命中“缺失 opening wrapper 但可修复”的形态时会为 `true`
- `rejectedByPolicy`:当前固定为 `false`
- `rejectedToolNames`:当前固定为空数组
## 5) 落地建议
1. Prompt 里只示范 canonical XML 语法。
2. 上游客户端需要直接输出 canonical XMLDS2API 不会把其他形态改写成工具调用
2. 上游客户端仍应直接输出 canonical XMLDS2API 只对“closing tag 在、opening tag 漏掉”的常见失误做窄修复,不会泛化接受其他旧格式
3. 不要依赖 parser 做安全控制;执行器侧仍应做工具名和参数校验。
## 6) 回归验证

View File

@@ -1,8 +1,5 @@
'use strict';
const TOOLS_WRAPPER_PATTERN = /<tool_calls\b[^>]*>([\s\S]*?)<\/tool_calls>/gi;
const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?invoke>/gi;
const PARAMETER_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi;
const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi;
const CDATA_PATTERN = /^<!\[CDATA\[([\s\S]*?)]]>$/i;
const XML_ATTR_PATTERN = /\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')/gi;
@@ -25,9 +22,9 @@ function parseMarkupToolCalls(text) {
return [];
}
const out = [];
for (const wrapper of raw.matchAll(TOOLS_WRAPPER_PATTERN)) {
const body = toStringSafe(wrapper[1]);
for (const block of body.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) {
for (const wrapper of findXmlElementBlocks(raw, 'tool_calls')) {
const body = toStringSafe(wrapper.body);
for (const block of findXmlElementBlocks(body, 'invoke')) {
const parsed = parseMarkupSingleToolCall(block);
if (parsed) {
out.push(parsed);
@@ -38,12 +35,12 @@ function parseMarkupToolCalls(text) {
}
function parseMarkupSingleToolCall(block) {
const attrs = parseTagAttributes(block[1]);
const attrs = parseTagAttributes(block.attrs);
const name = toStringSafe(attrs.name).trim();
if (!name) {
return null;
}
const inner = toStringSafe(block[2]).trim();
const inner = toStringSafe(block.body).trim();
if (inner) {
try {
@@ -63,13 +60,13 @@ function parseMarkupSingleToolCall(block) {
}
}
const input = {};
for (const match of inner.matchAll(PARAMETER_BLOCK_PATTERN)) {
const parameterAttrs = parseTagAttributes(match[1]);
for (const match of findXmlElementBlocks(inner, 'parameter')) {
const parameterAttrs = parseTagAttributes(match.attrs);
const paramName = toStringSafe(parameterAttrs.name).trim();
if (!paramName) {
continue;
}
appendMarkupValue(input, paramName, parseMarkupValue(match[2]));
appendMarkupValue(input, paramName, parseMarkupValue(match.body));
}
if (Object.keys(input).length === 0 && inner.trim() !== '') {
return null;
@@ -77,6 +74,154 @@ function parseMarkupSingleToolCall(block) {
return { name, input };
}
function findXmlElementBlocks(text, tag) {
const source = toStringSafe(text);
const name = toStringSafe(tag).toLowerCase();
if (!source || !name) {
return [];
}
const out = [];
let pos = 0;
while (pos < source.length) {
const start = findXmlStartTagOutsideCDATA(source, name, pos);
if (!start) {
break;
}
const end = findMatchingXmlEndTagOutsideCDATA(source, name, start.bodyStart);
if (!end) {
break;
}
out.push({
attrs: start.attrs,
body: source.slice(start.bodyStart, end.closeStart),
start: start.start,
end: end.closeEnd,
});
pos = end.closeEnd;
}
return out;
}
function findXmlStartTagOutsideCDATA(text, tag, from) {
const lower = text.toLowerCase();
const target = `<${tag}`;
for (let i = Math.max(0, from || 0); i < text.length;) {
const skipped = skipXmlIgnoredSection(lower, i);
if (skipped.blocked) {
return null;
}
if (skipped.advanced) {
i = skipped.next;
continue;
}
if (lower.startsWith(target, i) && hasXmlTagBoundary(text, i + target.length)) {
const tagEnd = findXmlTagEnd(text, i + target.length);
if (tagEnd < 0) {
return null;
}
return {
start: i,
bodyStart: tagEnd + 1,
attrs: text.slice(i + target.length, tagEnd),
};
}
i += 1;
}
return null;
}
function findMatchingXmlEndTagOutsideCDATA(text, tag, from) {
const lower = text.toLowerCase();
const openTarget = `<${tag}`;
const closeTarget = `</${tag}`;
let depth = 1;
for (let i = Math.max(0, from || 0); i < text.length;) {
const skipped = skipXmlIgnoredSection(lower, i);
if (skipped.blocked) {
return null;
}
if (skipped.advanced) {
i = skipped.next;
continue;
}
if (lower.startsWith(closeTarget, i) && hasXmlTagBoundary(text, i + closeTarget.length)) {
const tagEnd = findXmlTagEnd(text, i + closeTarget.length);
if (tagEnd < 0) {
return null;
}
depth -= 1;
if (depth === 0) {
return { closeStart: i, closeEnd: tagEnd + 1 };
}
i = tagEnd + 1;
continue;
}
if (lower.startsWith(openTarget, i) && hasXmlTagBoundary(text, i + openTarget.length)) {
const tagEnd = findXmlTagEnd(text, i + openTarget.length);
if (tagEnd < 0) {
return null;
}
if (!isSelfClosingXmlTag(text.slice(i, tagEnd))) {
depth += 1;
}
i = tagEnd + 1;
continue;
}
i += 1;
}
return null;
}
function skipXmlIgnoredSection(lower, i) {
if (lower.startsWith('<![cdata[', i)) {
const end = lower.indexOf(']]>', i + '<![cdata['.length);
if (end < 0) {
return { advanced: false, blocked: true, next: i };
}
return { advanced: true, blocked: false, next: end + ']]>'.length };
}
if (lower.startsWith('<!--', i)) {
const end = lower.indexOf('-->', i + '<!--'.length);
if (end < 0) {
return { advanced: false, blocked: true, next: i };
}
return { advanced: true, blocked: false, next: end + '-->'.length };
}
return { advanced: false, blocked: false, next: i };
}
function findXmlTagEnd(text, from) {
let quote = '';
for (let i = Math.max(0, from || 0); i < text.length; i += 1) {
const ch = text[i];
if (quote) {
if (ch === quote) {
quote = '';
}
continue;
}
if (ch === '"' || ch === "'") {
quote = ch;
continue;
}
if (ch === '>') {
return i;
}
}
return -1;
}
function hasXmlTagBoundary(text, idx) {
if (idx >= text.length) {
return true;
}
return [' ', '\t', '\n', '\r', '>', '/'].includes(text[idx]);
}
function isSelfClosingXmlTag(startTag) {
return toStringSafe(startTag).trim().endsWith('/');
}
function parseMarkupInput(raw) {
const s = toStringSafe(raw).trim();
if (!s) {
@@ -120,6 +265,10 @@ function parseMarkupKVObject(text) {
}
function parseMarkupValue(raw) {
const cdata = extractStandaloneCDATA(raw);
if (cdata.ok) {
return cdata.value;
}
const s = toStringSafe(extractRawTagValue(raw)).trim();
if (!s) {
return '';
@@ -152,9 +301,9 @@ function extractRawTagValue(inner) {
}
// 1. Check for CDATA
const cdataMatch = s.match(CDATA_PATTERN);
if (cdataMatch && cdataMatch[1] !== undefined) {
return cdataMatch[1];
const cdata = extractStandaloneCDATA(s);
if (cdata.ok) {
return cdata.value;
}
// 2. Fallback to unescaping standard HTML entities
@@ -172,6 +321,15 @@ function unescapeHtml(safe) {
.replace(/&#x27;/g, "'");
}
function extractStandaloneCDATA(inner) {
const s = toStringSafe(inner).trim();
const cdataMatch = s.match(CDATA_PATTERN);
if (cdataMatch && cdataMatch[1] !== undefined) {
return { ok: true, value: cdataMatch[1] };
}
return { ok: false, value: '' };
}
function parseTagAttributes(raw) {
const source = toStringSafe(raw);
const out = {};

View File

@@ -16,9 +16,10 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
if (openIdx < 0) {
continue;
}
// Find the LAST occurrence of the specific closing tag.
const closeIdx = lower.lastIndexOf(pair.close);
if (closeIdx < openIdx) {
// Ignore closing tags that appear inside CDATA payloads, such as
// write-file content containing tool-call documentation examples.
const closeIdx = findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length);
if (closeIdx < 0) {
// Opening tag present but specific closing tag hasn't arrived.
// Return not-ready so buffering continues until the wrapper closes.
return { ready: false, prefix: '', calls: [], suffix: '' };
@@ -46,8 +47,9 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) {
function hasOpenXMLToolTag(captured) {
const lower = captured.toLowerCase();
for (const pair of XML_TOOL_TAG_PAIRS) {
if (lower.includes(pair.open)) {
if (!lower.includes(pair.close)) {
const openIdx = lower.indexOf(pair.open);
if (openIdx >= 0) {
if (findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length) < 0) {
return true;
}
}
@@ -74,6 +76,38 @@ function findPartialXMLToolTagStart(s) {
return -1;
}
function findXMLCloseOutsideCDATA(s, closeTag, start) {
const text = typeof s === 'string' ? s : '';
const target = String(closeTag || '').toLowerCase();
if (!text || !target) {
return -1;
}
const lower = text.toLowerCase();
for (let i = Math.max(0, start || 0); i < text.length;) {
if (lower.startsWith('<![cdata[', i)) {
const end = lower.indexOf(']]>', i + '<![cdata['.length);
if (end < 0) {
return -1;
}
i = end + ']]>'.length;
continue;
}
if (lower.startsWith('<!--', i)) {
const end = lower.indexOf('-->', i + '<!--'.length);
if (end < 0) {
return -1;
}
i = end + '-->'.length;
continue;
}
if (lower.startsWith(target, i)) {
return i;
}
i += 1;
}
return -1;
}
module.exports = {
consumeXMLToolCapture,
hasOpenXMLToolTag,

View File

@@ -71,15 +71,30 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
hooks.OnFinalize(reason, scannerErr)
}
}
contextDone := func() bool {
if cfg.Context.Err() == nil {
return false
}
if hooks.OnContextDone != nil {
hooks.OnContextDone()
}
return true
}
for {
if contextDone() {
return
}
select {
case <-cfg.Context.Done():
if hooks.OnContextDone != nil {
hooks.OnContextDone()
if contextDone() {
return
}
return
case <-tickCh(ticker):
if contextDone() {
return
}
if !hasContent {
keepaliveCount++
if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput {
@@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
hooks.OnKeepAlive()
}
case parsed, ok := <-parsedLines:
if contextDone() {
return
}
if !ok {
finalize(StopReasonUpstreamCompleted, <-done)
return

View File

@@ -0,0 +1,47 @@
package stream
import (
"context"
"strings"
"testing"
"ds2api/internal/sse"
)
func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
var finalized bool
var contextDone bool
var parsedCalled bool
ConsumeSSE(ConsumeConfig{
Context: ctx,
Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"),
ThinkingEnabled: false,
InitialType: "text",
KeepAliveInterval: 0,
}, ConsumeHooks{
OnParsed: func(_ sse.LineResult) ParsedDecision {
parsedCalled = true
return ParsedDecision{}
},
OnFinalize: func(_ StopReason, _ error) {
finalized = true
},
OnContextDone: func() {
contextDone = true
},
})
if !contextDone {
t.Fatal("expected OnContextDone to run for an already-cancelled context")
}
if finalized {
t.Fatal("expected OnFinalize not to run after context cancellation wins")
}
if parsedCalled {
t.Fatal("expected parsed lines not to be processed after context cancellation wins")
}
}

View File

@@ -27,6 +27,8 @@ RULES:
7) Numbers, booleans, and null stay plain text.
8) Use only the parameter names in the tool schema. Do not invent fields.
9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
10) If you call a tool, the first non-whitespace characters of that tool block must be exactly <tool_calls>.
11) Never omit the opening <tool_calls> tag, even if you already plan to close with </tool_calls>.
PARAMETER SHAPES:
- string => <parameter name="x"><![CDATA[value]]></parameter>
@@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences:
` + "```xml" + `
<tool_calls>...</tool_calls>
` + "```" + `
Wrong 3 — missing opening wrapper:
<invoke name="TOOL_NAME">...</invoke>
</tool_calls>
Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.

View File

@@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) {
}
}
func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) {
out := BuildToolCallInstructions([]string{"read_file"})
if !strings.Contains(out, "Never omit the opening <tool_calls> tag") {
t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out)
}
if !strings.Contains(out, "Wrong 3 — missing opening wrapper") {
t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out)
}
}
func findInvokeBlocks(text, name string) []string {
open := `<invoke name="` + name + `">`
remaining := text

View File

@@ -43,6 +43,9 @@ func parseMarkupKVObject(text string) map[string]any {
}
func parseMarkupValue(inner string) any {
if value, ok := extractStandaloneCDATA(inner); ok {
return value
}
value := strings.TrimSpace(extractRawTagValue(inner))
if value == "" {
return ""
@@ -89,8 +92,8 @@ func extractRawTagValue(inner string) string {
}
// 1. Check for CDATA - if present, it's the ultimate "safe" container.
if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 {
return cdataMatches[1] // Return raw content between CDATA brackets
if value, ok := extractStandaloneCDATA(trimmed); ok {
return value // Return raw content between CDATA brackets
}
// 2. If no CDATA, we still want to be robust.
@@ -102,3 +105,11 @@ func extractRawTagValue(inner string) string {
// but for KV objects we usually want the value.
return html.UnescapeString(inner)
}
func extractStandaloneCDATA(inner string) (string, bool) {
trimmed := strings.TrimSpace(inner)
if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 {
return cdataMatches[1], true
}
return "", false
}

View File

@@ -87,7 +87,13 @@ func stripFencedCodeBlocks(text string) string {
lines := strings.SplitAfter(text, "\n")
inFence := false
fenceMarker := ""
inCDATA := false
for _, line := range lines {
if inCDATA || cdataStartsBeforeFence(line) {
b.WriteString(line)
inCDATA = updateCDATAState(inCDATA, line)
continue
}
trimmed := strings.TrimLeft(line, " \t")
if !inFence {
if marker, ok := parseFenceOpen(trimmed); ok {
@@ -111,6 +117,54 @@ func stripFencedCodeBlocks(text string) string {
return b.String()
}
func cdataStartsBeforeFence(line string) bool {
cdataIdx := strings.Index(strings.ToLower(line), "<![cdata[")
if cdataIdx < 0 {
return false
}
fenceIdx := firstFenceMarkerIndex(line)
return fenceIdx < 0 || cdataIdx < fenceIdx
}
func firstFenceMarkerIndex(line string) int {
idxBacktick := strings.Index(line, "```")
idxTilde := strings.Index(line, "~~~")
switch {
case idxBacktick < 0:
return idxTilde
case idxTilde < 0:
return idxBacktick
case idxBacktick < idxTilde:
return idxBacktick
default:
return idxTilde
}
}
func updateCDATAState(inCDATA bool, line string) bool {
lower := strings.ToLower(line)
pos := 0
state := inCDATA
for pos < len(lower) {
if state {
end := strings.Index(lower[pos:], "]]>")
if end < 0 {
return true
}
pos += end + len("]]>")
state = false
continue
}
start := strings.Index(lower[pos:], "<![cdata[")
if start < 0 {
return false
}
pos += start + len("<![cdata[")
state = true
}
return state
}
func parseFenceOpen(line string) (string, bool) {
if len(line) < 3 {
return "", false

View File

@@ -7,22 +7,24 @@ import (
"strings"
)
var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)<tool_calls\b[^>]*>\s*(.*?)\s*</tool_calls>`)
var xmlInvokePattern = regexp.MustCompile(`(?is)<invoke\b([^>]*)>\s*(.*?)\s*</invoke>`)
var xmlParameterPattern = regexp.MustCompile(`(?is)<parameter\b([^>]*)>\s*(.*?)\s*</parameter>`)
var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`)
var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)</tool_calls>`)
var xmlInvokeStartPattern = regexp.MustCompile(`(?is)<invoke\b[^>]*\bname\s*=\s*("([^"]*)"|'([^']*)')`)
func parseXMLToolCalls(text string) []ParsedToolCall {
wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1)
wrappers := findXMLElementBlocks(text, "tool_calls")
if len(wrappers) == 0 {
repaired := repairMissingXMLToolCallsOpeningWrapper(text)
if repaired != text {
wrappers = findXMLElementBlocks(repaired, "tool_calls")
}
}
if len(wrappers) == 0 {
return nil
}
out := make([]ParsedToolCall, 0, len(wrappers))
for _, wrapper := range wrappers {
if len(wrapper) < 2 {
continue
}
for _, block := range xmlInvokePattern.FindAllStringSubmatch(wrapper[1], -1) {
for _, block := range findXMLElementBlocks(wrapper.Body, "invoke") {
call, ok := parseSingleXMLToolCall(block)
if !ok {
continue
@@ -36,17 +38,36 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
return out
}
func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
if len(block) < 3 {
return ParsedToolCall{}, false
func repairMissingXMLToolCallsOpeningWrapper(text string) string {
lower := strings.ToLower(text)
if strings.Contains(lower, "<tool_calls") {
return text
}
attrs := parseXMLTagAttributes(block[1])
closeMatches := xmlToolCallsClosePattern.FindAllStringIndex(text, -1)
if len(closeMatches) == 0 {
return text
}
invokeLoc := xmlInvokeStartPattern.FindStringIndex(text)
if invokeLoc == nil {
return text
}
closeLoc := closeMatches[len(closeMatches)-1]
if invokeLoc[0] >= closeLoc[0] {
return text
}
return text[:invokeLoc[0]] + "<tool_calls>" + text[invokeLoc[0]:closeLoc[0]] + "</tool_calls>" + text[closeLoc[1]:]
}
func parseSingleXMLToolCall(block xmlElementBlock) (ParsedToolCall, bool) {
attrs := parseXMLTagAttributes(block.Attrs)
name := strings.TrimSpace(html.UnescapeString(attrs["name"]))
if name == "" {
return ParsedToolCall{}, false
}
inner := strings.TrimSpace(block[2])
inner := strings.TrimSpace(block.Body)
if strings.HasPrefix(inner, "{") {
var payload map[string]any
if err := json.Unmarshal([]byte(inner), &payload); err == nil {
@@ -64,16 +85,13 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
}
input := map[string]any{}
for _, paramMatch := range xmlParameterPattern.FindAllStringSubmatch(inner, -1) {
if len(paramMatch) < 3 {
continue
}
paramAttrs := parseXMLTagAttributes(paramMatch[1])
for _, paramMatch := range findXMLElementBlocks(inner, "parameter") {
paramAttrs := parseXMLTagAttributes(paramMatch.Attrs)
paramName := strings.TrimSpace(html.UnescapeString(paramAttrs["name"]))
if paramName == "" {
continue
}
value := parseInvokeParameterValue(paramMatch[2])
value := parseInvokeParameterValue(paramMatch.Body)
appendMarkupValue(input, paramName, value)
}
@@ -86,6 +104,168 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
return ParsedToolCall{Name: name, Input: input}, true
}
type xmlElementBlock struct {
Attrs string
Body string
Start int
End int
}
func findXMLElementBlocks(text, tag string) []xmlElementBlock {
if text == "" || tag == "" {
return nil
}
var out []xmlElementBlock
pos := 0
for pos < len(text) {
start, bodyStart, attrs, ok := findXMLStartTagOutsideCDATA(text, tag, pos)
if !ok {
break
}
closeStart, closeEnd, ok := findMatchingXMLEndTagOutsideCDATA(text, tag, bodyStart)
if !ok {
break
}
out = append(out, xmlElementBlock{
Attrs: attrs,
Body: text[bodyStart:closeStart],
Start: start,
End: closeEnd,
})
pos = closeEnd
}
return out
}
func findXMLStartTagOutsideCDATA(text, tag string, from int) (start, bodyStart int, attrs string, ok bool) {
lower := strings.ToLower(text)
target := "<" + strings.ToLower(tag)
for i := maxInt(from, 0); i < len(text); {
next, advanced, blocked := skipXMLIgnoredSection(lower, i)
if blocked {
return -1, -1, "", false
}
if advanced {
i = next
continue
}
if strings.HasPrefix(lower[i:], target) && hasXMLTagBoundary(text, i+len(target)) {
end := findXMLTagEnd(text, i+len(target))
if end < 0 {
return -1, -1, "", false
}
return i, end + 1, text[i+len(target) : end], true
}
i++
}
return -1, -1, "", false
}
func findMatchingXMLEndTagOutsideCDATA(text, tag string, from int) (closeStart, closeEnd int, ok bool) {
lower := strings.ToLower(text)
openTarget := "<" + strings.ToLower(tag)
closeTarget := "</" + strings.ToLower(tag)
depth := 1
for i := maxInt(from, 0); i < len(text); {
next, advanced, blocked := skipXMLIgnoredSection(lower, i)
if blocked {
return -1, -1, false
}
if advanced {
i = next
continue
}
if strings.HasPrefix(lower[i:], closeTarget) && hasXMLTagBoundary(text, i+len(closeTarget)) {
end := findXMLTagEnd(text, i+len(closeTarget))
if end < 0 {
return -1, -1, false
}
depth--
if depth == 0 {
return i, end + 1, true
}
i = end + 1
continue
}
if strings.HasPrefix(lower[i:], openTarget) && hasXMLTagBoundary(text, i+len(openTarget)) {
end := findXMLTagEnd(text, i+len(openTarget))
if end < 0 {
return -1, -1, false
}
if !isSelfClosingXMLTag(text[:end]) {
depth++
}
i = end + 1
continue
}
i++
}
return -1, -1, false
}
func skipXMLIgnoredSection(lower string, i int) (next int, advanced bool, blocked bool) {
switch {
case strings.HasPrefix(lower[i:], "<![cdata["):
end := strings.Index(lower[i+len("<![cdata["):], "]]>")
if end < 0 {
return 0, false, true
}
return i + len("<![cdata[") + end + len("]]>"), true, false
case strings.HasPrefix(lower[i:], "<!--"):
end := strings.Index(lower[i+len("<!--"):], "-->")
if end < 0 {
return 0, false, true
}
return i + len("<!--") + end + len("-->"), true, false
default:
return i, false, false
}
}
func findXMLTagEnd(text string, from int) int {
quote := byte(0)
for i := maxInt(from, 0); i < len(text); i++ {
ch := text[i]
if quote != 0 {
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '>' {
return i
}
}
return -1
}
func hasXMLTagBoundary(text string, idx int) bool {
if idx >= len(text) {
return true
}
switch text[idx] {
case ' ', '\t', '\n', '\r', '>', '/':
return true
default:
return false
}
}
func isSelfClosingXMLTag(startTag string) bool {
return strings.HasSuffix(strings.TrimSpace(startTag), "/")
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func parseXMLTagAttributes(raw string) map[string]string {
if strings.TrimSpace(raw) == "" {
return map[string]string{}
@@ -113,6 +293,9 @@ func parseInvokeParameterValue(raw string) any {
if trimmed == "" {
return ""
}
if value, ok := extractStandaloneCDATA(trimmed); ok {
return value
}
if parsed := parseStructuredToolCallInput(trimmed); len(parsed) > 0 {
if len(parsed) == 1 {
if rawValue, ok := parsed["_raw"].(string); ok {

View File

@@ -54,6 +54,32 @@ echo "hello"
}
}
func TestParseToolCallsKeepsToolSyntaxInsideCDATAAsParameterText(t *testing.T) {
payload := strings.Join([]string{
"# Release notes",
"",
"```xml",
"<tool_calls>",
" <invoke name=\"demo\">",
" <parameter name=\"value\">x</parameter>",
" </invoke>",
"</tool_calls>",
"```",
}, "\n")
text := `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[` + payload + `]]></parameter><parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"Write"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
content, _ := calls[0].Input["content"].(string)
if content != payload {
t.Fatalf("expected CDATA payload with nested tool syntax to survive intact, got %q", content)
}
if calls[0].Input["file_path"] != "DS2API-4.0-Release-Notes.md" {
t.Fatalf("expected file_path parameter, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsInvokeParameters(t *testing.T) {
text := `<tool_calls><invoke name="get_weather"><parameter name="city">beijing</parameter><parameter name="unit">c</parameter></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"get_weather"})
@@ -175,6 +201,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) {
}
}
func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) {
text := `Before tool call
<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>
</tool_calls>
after`
res := ParseToolCallsDetailed(text, []string{"read_file"})
if len(res.Calls) != 1 {
t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res)
}
if res.Calls[0].Name != "read_file" {
t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0])
}
if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" {
t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input)
}
if !res.SawToolCallSyntax {
t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res)
}
}
func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) {
text := `<tool_calls><invoke name="read_file"><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"read_file"})

View File

@@ -10,7 +10,7 @@ import (
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
var xmlToolCallClosingTags = []string{"</tool_calls>"}
var xmlToolCallOpeningTags = []string{"<tool_calls"}
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
// Order matters: longer/wrapper tags must be checked first.
@@ -24,7 +24,7 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls "}
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
@@ -35,9 +35,10 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
if openIdx < 0 {
continue
}
// Find the LAST occurrence of the specific closing tag to get the outermost block.
closeIdx := strings.LastIndex(lower, pair.close)
if closeIdx < openIdx {
// Find the matching closing tag outside CDATA. Long write-file tool
// calls often contain XML examples in CDATA, including </tool_calls>.
closeIdx := findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open))
if closeIdx < 0 {
// Opening tag is present but its specific closing tag hasn't arrived.
// Return not-ready so we keep buffering until the canonical wrapper closes.
return "", nil, "", false
@@ -55,6 +56,22 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
// If this block failed to become a tool call, pass it through as text.
return prefixPart + xmlBlock, nil, suffixPart, true
}
if !strings.Contains(lower, "<tool_calls") {
invokeIdx := strings.Index(lower, "<invoke")
closeIdx := findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx)
if invokeIdx >= 0 && closeIdx > invokeIdx {
closeEnd := closeIdx + len("</tool_calls>")
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
prefixPart := captured[:invokeIdx]
suffixPart := captured[closeEnd:]
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
if len(parsed) > 0 {
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
return prefixPart, parsed, suffixPart, true
}
return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true
}
}
return "", nil, "", false
}
@@ -63,8 +80,9 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
func hasOpenXMLToolTag(captured string) bool {
lower := strings.ToLower(captured)
for _, pair := range xmlToolCallTagPairs {
if strings.Contains(lower, pair.open) {
if !strings.Contains(lower, pair.close) {
openIdx := strings.Index(lower, pair.open)
if openIdx >= 0 {
if findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open)) < 0 {
return true
}
}
@@ -72,6 +90,38 @@ func hasOpenXMLToolTag(captured string) bool {
return false
}
func findXMLCloseOutsideCDATA(s, closeTag string, start int) int {
if s == "" || closeTag == "" {
return -1
}
if start < 0 {
start = 0
}
lower := strings.ToLower(s)
target := strings.ToLower(closeTag)
for i := start; i < len(s); {
switch {
case strings.HasPrefix(lower[i:], "<![cdata["):
end := strings.Index(lower[i+len("<![cdata["):], "]]>")
if end < 0 {
return -1
}
i += len("<![cdata[") + end + len("]]>")
case strings.HasPrefix(lower[i:], "<!--"):
end := strings.Index(lower[i+len("<!--"):], "-->")
if end < 0 {
return -1
}
i += len("<!--") + end + len("-->")
case strings.HasPrefix(lower[i:], target):
return i
default:
i++
}
}
return -1
}
// findPartialXMLToolTagStart checks if the string ends with a partial canonical
// XML wrapper tag (e.g., "<too") and returns the position of the '<'.
func findPartialXMLToolTagStart(s string) int {

View File

@@ -84,6 +84,65 @@ func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
}
}
func TestProcessToolSieveKeepsCDATAEmbeddedToolClosingBuffered(t *testing.T) {
var state State
payload := strings.Join([]string{
"# DS2API 4.0 更新内容",
"",
strings.Repeat("x", 4096),
"```xml",
"<tool_calls>",
" <invoke name=\"demo\">",
" <parameter name=\"value\">x</parameter>",
" </invoke>",
"</tool_calls>",
"```",
"tail",
}, "\n")
innerClose := strings.Index(payload, "</tool_calls>") + len("</tool_calls>")
chunks := []string{
"<tool_calls>\n <invoke name=\"Write\">\n <parameter name=\"content\"><![CDATA[",
payload[:innerClose],
payload[innerClose:],
"]]></parameter>\n <parameter name=\"file_path\">DS2API-4.0-Release-Notes.md</parameter>\n </invoke>\n</tool_calls>",
}
var events []Event
for i, c := range chunks {
next := ProcessChunk(&state, c, []string{"Write"})
if i <= 1 {
for _, evt := range next {
if evt.Content != "" || len(evt.ToolCalls) > 0 {
t.Fatalf("expected no events before outer closing tag, chunk=%d events=%#v", i, next)
}
}
}
events = append(events, next...)
}
events = append(events, Flush(&state, []string{"Write"})...)
var textContent strings.Builder
var gotPayload string
toolCalls := 0
for _, evt := range events {
textContent.WriteString(evt.Content)
if len(evt.ToolCalls) > 0 {
toolCalls += len(evt.ToolCalls)
gotPayload, _ = evt.ToolCalls[0].Input["content"].(string)
}
}
if toolCalls != 1 {
t.Fatalf("expected one parsed tool call, got %d events=%#v", toolCalls, events)
}
if textContent.Len() != 0 {
t.Fatalf("expected no leaked text, got %q", textContent.String())
}
if gotPayload != payload {
t.Fatalf("expected full CDATA payload to survive intact, got len=%d want=%d", len(gotPayload), len(payload))
}
}
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
var state State
// Model outputs some prose then an XML tool call.
@@ -288,6 +347,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
want int
}{
{"tool_calls_tag", "some text <tool_calls>\n", 10},
{"invoke_tag_missing_wrapper", "some text <invoke name=\"read_file\">\n", 10},
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
{"no_xml", "just plain text", -1},
@@ -310,6 +370,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
want int
}{
{"partial_tool_calls", "Hello <tool_ca", 6},
{"partial_invoke", "Hello <inv", 6},
{"bare_tool_call_not_held", "Hello <tool_name", -1},
{"partial_lt_only", "Text <", 5},
{"complete_tag", "Text <tool_calls>done", -1},
@@ -505,3 +566,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) {
t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String())
}
}
func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) {
var state State
chunks := []string{
"<invoke name=\"read_file\">\n",
" <parameter name=\"path\">README.md</parameter>\n",
"</invoke>\n",
"</tool_calls>",
}
var events []Event
for _, c := range chunks {
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
}
events = append(events, Flush(&state, []string{"read_file"})...)
var textContent strings.Builder
toolCalls := 0
for _, evt := range events {
textContent.WriteString(evt.Content)
toolCalls += len(evt.ToolCalls)
}
if toolCalls != 1 {
t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events)
}
if strings.Contains(textContent.String(), "<invoke") || strings.Contains(textContent.String(), "</tool_calls>") {
t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String())
}
}

View File

@@ -118,6 +118,60 @@ test('sieve keeps long XML tool calls buffered until the closing tag arrives', (
assert.equal(finalCalls[0].input.content, longContent);
});
test('sieve keeps CDATA tool examples buffered until the outer closing tag arrives', () => {
const content = [
'# DS2API 4.0 更新内容',
'',
'x'.repeat(4096),
'```xml',
'<tool_calls>',
' <invoke name="demo">',
' <parameter name="value">x</parameter>',
' </invoke>',
'</tool_calls>',
'```',
'tail',
].join('\n');
const innerClose = content.indexOf('</tool_calls>') + '</tool_calls>'.length;
const state = createToolSieveState();
const chunks = [
'<tool_calls>\n <invoke name="Write">\n <parameter name="content"><![CDATA[',
content.slice(0, innerClose),
content.slice(innerClose),
']]></parameter>\n <parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter>\n </invoke>\n</tool_calls>',
];
const events = [];
chunks.forEach((chunk, idx) => {
const next = processToolSieveChunk(state, chunk, ['Write']);
if (idx <= 1) {
assert.deepEqual(next, []);
}
events.push(...next);
});
events.push(...flushToolSieve(state, ['Write']));
const leakedText = collectText(events);
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
assert.equal(leakedText, '');
assert.equal(finalCalls.length, 1);
assert.equal(finalCalls[0].name, 'Write');
assert.equal(finalCalls[0].input.content, content);
});
test('parseToolCalls keeps XML-looking CDATA content intact', () => {
const content = [
'# Release notes',
'```xml',
'<tool_calls><invoke name="demo"><parameter name="value">x</parameter></invoke></tool_calls>',
'```',
].join('\n');
const payload = `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[${content}]]></parameter><parameter name="file_path">DS2API-4.0-Release-Notes.md</parameter></invoke></tool_calls>`;
const calls = parseToolCalls(payload, ['Write']);
assert.equal(calls.length, 1);
assert.equal(calls[0].input.content, content);
assert.equal(calls[0].input.file_path, 'DS2API-4.0-Release-Notes.md');
});
test('sieve passes JSON tool_calls payload through as text (XML-only)', () => {
const events = runSieve(
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],