mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-19 23:47:45 +08:00
测试DSML
This commit is contained in:
@@ -9,22 +9,27 @@ import (
|
||||
// --- XML tool call support for the streaming sieve ---
|
||||
|
||||
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>", "</|dsml|tool_calls>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke", "<|dsml|tool_calls", "<|dsml|invoke"}
|
||||
|
||||
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
|
||||
// Order matters: longer/wrapper tags must be checked first.
|
||||
var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
{"<|dsml|tool_calls", "</|dsml|tool_calls>"},
|
||||
{"<tool_calls", "</tool_calls>"},
|
||||
}
|
||||
|
||||
// xmlToolCallBlockPattern matches a complete canonical XML tool call block.
|
||||
//
|
||||
//nolint:unused // reserved for future fast-path XML block detection.
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)((?:<tool_calls\b|<\|dsml\|tool_calls\b)[^>]*>\s*(?:.*?)\s*(?:</tool_calls>|</\|dsml\|tool_calls>))`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
|
||||
var xmlToolTagsToDetect = []string{
|
||||
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
|
||||
"<|dsml|invoke ", "<|dsml|invoke\n", "<|dsml|invoke\t", "<|dsml|invoke\r",
|
||||
"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r",
|
||||
}
|
||||
|
||||
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
|
||||
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -56,12 +61,18 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
if !strings.Contains(lower, "<tool_calls") {
|
||||
invokeIdx := strings.Index(lower, "<invoke")
|
||||
closeIdx := findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx)
|
||||
if !containsAnyToolCallWrapper(lower) {
|
||||
invokeIdx, dsml := firstInvokeIndex(lower)
|
||||
closeTag := "</tool_calls>"
|
||||
openWrapper := "<tool_calls>"
|
||||
if dsml {
|
||||
closeTag = "</|dsml|tool_calls>"
|
||||
openWrapper = "<|DSML|tool_calls>"
|
||||
}
|
||||
closeIdx := findXMLCloseOutsideCDATA(captured, closeTag, invokeIdx)
|
||||
if invokeIdx >= 0 && closeIdx > invokeIdx {
|
||||
closeEnd := closeIdx + len("</tool_calls>")
|
||||
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
|
||||
closeEnd := closeIdx + len(closeTag)
|
||||
xmlBlock := openWrapper + captured[invokeIdx:closeIdx] + closeTag
|
||||
prefixPart := captured[:invokeIdx]
|
||||
suffixPart := captured[closeEnd:]
|
||||
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
|
||||
@@ -92,15 +103,25 @@ func hasOpenXMLToolTag(captured string) bool {
|
||||
|
||||
func shouldKeepBareInvokeCapture(captured string) bool {
|
||||
lower := strings.ToLower(captured)
|
||||
invokeIdx := strings.Index(lower, "<invoke")
|
||||
if invokeIdx < 0 || strings.Contains(lower, "<tool_calls") {
|
||||
invokeIdx, dsml := firstInvokeIndex(lower)
|
||||
if invokeIdx < 0 || containsAnyToolCallWrapper(lower) {
|
||||
return false
|
||||
}
|
||||
if findXMLCloseOutsideCDATA(captured, "</tool_calls>", invokeIdx) > invokeIdx {
|
||||
wrapperClose := "</tool_calls>"
|
||||
invokeOpenLen := len("<invoke")
|
||||
invokeClose := "</invoke>"
|
||||
parameterOpen := "<parameter"
|
||||
if dsml {
|
||||
wrapperClose = "</|dsml|tool_calls>"
|
||||
invokeOpenLen = len("<|dsml|invoke")
|
||||
invokeClose = "</|dsml|invoke>"
|
||||
parameterOpen = "<|dsml|parameter"
|
||||
}
|
||||
if findXMLCloseOutsideCDATA(captured, wrapperClose, invokeIdx) > invokeIdx {
|
||||
return true
|
||||
}
|
||||
|
||||
startEnd := findXMLTagEnd(captured, invokeIdx+len("<invoke"))
|
||||
startEnd := findXMLTagEnd(captured, invokeIdx+invokeOpenLen)
|
||||
if startEnd < 0 {
|
||||
return true
|
||||
}
|
||||
@@ -110,18 +131,37 @@ func shouldKeepBareInvokeCapture(captured string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
invokeCloseIdx := findXMLCloseOutsideCDATA(captured, "</invoke>", startEnd+1)
|
||||
invokeCloseIdx := findXMLCloseOutsideCDATA(captured, invokeClose, startEnd+1)
|
||||
if invokeCloseIdx >= 0 {
|
||||
afterClose := captured[invokeCloseIdx+len("</invoke>"):]
|
||||
afterClose := captured[invokeCloseIdx+len(invokeClose):]
|
||||
return strings.TrimSpace(afterClose) == ""
|
||||
}
|
||||
|
||||
trimmedLower := strings.ToLower(trimmedBody)
|
||||
return strings.HasPrefix(trimmedLower, "<parameter") ||
|
||||
return strings.HasPrefix(trimmedLower, parameterOpen) ||
|
||||
strings.HasPrefix(trimmedLower, "{") ||
|
||||
strings.HasPrefix(trimmedLower, "[")
|
||||
}
|
||||
|
||||
func containsAnyToolCallWrapper(lower string) bool {
|
||||
return strings.Contains(lower, "<tool_calls") || strings.Contains(lower, "<|dsml|tool_calls")
|
||||
}
|
||||
|
||||
func firstInvokeIndex(lower string) (int, bool) {
|
||||
xmlIdx := strings.Index(lower, "<invoke")
|
||||
dsmlIdx := strings.Index(lower, "<|dsml|invoke")
|
||||
switch {
|
||||
case xmlIdx < 0:
|
||||
return dsmlIdx, dsmlIdx >= 0
|
||||
case dsmlIdx < 0:
|
||||
return xmlIdx, false
|
||||
case dsmlIdx < xmlIdx:
|
||||
return dsmlIdx, true
|
||||
default:
|
||||
return xmlIdx, false
|
||||
}
|
||||
}
|
||||
|
||||
func findXMLCloseOutsideCDATA(s, closeTag string, start int) int {
|
||||
if s == "" || closeTag == "" {
|
||||
return -1
|
||||
|
||||
@@ -41,6 +41,37 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsDSMLToolCallWithoutLeak(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<|DSML|tool",
|
||||
"_calls>\n",
|
||||
` <|DSML|invoke name="read_file">` + "\n",
|
||||
` <|DSML|parameter name="path">README.MD</|DSML|parameter>` + "\n",
|
||||
" </|DSML|invoke>\n",
|
||||
"</|DSML|tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent string
|
||||
var toolCalls int
|
||||
for _, evt := range events {
|
||||
textContent += evt.Content
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(textContent), "dsml") || strings.Contains(textContent, "read_file") {
|
||||
t.Fatalf("DSML tool call content leaked to text: %q", textContent)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one DSML tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
var state State
|
||||
const toolName = "write_to_file"
|
||||
|
||||
Reference in New Issue
Block a user