diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index 6d6301c..495d1cc 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -148,6 +148,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` 4. 把这整段内容并入 system prompt。 工具调用正例仍只示范 canonical XML:`` → `` → ``。 +提示词会额外强调:如果要调用工具,工具块的首个非空白字符必须就是 ``,不能只输出 `` 而漏掉 opening tag。 正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。 对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。 @@ -192,6 +193,7 @@ assistant 历史 `tool_calls` 不会保留成 OpenAI 原生 JSON,而会转成 ``` 这也是当前项目里唯一受支持的 canonical tool-calling 形态;其他形态都会作为普通文本保留,不会作为可执行调用语法。 +例外是 parser 会对一个非常窄的模型失误做修复:如果 assistant 输出了 `` ... ``,但漏掉最前面的 opening ``,解析阶段会补回 wrapper 后再尝试识别。 这件事很重要,因为它决定了: diff --git a/docs/toolcall-semantics.md b/docs/toolcall-semantics.md index 2627a0a..ea5c456 100644 --- a/docs/toolcall-semantics.md +++ b/docs/toolcall-semantics.md @@ -23,9 +23,14 @@ - 工具名必须放在 `invoke` 的 `name` 属性 - 参数必须使用 `...` +兼容修复: + +- 如果模型漏掉 opening ``,但后面仍输出了一个或多个 `` 并以 `` 收尾,Go 解析链路会在解析前补回缺失的 opening wrapper。 +- 这是一个针对常见模型失误的窄修复,不改变推荐输出格式;prompt 仍要求模型直接输出完整 canonical XML。 + ## 2) 非 canonical 内容 -任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。 +任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。一个例外是上一节提到的“缺失 opening ``、但 closing `` 仍存在”的窄修复场景。 当前 parser 不把 allow-list 当作硬安全边界:即使传入了已声明工具名列表,XML 里出现未声明工具名时也会尽量解析并交给上层协议输出;真正的执行侧仍必须自行校验工具名和参数。 @@ -33,7 +38,8 @@ 在流式链路中(Go / Node 一致): -- 只有从 `` wrapper 会进入结构化捕获 +- 如果流里直接从 `` 开始,但后面补上了 ``,Go 流式筛分也会按缺失 opening wrapper 的修复路径尝试恢复 - 已识别成功的工具调用不会再次回流到普通文本 - 不符合新格式的块不会执行,并继续按原样文本透传 - fenced code block 中的 XML 示例始终按普通文本处理 @@ -43,14 +49,14 @@ `ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回: - `calls`:解析出的工具调用列表(`name` + `input`) -- `sawToolCallSyntax`:只有检测到 `]*>([\s\S]*?)<\/tool_calls>/gi; -const TOOL_CALL_MARKUP_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?invoke\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?invoke>/gi; -const PARAMETER_BLOCK_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b([^>]*)>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi; const TOOL_CALL_MARKUP_KV_PATTERN = /<(?:[a-z0-9_:-]+:)?([a-z0-9_.-]+)\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?\1>/gi; const CDATA_PATTERN = /^$/i; const XML_ATTR_PATTERN = /\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')/gi; @@ -25,9 +22,9 @@ function parseMarkupToolCalls(text) { return []; } const out = []; - for (const wrapper of raw.matchAll(TOOLS_WRAPPER_PATTERN)) { - const body = toStringSafe(wrapper[1]); - for (const block of body.matchAll(TOOL_CALL_MARKUP_BLOCK_PATTERN)) { + for (const wrapper of findXmlElementBlocks(raw, 'tool_calls')) { + const body = toStringSafe(wrapper.body); + for (const block of findXmlElementBlocks(body, 'invoke')) { const parsed = parseMarkupSingleToolCall(block); if (parsed) { out.push(parsed); @@ -38,12 +35,12 @@ function parseMarkupToolCalls(text) { } function parseMarkupSingleToolCall(block) { - const attrs = parseTagAttributes(block[1]); + const attrs = parseTagAttributes(block.attrs); const name = toStringSafe(attrs.name).trim(); if (!name) { return null; } - const inner = toStringSafe(block[2]).trim(); + const inner = toStringSafe(block.body).trim(); if (inner) { try { @@ -63,13 +60,13 @@ function parseMarkupSingleToolCall(block) { } } const input = {}; - for (const match of inner.matchAll(PARAMETER_BLOCK_PATTERN)) { - const parameterAttrs = parseTagAttributes(match[1]); + for (const match of findXmlElementBlocks(inner, 'parameter')) { + const parameterAttrs = parseTagAttributes(match.attrs); const paramName = toStringSafe(parameterAttrs.name).trim(); if (!paramName) { continue; } - appendMarkupValue(input, paramName, parseMarkupValue(match[2])); + appendMarkupValue(input, paramName, parseMarkupValue(match.body)); } if (Object.keys(input).length === 0 && inner.trim() !== '') { return null; @@ -77,6 +74,154 @@ function parseMarkupSingleToolCall(block) { return { name, input }; } +function findXmlElementBlocks(text, tag) { + const source = toStringSafe(text); + const name = toStringSafe(tag).toLowerCase(); + if (!source || !name) { + return []; + } + const out = []; + let pos = 0; + while (pos < source.length) { + const start = findXmlStartTagOutsideCDATA(source, name, pos); + if (!start) { + break; + } + const end = findMatchingXmlEndTagOutsideCDATA(source, name, start.bodyStart); + if (!end) { + break; + } + out.push({ + attrs: start.attrs, + body: source.slice(start.bodyStart, end.closeStart), + start: start.start, + end: end.closeEnd, + }); + pos = end.closeEnd; + } + return out; +} + +function findXmlStartTagOutsideCDATA(text, tag, from) { + const lower = text.toLowerCase(); + const target = `<${tag}`; + for (let i = Math.max(0, from || 0); i < text.length;) { + const skipped = skipXmlIgnoredSection(lower, i); + if (skipped.blocked) { + return null; + } + if (skipped.advanced) { + i = skipped.next; + continue; + } + if (lower.startsWith(target, i) && hasXmlTagBoundary(text, i + target.length)) { + const tagEnd = findXmlTagEnd(text, i + target.length); + if (tagEnd < 0) { + return null; + } + return { + start: i, + bodyStart: tagEnd + 1, + attrs: text.slice(i + target.length, tagEnd), + }; + } + i += 1; + } + return null; +} + +function findMatchingXmlEndTagOutsideCDATA(text, tag, from) { + const lower = text.toLowerCase(); + const openTarget = `<${tag}`; + const closeTarget = `', i + ''.length }; + } + if (lower.startsWith('', i + ''.length }; + } + return { advanced: false, blocked: false, next: i }; +} + +function findXmlTagEnd(text, from) { + let quote = ''; + for (let i = Math.max(0, from || 0); i < text.length; i += 1) { + const ch = text[i]; + if (quote) { + if (ch === quote) { + quote = ''; + } + continue; + } + if (ch === '"' || ch === "'") { + quote = ch; + continue; + } + if (ch === '>') { + return i; + } + } + return -1; +} + +function hasXmlTagBoundary(text, idx) { + if (idx >= text.length) { + return true; + } + return [' ', '\t', '\n', '\r', '>', '/'].includes(text[idx]); +} + +function isSelfClosingXmlTag(startTag) { + return toStringSafe(startTag).trim().endsWith('/'); +} + function parseMarkupInput(raw) { const s = toStringSafe(raw).trim(); if (!s) { @@ -120,6 +265,10 @@ function parseMarkupKVObject(text) { } function parseMarkupValue(raw) { + const cdata = extractStandaloneCDATA(raw); + if (cdata.ok) { + return cdata.value; + } const s = toStringSafe(extractRawTagValue(raw)).trim(); if (!s) { return ''; @@ -152,9 +301,9 @@ function extractRawTagValue(inner) { } // 1. Check for CDATA - const cdataMatch = s.match(CDATA_PATTERN); - if (cdataMatch && cdataMatch[1] !== undefined) { - return cdataMatch[1]; + const cdata = extractStandaloneCDATA(s); + if (cdata.ok) { + return cdata.value; } // 2. Fallback to unescaping standard HTML entities @@ -172,6 +321,15 @@ function unescapeHtml(safe) { .replace(/'/g, "'"); } +function extractStandaloneCDATA(inner) { + const s = toStringSafe(inner).trim(); + const cdataMatch = s.match(CDATA_PATTERN); + if (cdataMatch && cdataMatch[1] !== undefined) { + return { ok: true, value: cdataMatch[1] }; + } + return { ok: false, value: '' }; +} + function parseTagAttributes(raw) { const source = toStringSafe(raw); const out = {}; diff --git a/internal/js/helpers/stream-tool-sieve/sieve-xml.js b/internal/js/helpers/stream-tool-sieve/sieve-xml.js index cc8ee43..90ea280 100644 --- a/internal/js/helpers/stream-tool-sieve/sieve-xml.js +++ b/internal/js/helpers/stream-tool-sieve/sieve-xml.js @@ -16,9 +16,10 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) { if (openIdx < 0) { continue; } - // Find the LAST occurrence of the specific closing tag. - const closeIdx = lower.lastIndexOf(pair.close); - if (closeIdx < openIdx) { + // Ignore closing tags that appear inside CDATA payloads, such as + // write-file content containing tool-call documentation examples. + const closeIdx = findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length); + if (closeIdx < 0) { // Opening tag present but specific closing tag hasn't arrived. // Return not-ready so buffering continues until the wrapper closes. return { ready: false, prefix: '', calls: [], suffix: '' }; @@ -46,8 +47,9 @@ function consumeXMLToolCapture(captured, toolNames, trimWrappingJSONFence) { function hasOpenXMLToolTag(captured) { const lower = captured.toLowerCase(); for (const pair of XML_TOOL_TAG_PAIRS) { - if (lower.includes(pair.open)) { - if (!lower.includes(pair.close)) { + const openIdx = lower.indexOf(pair.open); + if (openIdx >= 0) { + if (findXMLCloseOutsideCDATA(captured, pair.close, openIdx + pair.open.length) < 0) { return true; } } @@ -74,6 +76,38 @@ function findPartialXMLToolTagStart(s) { return -1; } +function findXMLCloseOutsideCDATA(s, closeTag, start) { + const text = typeof s === 'string' ? s : ''; + const target = String(closeTag || '').toLowerCase(); + if (!text || !target) { + return -1; + } + const lower = text.toLowerCase(); + for (let i = Math.max(0, start || 0); i < text.length;) { + if (lower.startsWith('', i + ''.length; + continue; + } + if (lower.startsWith('', i + ''.length; + continue; + } + if (lower.startsWith(target, i)) { + return i; + } + i += 1; + } + return -1; +} + module.exports = { consumeXMLToolCapture, hasOpenXMLToolTag, diff --git a/internal/stream/engine.go b/internal/stream/engine.go index c63cd7b..1623946 100644 --- a/internal/stream/engine.go +++ b/internal/stream/engine.go @@ -71,15 +71,30 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { hooks.OnFinalize(reason, scannerErr) } } + contextDone := func() bool { + if cfg.Context.Err() == nil { + return false + } + if hooks.OnContextDone != nil { + hooks.OnContextDone() + } + return true + } for { + if contextDone() { + return + } select { case <-cfg.Context.Done(): - if hooks.OnContextDone != nil { - hooks.OnContextDone() + if contextDone() { + return } return case <-tickCh(ticker): + if contextDone() { + return + } if !hasContent { keepaliveCount++ if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput { @@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { hooks.OnKeepAlive() } case parsed, ok := <-parsedLines: + if contextDone() { + return + } if !ok { finalize(StopReasonUpstreamCompleted, <-done) return diff --git a/internal/stream/engine_test.go b/internal/stream/engine_test.go new file mode 100644 index 0000000..b23474b --- /dev/null +++ b/internal/stream/engine_test.go @@ -0,0 +1,47 @@ +package stream + +import ( + "context" + "strings" + "testing" + + "ds2api/internal/sse" +) + +func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var finalized bool + var contextDone bool + var parsedCalled bool + + ConsumeSSE(ConsumeConfig{ + Context: ctx, + Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"), + ThinkingEnabled: false, + InitialType: "text", + KeepAliveInterval: 0, + }, ConsumeHooks{ + OnParsed: func(_ sse.LineResult) ParsedDecision { + parsedCalled = true + return ParsedDecision{} + }, + OnFinalize: func(_ StopReason, _ error) { + finalized = true + }, + OnContextDone: func() { + contextDone = true + }, + }) + + if !contextDone { + t.Fatal("expected OnContextDone to run for an already-cancelled context") + } + if finalized { + t.Fatal("expected OnFinalize not to run after context cancellation wins") + } + if parsedCalled { + t.Fatal("expected parsed lines not to be processed after context cancellation wins") + } +} diff --git a/internal/toolcall/tool_prompt.go b/internal/toolcall/tool_prompt.go index 7f405d2..aa556e8 100644 --- a/internal/toolcall/tool_prompt.go +++ b/internal/toolcall/tool_prompt.go @@ -27,6 +27,8 @@ RULES: 7) Numbers, booleans, and null stay plain text. 8) Use only the parameter names in the tool schema. Do not invent fields. 9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue. +10) If you call a tool, the first non-whitespace characters of that tool block must be exactly . +11) Never omit the opening tag, even if you already plan to close with . PARAMETER SHAPES: - string => @@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences: ` + "```xml" + ` ... ` + "```" + ` +Wrong 3 — missing opening wrapper: + ... + Remember: The ONLY valid way to use tools is the ... XML block at the end of your response. diff --git a/internal/toolcall/tool_prompt_test.go b/internal/toolcall/tool_prompt_test.go index 8b0e8cf..d482d52 100644 --- a/internal/toolcall/tool_prompt_test.go +++ b/internal/toolcall/tool_prompt_test.go @@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) { } } +func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) { + out := BuildToolCallInstructions([]string{"read_file"}) + if !strings.Contains(out, "Never omit the opening tag") { + t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out) + } + if !strings.Contains(out, "Wrong 3 — missing opening wrapper") { + t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out) + } +} + func findInvokeBlocks(text, name string) []string { open := `` remaining := text diff --git a/internal/toolcall/toolcalls_markup.go b/internal/toolcall/toolcalls_markup.go index 3d8e657..b01ba21 100644 --- a/internal/toolcall/toolcalls_markup.go +++ b/internal/toolcall/toolcalls_markup.go @@ -43,6 +43,9 @@ func parseMarkupKVObject(text string) map[string]any { } func parseMarkupValue(inner string) any { + if value, ok := extractStandaloneCDATA(inner); ok { + return value + } value := strings.TrimSpace(extractRawTagValue(inner)) if value == "" { return "" @@ -89,8 +92,8 @@ func extractRawTagValue(inner string) string { } // 1. Check for CDATA - if present, it's the ultimate "safe" container. - if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 { - return cdataMatches[1] // Return raw content between CDATA brackets + if value, ok := extractStandaloneCDATA(trimmed); ok { + return value // Return raw content between CDATA brackets } // 2. If no CDATA, we still want to be robust. @@ -102,3 +105,11 @@ func extractRawTagValue(inner string) string { // but for KV objects we usually want the value. return html.UnescapeString(inner) } + +func extractStandaloneCDATA(inner string) (string, bool) { + trimmed := strings.TrimSpace(inner) + if cdataMatches := cdataPattern.FindStringSubmatch(trimmed); len(cdataMatches) >= 2 { + return cdataMatches[1], true + } + return "", false +} diff --git a/internal/toolcall/toolcalls_parse.go b/internal/toolcall/toolcalls_parse.go index 16743ac..a950c2c 100644 --- a/internal/toolcall/toolcalls_parse.go +++ b/internal/toolcall/toolcalls_parse.go @@ -87,7 +87,13 @@ func stripFencedCodeBlocks(text string) string { lines := strings.SplitAfter(text, "\n") inFence := false fenceMarker := "" + inCDATA := false for _, line := range lines { + if inCDATA || cdataStartsBeforeFence(line) { + b.WriteString(line) + inCDATA = updateCDATAState(inCDATA, line) + continue + } trimmed := strings.TrimLeft(line, " \t") if !inFence { if marker, ok := parseFenceOpen(trimmed); ok { @@ -111,6 +117,54 @@ func stripFencedCodeBlocks(text string) string { return b.String() } +func cdataStartsBeforeFence(line string) bool { + cdataIdx := strings.Index(strings.ToLower(line), "") + if end < 0 { + return true + } + pos += end + len("]]>") + state = false + continue + } + start := strings.Index(lower[pos:], "]*>\s*(.*?)\s*`) -var xmlInvokePattern = regexp.MustCompile(`(?is)]*)>\s*(.*?)\s*`) -var xmlParameterPattern = regexp.MustCompile(`(?is)]*)>\s*(.*?)\s*`) var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`) +var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)`) +var xmlInvokeStartPattern = regexp.MustCompile(`(?is)]*\bname\s*=\s*("([^"]*)"|'([^']*)')`) func parseXMLToolCalls(text string) []ParsedToolCall { - wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1) + wrappers := findXMLElementBlocks(text, "tool_calls") + if len(wrappers) == 0 { + repaired := repairMissingXMLToolCallsOpeningWrapper(text) + if repaired != text { + wrappers = findXMLElementBlocks(repaired, "tool_calls") + } + } if len(wrappers) == 0 { return nil } out := make([]ParsedToolCall, 0, len(wrappers)) for _, wrapper := range wrappers { - if len(wrapper) < 2 { - continue - } - for _, block := range xmlInvokePattern.FindAllStringSubmatch(wrapper[1], -1) { + for _, block := range findXMLElementBlocks(wrapper.Body, "invoke") { call, ok := parseSingleXMLToolCall(block) if !ok { continue @@ -36,17 +38,36 @@ func parseXMLToolCalls(text string) []ParsedToolCall { return out } -func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) { - if len(block) < 3 { - return ParsedToolCall{}, false +func repairMissingXMLToolCallsOpeningWrapper(text string) string { + lower := strings.ToLower(text) + if strings.Contains(lower, "= closeLoc[0] { + return text + } + + return text[:invokeLoc[0]] + "" + text[invokeLoc[0]:closeLoc[0]] + "" + text[closeLoc[1]:] +} + +func parseSingleXMLToolCall(block xmlElementBlock) (ParsedToolCall, bool) { + attrs := parseXMLTagAttributes(block.Attrs) name := strings.TrimSpace(html.UnescapeString(attrs["name"])) if name == "" { return ParsedToolCall{}, false } - inner := strings.TrimSpace(block[2]) + inner := strings.TrimSpace(block.Body) if strings.HasPrefix(inner, "{") { var payload map[string]any if err := json.Unmarshal([]byte(inner), &payload); err == nil { @@ -64,16 +85,13 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) { } input := map[string]any{} - for _, paramMatch := range xmlParameterPattern.FindAllStringSubmatch(inner, -1) { - if len(paramMatch) < 3 { - continue - } - paramAttrs := parseXMLTagAttributes(paramMatch[1]) + for _, paramMatch := range findXMLElementBlocks(inner, "parameter") { + paramAttrs := parseXMLTagAttributes(paramMatch.Attrs) paramName := strings.TrimSpace(html.UnescapeString(paramAttrs["name"])) if paramName == "" { continue } - value := parseInvokeParameterValue(paramMatch[2]) + value := parseInvokeParameterValue(paramMatch.Body) appendMarkupValue(input, paramName, value) } @@ -86,6 +104,168 @@ func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) { return ParsedToolCall{Name: name, Input: input}, true } +type xmlElementBlock struct { + Attrs string + Body string + Start int + End int +} + +func findXMLElementBlocks(text, tag string) []xmlElementBlock { + if text == "" || tag == "" { + return nil + } + var out []xmlElementBlock + pos := 0 + for pos < len(text) { + start, bodyStart, attrs, ok := findXMLStartTagOutsideCDATA(text, tag, pos) + if !ok { + break + } + closeStart, closeEnd, ok := findMatchingXMLEndTagOutsideCDATA(text, tag, bodyStart) + if !ok { + break + } + out = append(out, xmlElementBlock{ + Attrs: attrs, + Body: text[bodyStart:closeStart], + Start: start, + End: closeEnd, + }) + pos = closeEnd + } + return out +} + +func findXMLStartTagOutsideCDATA(text, tag string, from int) (start, bodyStart int, attrs string, ok bool) { + lower := strings.ToLower(text) + target := "<" + strings.ToLower(tag) + for i := maxInt(from, 0); i < len(text); { + next, advanced, blocked := skipXMLIgnoredSection(lower, i) + if blocked { + return -1, -1, "", false + } + if advanced { + i = next + continue + } + if strings.HasPrefix(lower[i:], target) && hasXMLTagBoundary(text, i+len(target)) { + end := findXMLTagEnd(text, i+len(target)) + if end < 0 { + return -1, -1, "", false + } + return i, end + 1, text[i+len(target) : end], true + } + i++ + } + return -1, -1, "", false +} + +func findMatchingXMLEndTagOutsideCDATA(text, tag string, from int) (closeStart, closeEnd int, ok bool) { + lower := strings.ToLower(text) + openTarget := "<" + strings.ToLower(tag) + closeTarget := "") + if end < 0 { + return 0, false, true + } + return i + len(""), true, false + case strings.HasPrefix(lower[i:], "") + if end < 0 { + return 0, false, true + } + return i + len(""), true, false + default: + return i, false, false + } +} + +func findXMLTagEnd(text string, from int) int { + quote := byte(0) + for i := maxInt(from, 0); i < len(text); i++ { + ch := text[i] + if quote != 0 { + if ch == quote { + quote = 0 + } + continue + } + if ch == '"' || ch == '\'' { + quote = ch + continue + } + if ch == '>' { + return i + } + } + return -1 +} + +func hasXMLTagBoundary(text string, idx int) bool { + if idx >= len(text) { + return true + } + switch text[idx] { + case ' ', '\t', '\n', '\r', '>', '/': + return true + default: + return false + } +} + +func isSelfClosingXMLTag(startTag string) bool { + return strings.HasSuffix(strings.TrimSpace(startTag), "/") +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + func parseXMLTagAttributes(raw string) map[string]string { if strings.TrimSpace(raw) == "" { return map[string]string{} @@ -113,6 +293,9 @@ func parseInvokeParameterValue(raw string) any { if trimmed == "" { return "" } + if value, ok := extractStandaloneCDATA(trimmed); ok { + return value + } if parsed := parseStructuredToolCallInput(trimmed); len(parsed) > 0 { if len(parsed) == 1 { if rawValue, ok := parsed["_raw"].(string); ok { diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index 13d0bef..c4bfe51 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -54,6 +54,32 @@ echo "hello" } } +func TestParseToolCallsKeepsToolSyntaxInsideCDATAAsParameterText(t *testing.T) { + payload := strings.Join([]string{ + "# Release notes", + "", + "```xml", + "", + " ", + " x", + " ", + "", + "```", + }, "\n") + text := `DS2API-4.0-Release-Notes.md` + calls := ParseToolCalls(text, []string{"Write"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + content, _ := calls[0].Input["content"].(string) + if content != payload { + t.Fatalf("expected CDATA payload with nested tool syntax to survive intact, got %q", content) + } + if calls[0].Input["file_path"] != "DS2API-4.0-Release-Notes.md" { + t.Fatalf("expected file_path parameter, got %#v", calls[0].Input) + } +} + func TestParseToolCallsSupportsInvokeParameters(t *testing.T) { text := `beijingc` calls := ParseToolCalls(text, []string{"get_weather"}) @@ -175,6 +201,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) { } } +func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) { + text := `Before tool call +README.md + +after` + res := ParseToolCallsDetailed(text, []string{"read_file"}) + if len(res.Calls) != 1 { + t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res) + } + if res.Calls[0].Name != "read_file" { + t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0]) + } + if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" { + t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input) + } + if !res.SawToolCallSyntax { + t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res) + } +} + func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) { text := `read_file{"path":"README.md"}` calls := ParseToolCalls(text, []string{"read_file"}) diff --git a/internal/toolstream/tool_sieve_xml.go b/internal/toolstream/tool_sieve_xml.go index 87fb075..72cbbaa 100644 --- a/internal/toolstream/tool_sieve_xml.go +++ b/internal/toolstream/tool_sieve_xml.go @@ -10,7 +10,7 @@ import ( //nolint:unused // kept as explicit tag inventory for future XML sieve refinements. var xmlToolCallClosingTags = []string{""} -var xmlToolCallOpeningTags = []string{"]*>\s*(?:.*?)\s*)`) // xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart. -var xmlToolTagsToDetect = []string{"", "", ". + closeIdx := findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open)) + if closeIdx < 0 { // Opening tag is present but its specific closing tag hasn't arrived. // Return not-ready so we keep buffering until the canonical wrapper closes. return "", nil, "", false @@ -55,6 +56,22 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, // If this block failed to become a tool call, pass it through as text. return prefixPart + xmlBlock, nil, suffixPart, true } + if !strings.Contains(lower, "", invokeIdx) + if invokeIdx >= 0 && closeIdx > invokeIdx { + closeEnd := closeIdx + len("") + xmlBlock := "" + captured[invokeIdx:closeIdx] + "" + prefixPart := captured[:invokeIdx] + suffixPart := captured[closeEnd:] + parsed := toolcall.ParseToolCalls(xmlBlock, toolNames) + if len(parsed) > 0 { + prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart) + return prefixPart, parsed, suffixPart, true + } + return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true + } + } return "", nil, "", false } @@ -63,8 +80,9 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, func hasOpenXMLToolTag(captured string) bool { lower := strings.ToLower(captured) for _, pair := range xmlToolCallTagPairs { - if strings.Contains(lower, pair.open) { - if !strings.Contains(lower, pair.close) { + openIdx := strings.Index(lower, pair.open) + if openIdx >= 0 { + if findXMLCloseOutsideCDATA(captured, pair.close, openIdx+len(pair.open)) < 0 { return true } } @@ -72,6 +90,38 @@ func hasOpenXMLToolTag(captured string) bool { return false } +func findXMLCloseOutsideCDATA(s, closeTag string, start int) int { + if s == "" || closeTag == "" { + return -1 + } + if start < 0 { + start = 0 + } + lower := strings.ToLower(s) + target := strings.ToLower(closeTag) + for i := start; i < len(s); { + switch { + case strings.HasPrefix(lower[i:], "") + if end < 0 { + return -1 + } + i += len("") + case strings.HasPrefix(lower[i:], "") + if end < 0 { + return -1 + } + i += len("") + case strings.HasPrefix(lower[i:], target): + return i + default: + i++ + } + } + return -1 +} + // findPartialXMLToolTagStart checks if the string ends with a partial canonical // XML wrapper tag (e.g., "", + " ", + " x", + " ", + "", + "```", + "tail", + }, "\n") + innerClose := strings.Index(payload, "") + len("") + chunks := []string{ + "\n \n \n DS2API-4.0-Release-Notes.md\n \n", + } + + var events []Event + for i, c := range chunks { + next := ProcessChunk(&state, c, []string{"Write"}) + if i <= 1 { + for _, evt := range next { + if evt.Content != "" || len(evt.ToolCalls) > 0 { + t.Fatalf("expected no events before outer closing tag, chunk=%d events=%#v", i, next) + } + } + } + events = append(events, next...) + } + events = append(events, Flush(&state, []string{"Write"})...) + + var textContent strings.Builder + var gotPayload string + toolCalls := 0 + for _, evt := range events { + textContent.WriteString(evt.Content) + if len(evt.ToolCalls) > 0 { + toolCalls += len(evt.ToolCalls) + gotPayload, _ = evt.ToolCalls[0].Input["content"].(string) + } + } + + if toolCalls != 1 { + t.Fatalf("expected one parsed tool call, got %d events=%#v", toolCalls, events) + } + if textContent.Len() != 0 { + t.Fatalf("expected no leaked text, got %q", textContent.String()) + } + if gotPayload != payload { + t.Fatalf("expected full CDATA payload to survive intact, got len=%d want=%d", len(gotPayload), len(payload)) + } +} + func TestProcessToolSieveXMLWithLeadingText(t *testing.T) { var state State // Model outputs some prose then an XML tool call. @@ -288,6 +347,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) { want int }{ {"tool_calls_tag", "some text \n", 10}, + {"invoke_tag_missing_wrapper", "some text \n", 10}, {"bare_tool_call_text", "prefix \n", -1}, {"xml_inside_code_fence", "```xml\n\n```", -1}, {"no_xml", "just plain text", -1}, @@ -310,6 +370,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) { want int }{ {"partial_tool_calls", "Hello done", -1}, @@ -505,3 +566,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) { t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String()) } } + +func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) { + var state State + chunks := []string{ + "\n", + " README.md\n", + "\n", + "", + } + var events []Event + for _, c := range chunks { + events = append(events, ProcessChunk(&state, c, []string{"read_file"})...) + } + events = append(events, Flush(&state, []string{"read_file"})...) + + var textContent strings.Builder + toolCalls := 0 + for _, evt := range events { + textContent.WriteString(evt.Content) + toolCalls += len(evt.ToolCalls) + } + + if toolCalls != 1 { + t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events) + } + if strings.Contains(textContent.String(), "") { + t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String()) + } +} diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index 1e5012a..cc6ae93 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -118,6 +118,60 @@ test('sieve keeps long XML tool calls buffered until the closing tag arrives', ( assert.equal(finalCalls[0].input.content, longContent); }); +test('sieve keeps CDATA tool examples buffered until the outer closing tag arrives', () => { + const content = [ + '# DS2API 4.0 更新内容', + '', + 'x'.repeat(4096), + '```xml', + '', + ' ', + ' x', + ' ', + '', + '```', + 'tail', + ].join('\n'); + const innerClose = content.indexOf('') + ''.length; + const state = createToolSieveState(); + const chunks = [ + '\n \n \n DS2API-4.0-Release-Notes.md\n \n', + ]; + const events = []; + chunks.forEach((chunk, idx) => { + const next = processToolSieveChunk(state, chunk, ['Write']); + if (idx <= 1) { + assert.deepEqual(next, []); + } + events.push(...next); + }); + events.push(...flushToolSieve(state, ['Write'])); + + const leakedText = collectText(events); + const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []); + assert.equal(leakedText, ''); + assert.equal(finalCalls.length, 1); + assert.equal(finalCalls[0].name, 'Write'); + assert.equal(finalCalls[0].input.content, content); +}); + +test('parseToolCalls keeps XML-looking CDATA content intact', () => { + const content = [ + '# Release notes', + '```xml', + 'x', + '```', + ].join('\n'); + const payload = `DS2API-4.0-Release-Notes.md`; + const calls = parseToolCalls(payload, ['Write']); + assert.equal(calls.length, 1); + assert.equal(calls[0].input.content, content); + assert.equal(calls[0].input.file_path, 'DS2API-4.0-Release-Notes.md'); +}); + test('sieve passes JSON tool_calls payload through as text (XML-only)', () => { const events = runSieve( ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],