测试DSML

This commit is contained in:
CJACK
2026-04-27 00:21:26 +08:00
parent 645fce41c8
commit 40d5e3ebb5
50 changed files with 1112 additions and 265 deletions

View File

@@ -9,22 +9,27 @@ import (
// --- XML tool call support for the streaming sieve ---
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
var xmlToolCallClosingTags = []string{"</tool_calls>"}
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
var xmlToolCallClosingTags = []string{"</tool_calls>", "</|dsml|tool_calls>"}
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke", "<|dsml|tool_calls", "<|dsml|invoke"}
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
// Order matters: longer/wrapper tags must be checked first.
var xmlToolCallTagPairs = []struct{ open, close string }{
{"<|dsml|tool_calls", "</|dsml|tool_calls>"},
{"<tool_calls", "</tool_calls>"},
}
// xmlToolCallBlockPattern matches a complete canonical XML tool call block.
//
//nolint:unused // reserved for future fast-path XML block detection.
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)((?:<tool_calls\b|<\|dsml\|tool_calls\b)[^>]*>\s*(?:.*?)\s*(?:</tool_calls>|</\|dsml\|tool_calls>))`)
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
var xmlToolTagsToDetect = []string{
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
"<|dsml|invoke ", "<|dsml|invoke\n", "<|dsml|invoke\t", "<|dsml|invoke\r",
"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r",
}
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
@@ -56,12 +61,18 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
// If this block failed to become a tool call, pass it through as text.
return prefixPart + xmlBlock, nil, suffixPart, true
}
if !strings.Contains(lower, "<tool_calls") {
invokeIdx := strings.Index(lower, "<invoke")
closeIdx := findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx)
if !containsAnyToolCallWrapper(lower) {
invokeIdx, dsml := firstInvokeIndex(lower)
closeTag := "</tool_calls>"
openWrapper := "<tool_calls>"
if dsml {
closeTag = "</|dsml|tool_calls>"
openWrapper = "<|DSML|tool_calls>"
}
closeIdx := findXMLCloseOutsideCDATA(captured, closeTag, invokeIdx)
if invokeIdx >= 0 && closeIdx > invokeIdx {
closeEnd := closeIdx + len("</tool_calls>")
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
closeEnd := closeIdx + len(closeTag)
xmlBlock := openWrapper + captured[invokeIdx:closeIdx] + closeTag
prefixPart := captured[:invokeIdx]
suffixPart := captured[closeEnd:]
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
@@ -92,15 +103,25 @@ func hasOpenXMLToolTag(captured string) bool {
func shouldKeepBareInvokeCapture(captured string) bool {
lower := strings.ToLower(captured)
invokeIdx := strings.Index(lower, "<invoke")
if invokeIdx < 0 || strings.Contains(lower, "<tool_calls") {
invokeIdx, dsml := firstInvokeIndex(lower)
if invokeIdx < 0 || containsAnyToolCallWrapper(lower) {
return false
}
if findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx) > invokeIdx {
wrapperClose := "</tool_calls>"
invokeOpenLen := len("<invoke")
invokeClose := "</invoke>"
parameterOpen := "<parameter"
if dsml {
wrapperClose = "</|dsml|tool_calls>"
invokeOpenLen = len("<|dsml|invoke")
invokeClose = "</|dsml|invoke>"
parameterOpen = "<|dsml|parameter"
}
if findXMLCloseOutsideCDATA(captured, wrapperClose, invokeIdx) > invokeIdx {
return true
}
startEnd := findXMLTagEnd(captured, invokeIdx+len("<invoke"))
startEnd := findXMLTagEnd(captured, invokeIdx+invokeOpenLen)
if startEnd < 0 {
return true
}
@@ -110,18 +131,37 @@ func shouldKeepBareInvokeCapture(captured string) bool {
return true
}
invokeCloseIdx := findXMLCloseOutsideCDATA(captured, "</invoke>", startEnd+1)
invokeCloseIdx := findXMLCloseOutsideCDATA(captured, invokeClose, startEnd+1)
if invokeCloseIdx >= 0 {
afterClose := captured[invokeCloseIdx+len("</invoke>"):]
afterClose := captured[invokeCloseIdx+len(invokeClose):]
return strings.TrimSpace(afterClose) == ""
}
trimmedLower := strings.ToLower(trimmedBody)
return strings.HasPrefix(trimmedLower, "<parameter") ||
return strings.HasPrefix(trimmedLower, parameterOpen) ||
strings.HasPrefix(trimmedLower, "{") ||
strings.HasPrefix(trimmedLower, "[")
}
func containsAnyToolCallWrapper(lower string) bool {
return strings.Contains(lower, "<tool_calls") || strings.Contains(lower, "<|dsml|tool_calls")
}
func firstInvokeIndex(lower string) (int, bool) {
xmlIdx := strings.Index(lower, "<invoke")
dsmlIdx := strings.Index(lower, "<|dsml|invoke")
switch {
case xmlIdx < 0:
return dsmlIdx, dsmlIdx >= 0
case dsmlIdx < 0:
return xmlIdx, false
case dsmlIdx < xmlIdx:
return dsmlIdx, true
default:
return xmlIdx, false
}
}
func findXMLCloseOutsideCDATA(s, closeTag string, start int) int {
if s == "" || closeTag == "" {
return -1

View File

@@ -41,6 +41,37 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
}
}
func TestProcessToolSieveInterceptsDSMLToolCallWithoutLeak(t *testing.T) {
var state State
chunks := []string{
"<|DSML|tool",
"_calls>\n",
` <|DSML|invoke name="read_file">` + "\n",
` <|DSML|parameter name="path">README.MD</|DSML|parameter>` + "\n",
" </|DSML|invoke>\n",
"</|DSML|tool_calls>",
}
var events []Event
for _, c := range chunks {
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
}
events = append(events, Flush(&state, []string{"read_file"})...)
var textContent string
var toolCalls int
for _, evt := range events {
textContent += evt.Content
toolCalls += len(evt.ToolCalls)
}
if strings.Contains(strings.ToLower(textContent), "dsml") || strings.Contains(textContent, "read_file") {
t.Fatalf("DSML tool call content leaked to text: %q", textContent)
}
if toolCalls != 1 {
t.Fatalf("expected one DSML tool call, got %d events=%#v", toolCalls, events)
}
}
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
var state State
const toolName = "write_to_file"