mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-15 13:45:10 +08:00
refactor: generalize DSML tag parsing to tolerate model noise; split tiktoken by build tags
Replace hardcoded DSML typo variant lists in Go/Node tool call parsers with generalized prefix consumption that tolerates repeated leading <, repeated DSML prefix noise, and trailing pipe terminators. Split tiktoken-dependent token counting into a build-tagged file for non-cgo platform compatibility. Add /data directory to Dockerfile for bind-mount permissions. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"client": {
|
||||
"name": "DeepSeek",
|
||||
"platform": "android",
|
||||
"version": "2.0.3",
|
||||
"version": "2.0.4",
|
||||
"android_api_level": "35",
|
||||
"locale": "zh_CN"
|
||||
},
|
||||
@@ -24,4 +24,4 @@
|
||||
"skip_exact_paths": [
|
||||
"response/search_status"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -248,6 +248,9 @@ function replaceDSMLToolMarkupOutsideIgnored(text) {
|
||||
if (tag) {
|
||||
if (tag.dsmlLike) {
|
||||
out += `<${tag.closing ? '/' : ''}${tag.name}${raw.slice(tag.nameEnd, tag.end + 1)}`;
|
||||
if (raw[tag.end] !== '>') {
|
||||
out += '>';
|
||||
}
|
||||
} else {
|
||||
out += raw.slice(tag.start, tag.end + 1);
|
||||
}
|
||||
@@ -424,31 +427,42 @@ function scanToolMarkupTagAt(text, start) {
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
let i = start + 1;
|
||||
while (i < raw.length && raw[i] === '<') {
|
||||
i += 1;
|
||||
}
|
||||
const closing = raw[i] === '/';
|
||||
if (closing) {
|
||||
i += 1;
|
||||
}
|
||||
let dsmlLike = false;
|
||||
if (i < raw.length && isToolMarkupPipe(raw[i])) {
|
||||
dsmlLike = true;
|
||||
i += 1;
|
||||
}
|
||||
if (lower.startsWith('dsml', i)) {
|
||||
dsmlLike = true;
|
||||
i += 'dsml'.length;
|
||||
while (i < raw.length && isToolMarkupSeparator(raw[i])) {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
const prefix = consumeToolMarkupNamePrefix(raw, lower, i);
|
||||
i = prefix.next;
|
||||
const dsmlLike = prefix.dsmlLike;
|
||||
const { name, len } = matchToolMarkupName(lower, i);
|
||||
if (!name) {
|
||||
return null;
|
||||
}
|
||||
const nameEnd = i + len;
|
||||
const originalNameEnd = i + len;
|
||||
let nameEnd = originalNameEnd;
|
||||
while (nameEnd < raw.length && isToolMarkupPipe(raw[nameEnd])) {
|
||||
nameEnd += 1;
|
||||
}
|
||||
const hasTrailingPipe = nameEnd > originalNameEnd;
|
||||
if (!hasXmlTagBoundary(raw, nameEnd)) {
|
||||
return null;
|
||||
}
|
||||
const end = findXmlTagEnd(raw, nameEnd);
|
||||
let end = findXmlTagEnd(raw, nameEnd);
|
||||
if (end < 0) {
|
||||
if (!hasTrailingPipe) {
|
||||
return null;
|
||||
}
|
||||
end = nameEnd - 1;
|
||||
}
|
||||
if (hasTrailingPipe) {
|
||||
const nextLT = raw.indexOf('<', nameEnd);
|
||||
if (nextLT >= 0 && end >= nextLT) {
|
||||
end = nameEnd - 1;
|
||||
}
|
||||
}
|
||||
if (end < 0) {
|
||||
return null;
|
||||
}
|
||||
@@ -520,37 +534,94 @@ function findPartialToolMarkupStart(text) {
|
||||
if (lastLT < 0) {
|
||||
return -1;
|
||||
}
|
||||
const tail = raw.slice(lastLT);
|
||||
const start = includeDuplicateLeadingLessThan(raw, lastLT);
|
||||
const tail = raw.slice(start);
|
||||
if (tail.includes('>')) {
|
||||
return -1;
|
||||
}
|
||||
const lowerTail = tail.toLowerCase();
|
||||
const candidates = [
|
||||
'<tool_calls', '<invoke', '<parameter',
|
||||
'<|tool_calls', '<|invoke', '<|parameter',
|
||||
'<|tool_calls', '<|invoke', '<|parameter',
|
||||
'<|dsml|tool_calls', '<|dsml|invoke', '<|dsml|parameter',
|
||||
'<|dsml|tool_calls', '<|dsml|invoke', '<|dsml|parameter',
|
||||
'<dsmltool_calls', '<dsmlinvoke', '<dsmlparameter',
|
||||
'<dsml tool_calls', '<dsml invoke', '<dsml parameter',
|
||||
'<dsml|tool_calls', '<dsml|invoke', '<dsml|parameter',
|
||||
'<|dsmltool_calls', '<|dsmlinvoke', '<|dsmlparameter',
|
||||
'<|dsml tool_calls', '<|dsml invoke', '<|dsml parameter',
|
||||
];
|
||||
for (const candidate of candidates) {
|
||||
if (candidate.startsWith(lowerTail)) {
|
||||
return lastLT;
|
||||
}
|
||||
return isPartialToolMarkupTagPrefix(tail) ? start : -1;
|
||||
}
|
||||
|
||||
function includeDuplicateLeadingLessThan(text, idx) {
|
||||
let out = idx;
|
||||
while (out > 0 && text[out - 1] === '<') {
|
||||
out -= 1;
|
||||
}
|
||||
return -1;
|
||||
return out;
|
||||
}
|
||||
|
||||
function isToolMarkupPipe(ch) {
|
||||
return ch === '|' || ch === '|';
|
||||
}
|
||||
|
||||
function isToolMarkupSeparator(ch) {
|
||||
return ch === ' ' || ch === '\t' || ch === '\r' || ch === '\n' || isToolMarkupPipe(ch);
|
||||
function isPartialToolMarkupTagPrefix(text) {
|
||||
const raw = toStringSafe(text);
|
||||
if (!raw || raw[0] !== '<' || raw.includes('>')) {
|
||||
return false;
|
||||
}
|
||||
const lower = raw.toLowerCase();
|
||||
let i = 1;
|
||||
while (i < raw.length && raw[i] === '<') {
|
||||
i += 1;
|
||||
}
|
||||
if (i >= raw.length) {
|
||||
return true;
|
||||
}
|
||||
if (raw[i] === '/') {
|
||||
i += 1;
|
||||
}
|
||||
while (i <= raw.length) {
|
||||
if (i === raw.length) {
|
||||
return true;
|
||||
}
|
||||
if (hasToolMarkupNamePrefix(lower.slice(i))) {
|
||||
return true;
|
||||
}
|
||||
if ('dsml'.startsWith(lower.slice(i))) {
|
||||
return true;
|
||||
}
|
||||
const next = consumeToolMarkupNamePrefixOnce(raw, lower, i);
|
||||
if (!next.ok) {
|
||||
return false;
|
||||
}
|
||||
i = next.next;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function consumeToolMarkupNamePrefix(raw, lower, idx) {
|
||||
let next = idx;
|
||||
let dsmlLike = false;
|
||||
while (true) {
|
||||
const consumed = consumeToolMarkupNamePrefixOnce(raw, lower, next);
|
||||
if (!consumed.ok) {
|
||||
return { next, dsmlLike };
|
||||
}
|
||||
next = consumed.next;
|
||||
dsmlLike = true;
|
||||
}
|
||||
}
|
||||
|
||||
function consumeToolMarkupNamePrefixOnce(raw, lower, idx) {
|
||||
if (idx < raw.length && isToolMarkupPipe(raw[idx])) {
|
||||
return { next: idx + 1, ok: true };
|
||||
}
|
||||
if (idx < raw.length && [' ', '\t', '\r', '\n'].includes(raw[idx])) {
|
||||
return { next: idx + 1, ok: true };
|
||||
}
|
||||
if (lower.startsWith('dsml', idx)) {
|
||||
return { next: idx + 'dsml'.length, ok: true };
|
||||
}
|
||||
return { next: idx, ok: false };
|
||||
}
|
||||
|
||||
function hasToolMarkupNamePrefix(lowerTail) {
|
||||
for (const name of TOOL_MARKUP_NAMES) {
|
||||
if (lowerTail.startsWith(name) || name.startsWith(lowerTail)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function matchToolMarkupName(lower, start) {
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
'use strict';
|
||||
|
||||
const XML_TOOL_SEGMENT_TAGS = [
|
||||
'<|dsml|tool_calls>', '<|dsml|tool_calls\n', '<|dsml|tool_calls ',
|
||||
'<|dsml|tool_calls>', '<|dsml|tool_calls\n', '<|dsml|tool_calls ',
|
||||
'<|dsml|invoke ', '<|dsml|invoke\n', '<|dsml|invoke\t', '<|dsml|invoke\r',
|
||||
'<|dsmltool_calls>', '<|dsmltool_calls\n', '<|dsmltool_calls ',
|
||||
'<|dsmlinvoke ', '<|dsmlinvoke\n', '<|dsmlinvoke\t', '<|dsmlinvoke\r',
|
||||
'<|dsml tool_calls>', '<|dsml tool_calls\n', '<|dsml tool_calls ',
|
||||
'<|dsml invoke ', '<|dsml invoke\n', '<|dsml invoke\t', '<|dsml invoke\r',
|
||||
'<dsml|tool_calls>', '<dsml|tool_calls\n', '<dsml|tool_calls ',
|
||||
'<dsml|invoke ', '<dsml|invoke\n', '<dsml|invoke\t', '<dsml|invoke\r',
|
||||
'<dsmltool_calls>', '<dsmltool_calls\n', '<dsmltool_calls ',
|
||||
'<dsmlinvoke ', '<dsmlinvoke\n', '<dsmlinvoke\t', '<dsmlinvoke\r',
|
||||
'<dsml tool_calls>', '<dsml tool_calls\n', '<dsml tool_calls ',
|
||||
'<dsml invoke ', '<dsml invoke\n', '<dsml invoke\t', '<dsml invoke\r',
|
||||
'<|tool_calls>', '<|tool_calls\n', '<|tool_calls ',
|
||||
'<|invoke ', '<|invoke\n', '<|invoke\t', '<|invoke\r',
|
||||
'<|tool_calls>', '<|tool_calls\n', '<|tool_calls ',
|
||||
'<|invoke ', '<|invoke\n', '<|invoke\t', '<|invoke\r',
|
||||
'<tool_calls>', '<tool_calls\n', '<tool_calls ',
|
||||
'<invoke ', '<invoke\n', '<invoke\t', '<invoke\r',
|
||||
];
|
||||
|
||||
const XML_TOOL_OPENING_TAGS = [
|
||||
'<|dsml|tool_calls',
|
||||
'<|dsml|tool_calls',
|
||||
'<|dsmltool_calls',
|
||||
'<|dsml tool_calls',
|
||||
'<dsml|tool_calls',
|
||||
'<dsmltool_calls',
|
||||
'<dsml tool_calls',
|
||||
'<|tool_calls',
|
||||
'<|tool_calls',
|
||||
'<tool_calls',
|
||||
];
|
||||
|
||||
const XML_TOOL_CLOSING_TAGS = [
|
||||
'</|dsml|tool_calls>',
|
||||
'</|dsml|tool_calls>',
|
||||
'</|dsmltool_calls>',
|
||||
'</|dsml tool_calls>',
|
||||
'</dsml|tool_calls>',
|
||||
'</dsmltool_calls>',
|
||||
'</dsml tool_calls>',
|
||||
'</|tool_calls>',
|
||||
'</|tool_calls>',
|
||||
'</tool_calls>',
|
||||
];
|
||||
|
||||
module.exports = {
|
||||
XML_TOOL_SEGMENT_TAGS,
|
||||
XML_TOOL_OPENING_TAGS,
|
||||
XML_TOOL_CLOSING_TAGS,
|
||||
};
|
||||
@@ -44,6 +44,9 @@ func rewriteDSMLToolMarkupOutsideIgnored(text string) string {
|
||||
}
|
||||
b.WriteString(tag.Name)
|
||||
b.WriteString(text[tag.NameEnd : tag.End+1])
|
||||
if text[tag.End] != '>' {
|
||||
b.WriteByte('>')
|
||||
}
|
||||
i = tag.End + 1
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -128,34 +128,39 @@ func scanToolMarkupTagAt(text string, start int) (ToolMarkupTag, bool) {
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
i := start + 1
|
||||
for i < len(text) && text[i] == '<' {
|
||||
i++
|
||||
}
|
||||
closing := false
|
||||
if i < len(text) && text[i] == '/' {
|
||||
closing = true
|
||||
i++
|
||||
}
|
||||
dsmlLike := false
|
||||
if next, ok := consumeToolMarkupPipe(text, i); ok {
|
||||
dsmlLike = true
|
||||
i = next
|
||||
}
|
||||
if strings.HasPrefix(lower[i:], "dsml") {
|
||||
dsmlLike = true
|
||||
i += len("dsml")
|
||||
for next, ok := consumeToolMarkupSeparator(text, i); ok; next, ok = consumeToolMarkupSeparator(text, i) {
|
||||
i = next
|
||||
}
|
||||
}
|
||||
i, dsmlLike := consumeToolMarkupNamePrefix(lower, text, i)
|
||||
name, nameLen := matchToolMarkupName(lower, i)
|
||||
if nameLen == 0 {
|
||||
return ToolMarkupTag{}, false
|
||||
}
|
||||
nameEnd := i + nameLen
|
||||
nameEndBeforePipes := nameEnd
|
||||
for next, ok := consumeToolMarkupPipe(text, nameEnd); ok; next, ok = consumeToolMarkupPipe(text, nameEnd) {
|
||||
nameEnd = next
|
||||
}
|
||||
hasTrailingPipe := nameEnd > nameEndBeforePipes
|
||||
if !hasToolMarkupBoundary(text, nameEnd) {
|
||||
return ToolMarkupTag{}, false
|
||||
}
|
||||
end := findXMLTagEnd(text, nameEnd)
|
||||
if end < 0 {
|
||||
return ToolMarkupTag{}, false
|
||||
if !hasTrailingPipe {
|
||||
return ToolMarkupTag{}, false
|
||||
}
|
||||
end = nameEnd - 1
|
||||
}
|
||||
if hasTrailingPipe {
|
||||
if nextLT := strings.IndexByte(text[nameEnd:], '<'); nextLT >= 0 && end >= nameEnd+nextLT {
|
||||
end = nameEnd - 1
|
||||
}
|
||||
}
|
||||
trimmed := strings.TrimSpace(text[start : end+1])
|
||||
return ToolMarkupTag{
|
||||
@@ -171,6 +176,74 @@ func scanToolMarkupTagAt(text string, start int) (ToolMarkupTag, bool) {
|
||||
}, true
|
||||
}
|
||||
|
||||
func IsPartialToolMarkupTagPrefix(text string) bool {
|
||||
if text == "" || text[0] != '<' || strings.Contains(text, ">") {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(text)
|
||||
i := 1
|
||||
for i < len(text) && text[i] == '<' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return true
|
||||
}
|
||||
if text[i] == '/' {
|
||||
i++
|
||||
}
|
||||
for i <= len(text) {
|
||||
if i == len(text) {
|
||||
return true
|
||||
}
|
||||
if hasToolMarkupNamePrefix(lower[i:]) {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix("dsml", lower[i:]) {
|
||||
return true
|
||||
}
|
||||
next, ok := consumeToolMarkupNamePrefixOnce(lower, text, i)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
i = next
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func consumeToolMarkupNamePrefix(lower, text string, idx int) (int, bool) {
|
||||
dsmlLike := false
|
||||
for {
|
||||
next, ok := consumeToolMarkupNamePrefixOnce(lower, text, idx)
|
||||
if !ok {
|
||||
return idx, dsmlLike
|
||||
}
|
||||
idx = next
|
||||
dsmlLike = true
|
||||
}
|
||||
}
|
||||
|
||||
func consumeToolMarkupNamePrefixOnce(lower, text string, idx int) (int, bool) {
|
||||
if next, ok := consumeToolMarkupPipe(text, idx); ok {
|
||||
return next, true
|
||||
}
|
||||
if idx < len(text) && (text[idx] == ' ' || text[idx] == '\t' || text[idx] == '\r' || text[idx] == '\n') {
|
||||
return idx + 1, true
|
||||
}
|
||||
if strings.HasPrefix(lower[idx:], "dsml") {
|
||||
return idx + len("dsml"), true
|
||||
}
|
||||
return idx, false
|
||||
}
|
||||
|
||||
func hasToolMarkupNamePrefix(lowerTail string) bool {
|
||||
for _, name := range toolMarkupNames {
|
||||
if strings.HasPrefix(lowerTail, name) || strings.HasPrefix(name, lowerTail) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func matchToolMarkupName(lower string, start int) (string, int) {
|
||||
for _, name := range toolMarkupNames {
|
||||
if strings.HasPrefix(lower[start:], name) {
|
||||
@@ -193,19 +266,6 @@ func consumeToolMarkupPipe(text string, idx int) (int, bool) {
|
||||
return idx, false
|
||||
}
|
||||
|
||||
func consumeToolMarkupSeparator(text string, idx int) (int, bool) {
|
||||
if idx >= len(text) {
|
||||
return idx, false
|
||||
}
|
||||
if text[idx] == ' ' || text[idx] == '\t' || text[idx] == '\r' || text[idx] == '\n' {
|
||||
return idx + 1, true
|
||||
}
|
||||
if next, ok := consumeToolMarkupPipe(text, idx); ok {
|
||||
return next, true
|
||||
}
|
||||
return idx, false
|
||||
}
|
||||
|
||||
func hasToolMarkupBoundary(text string, idx int) bool {
|
||||
if idx >= len(text) {
|
||||
return true
|
||||
|
||||
@@ -41,6 +41,52 @@ func TestParseToolCallsSupportsDSMLShell(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsToleratesDSMLTrailingPipeTagTerminator(t *testing.T) {
|
||||
text := strings.Join([]string{
|
||||
`<|DSML|tool_calls| `,
|
||||
` <|DSML|invoke name="terminal">`,
|
||||
` <|DSML|parameter name="command"><![CDATA[find "/home" -type d]]></|DSML|parameter>`,
|
||||
` <|DSML|parameter name="timeout"><![CDATA[10]]></|DSML|parameter>`,
|
||||
` </|DSML|invoke>`,
|
||||
`</|DSML|tool_calls>`,
|
||||
}, "\n")
|
||||
calls := ParseToolCalls(text, []string{"terminal"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one trailing-pipe DSML call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "terminal" {
|
||||
t.Fatalf("expected terminal tool, got %#v", calls[0])
|
||||
}
|
||||
if calls[0].Input["command"] != `find "/home" -type d` {
|
||||
t.Fatalf("expected command argument, got %#v", calls[0].Input)
|
||||
}
|
||||
if calls[0].Input["timeout"] != float64(10) {
|
||||
t.Fatalf("expected numeric timeout, got %#v", calls[0].Input)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsToleratesExtraLeadingLessThanBeforeDSML(t *testing.T) {
|
||||
text := `<<|DSML|tool_calls><<|DSML|invoke name="Bash"><<|DSML|parameter name="command"><![CDATA[pwd]]></|DSML|parameter></|DSML|invoke></|DSML|tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"Bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one extra-leading-less-than DSML call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "Bash" || calls[0].Input["command"] != "pwd" {
|
||||
t.Fatalf("unexpected extra-leading-less-than DSML parse result: %#v", calls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsToleratesRepeatedDSMLPrefixNoise(t *testing.T) {
|
||||
text := `<<DSML|DSML|tool_calls><<DSML|DSML|invoke name="Bash"><<DSML|DSML|parameter name="command"><![CDATA[git status]]></DSML|DSML|parameter></DSML|DSML|invoke></DSML|DSML|tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"Bash"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one repeated-prefix DSML call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "Bash" || calls[0].Input["command"] != "git status" {
|
||||
t.Fatalf("unexpected repeated-prefix DSML parse result: %#v", calls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsSupportsDSMLShellWithCanonicalExampleInCDATA(t *testing.T) {
|
||||
content := `<tool_calls><invoke name="demo"><parameter name="value">x</parameter></invoke></tool_calls>`
|
||||
text := `<|DSML|tool_calls><|DSML|invoke name="Write"><|DSML|parameter name="file_path">notes.md</|DSML|parameter><|DSML|parameter name="content"><![CDATA[` + content + `]]></|DSML|parameter></|DSML|invoke></|DSML|tool_calls>`
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
package toolstream
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/toolcall"
|
||||
)
|
||||
import "ds2api/internal/toolcall"
|
||||
|
||||
func ProcessChunk(state *State, chunk string, toolNames []string) []Event {
|
||||
if state == nil {
|
||||
@@ -174,31 +170,27 @@ func findToolSegmentStart(state *State, s string) int {
|
||||
if s == "" {
|
||||
return -1
|
||||
}
|
||||
lower := strings.ToLower(s)
|
||||
offset := 0
|
||||
for {
|
||||
bestKeyIdx := -1
|
||||
matchedTag := ""
|
||||
for _, tag := range xmlToolTagsToDetect {
|
||||
idx := strings.Index(lower[offset:], tag)
|
||||
if idx >= 0 {
|
||||
idx += offset
|
||||
if bestKeyIdx < 0 || idx < bestKeyIdx {
|
||||
bestKeyIdx = idx
|
||||
matchedTag = tag
|
||||
}
|
||||
}
|
||||
}
|
||||
if bestKeyIdx < 0 {
|
||||
tag, ok := toolcall.FindToolMarkupTagOutsideIgnored(s, offset)
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
if !insideCodeFenceWithState(state, s[:bestKeyIdx]) {
|
||||
return bestKeyIdx
|
||||
start := includeDuplicateLeadingLessThan(s, tag.Start)
|
||||
if !insideCodeFenceWithState(state, s[:start]) {
|
||||
return start
|
||||
}
|
||||
offset = bestKeyIdx + len(matchedTag)
|
||||
offset = tag.End + 1
|
||||
}
|
||||
}
|
||||
|
||||
func includeDuplicateLeadingLessThan(s string, idx int) int {
|
||||
for idx > 0 && s[idx-1] == '<' {
|
||||
idx--
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func consumeToolCapture(state *State, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
|
||||
@@ -153,27 +153,14 @@ func findPartialXMLToolTagStart(s string) int {
|
||||
if lastLT < 0 {
|
||||
return -1
|
||||
}
|
||||
tail := s[lastLT:]
|
||||
start := includeDuplicateLeadingLessThan(s, lastLT)
|
||||
tail := s[start:]
|
||||
// If there's a '>' in the tail, the tag is closed — not partial.
|
||||
if strings.Contains(tail, ">") {
|
||||
return -1
|
||||
}
|
||||
lowerTail := strings.ToLower(tail)
|
||||
for _, tag := range []string{
|
||||
"<tool_calls", "<invoke", "<parameter",
|
||||
"<|tool_calls", "<|invoke", "<|parameter",
|
||||
"<|tool_calls", "<|invoke", "<|parameter",
|
||||
"<|dsml|tool_calls", "<|dsml|invoke", "<|dsml|parameter",
|
||||
"<|dsml|tool_calls", "<|dsml|invoke", "<|dsml|parameter",
|
||||
"<dsmltool_calls", "<dsmlinvoke", "<dsmlparameter",
|
||||
"<dsml tool_calls", "<dsml invoke", "<dsml parameter",
|
||||
"<dsml|tool_calls", "<dsml|invoke", "<dsml|parameter",
|
||||
"<|dsmltool_calls", "<|dsmlinvoke", "<|dsmlparameter",
|
||||
"<|dsml tool_calls", "<|dsml invoke", "<|dsml parameter",
|
||||
} {
|
||||
if strings.HasPrefix(tag, lowerTail) {
|
||||
return lastLT
|
||||
}
|
||||
if toolcall.IsPartialToolMarkupTagPrefix(tail) {
|
||||
return start
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package toolstream
|
||||
|
||||
import "regexp"
|
||||
|
||||
// --- XML tool call support for the streaming sieve ---
|
||||
|
||||
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>", "</|dsml|tool_calls>", "</|dsmltool_calls>", "</|dsml tool_calls>", "</dsml|tool_calls>", "</dsmltool_calls>", "</dsml tool_calls>", "</|tool_calls>", "</|tool_calls>"}
|
||||
|
||||
// xmlToolCallBlockPattern matches a complete canonical XML tool call block.
|
||||
//
|
||||
//nolint:unused // reserved for future fast-path XML block detection.
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)((?:<tool_calls\b|<\|dsml\|tool_calls\b)[^>]*>\s*(?:.*?)\s*(?:</tool_calls>|</\|dsml\|tool_calls>))`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{
|
||||
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
|
||||
"<|dsml|tool_calls>", "<|dsml|tool_calls\n", "<|dsml|tool_calls ",
|
||||
"<|dsml|invoke ", "<|dsml|invoke\n", "<|dsml|invoke\t", "<|dsml|invoke\r",
|
||||
"<|dsmltool_calls>", "<|dsmltool_calls\n", "<|dsmltool_calls ",
|
||||
"<|dsmlinvoke ", "<|dsmlinvoke\n", "<|dsmlinvoke\t", "<|dsmlinvoke\r",
|
||||
"<|dsml tool_calls>", "<|dsml tool_calls\n", "<|dsml tool_calls ",
|
||||
"<|dsml invoke ", "<|dsml invoke\n", "<|dsml invoke\t", "<|dsml invoke\r",
|
||||
"<dsml|tool_calls>", "<dsml|tool_calls\n", "<dsml|tool_calls ",
|
||||
"<dsml|invoke ", "<dsml|invoke\n", "<dsml|invoke\t", "<dsml|invoke\r",
|
||||
"<dsmltool_calls>", "<dsmltool_calls\n", "<dsmltool_calls ",
|
||||
"<dsmlinvoke ", "<dsmlinvoke\n", "<dsmlinvoke\t", "<dsmlinvoke\r",
|
||||
"<dsml tool_calls>", "<dsml tool_calls\n", "<dsml tool_calls ",
|
||||
"<dsml invoke ", "<dsml invoke\n", "<dsml invoke\t", "<dsml invoke\r",
|
||||
"<|tool_calls>", "<|tool_calls\n", "<|tool_calls ",
|
||||
"<|invoke ", "<|invoke\n", "<|invoke\t", "<|invoke\r",
|
||||
"<|tool_calls>", "<|tool_calls\n", "<|tool_calls ",
|
||||
"<|invoke ", "<|invoke\n", "<|invoke\t", "<|invoke\r",
|
||||
"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r",
|
||||
}
|
||||
@@ -72,6 +72,97 @@ func TestProcessToolSieveInterceptsDSMLToolCallWithoutLeak(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsDSMLTrailingPipeToolCallWithoutLeak(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<|DSML|tool_calls| \n",
|
||||
` <|DSML|invoke name="terminal">` + "\n",
|
||||
` <|DSML|parameter name="command"><![CDATA[find "/home" -type d]]></|DSML|parameter>` + "\n",
|
||||
` <|DSML|parameter name="timeout"><![CDATA[10]]></|DSML|parameter>` + "\n",
|
||||
" </|DSML|invoke>\n",
|
||||
"</|DSML|tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"terminal"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"terminal"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
var calls []any
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
for _, call := range evt.ToolCalls {
|
||||
calls = append(calls, call)
|
||||
}
|
||||
}
|
||||
if text := textContent.String(); strings.Contains(strings.ToLower(text), "dsml") || strings.Contains(text, "terminal") {
|
||||
t.Fatalf("trailing-pipe DSML tool call leaked to text: %q events=%#v", text, events)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected one trailing-pipe DSML tool call, got %d events=%#v", len(calls), events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsExtraLeadingLessThanDSMLToolCallWithoutLeak(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<<|DSML|tool_calls>\n",
|
||||
` <<|DSML|invoke name="Bash">` + "\n",
|
||||
` <<|DSML|parameter name="command"><![CDATA[pwd]]></|DSML|parameter>` + "\n",
|
||||
" </|DSML|invoke>\n",
|
||||
"</|DSML|tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"Bash"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"Bash"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if text := textContent.String(); strings.Contains(text, "<") || strings.Contains(text, "Bash") {
|
||||
t.Fatalf("extra-leading-less-than DSML tool call leaked to text: %q events=%#v", text, events)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one extra-leading-less-than DSML tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveInterceptsRepeatedDSMLPrefixNoiseWithoutLeak(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<<DSML|DSML|tool",
|
||||
"_calls>\n",
|
||||
` <<DSML|DSML|invoke name="Bash">` + "\n",
|
||||
` <<DSML|DSML|parameter name="command"><![CDATA[git status]]></DSML|DSML|parameter>` + "\n",
|
||||
" </DSML|DSML|invoke>\n",
|
||||
"</DSML|DSML|tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"Bash"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"Bash"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
if text := textContent.String(); strings.Contains(strings.ToLower(text), "dsml") || strings.Contains(text, "Bash") {
|
||||
t.Fatalf("repeated-prefix DSML tool call leaked to text: %q events=%#v", text, events)
|
||||
}
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected one repeated-prefix DSML tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
|
||||
var state State
|
||||
const toolName = "write_to_file"
|
||||
@@ -442,6 +533,8 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"dsml_trailing_pipe_tag", "some text <|DSML|tool_calls| \n", 10},
|
||||
{"dsml_extra_leading_less_than", "some text <<|DSML|tool_calls>\n", 10},
|
||||
{"invoke_tag_missing_wrapper", "some text <invoke name=\"read_file\">\n", 10},
|
||||
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
|
||||
@@ -465,6 +558,8 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"partial_tool_calls", "Hello <tool_ca", 6},
|
||||
{"partial_dsml_trailing_pipe", "Hello <|DSML|tool_calls|", 6},
|
||||
{"partial_dsml_extra_leading_less_than", "Hello <<|DSML|tool_calls", 6},
|
||||
{"partial_invoke", "Hello <inv", 6},
|
||||
{"bare_tool_call_not_held", "Hello <tool_name", -1},
|
||||
{"partial_lt_only", "Text <", 5},
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
tiktoken "github.com/hupe1980/go-tiktoken"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTokenizerModel = "gpt-4o"
|
||||
claudeTokenizerModel = "claude"
|
||||
@@ -33,41 +27,6 @@ func CountOutputTokens(text, model string) int {
|
||||
return base
|
||||
}
|
||||
|
||||
func countWithTokenizer(text, model string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
ids, _, err := encoding.Encode(text, nil, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func tokenizerModelForCount(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude"):
|
||||
return claudeTokenizerModel
|
||||
case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek-v4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek"):
|
||||
return defaultTokenizerModel
|
||||
default:
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
}
|
||||
|
||||
func conservativePromptPadding(base int) int {
|
||||
padding := base / 50
|
||||
if padding < 4 {
|
||||
|
||||
7
internal/util/token_count_heuristic.go
Normal file
7
internal/util/token_count_heuristic.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build 386 || arm || mips || mipsle || wasm
|
||||
|
||||
package util
|
||||
|
||||
func countWithTokenizer(_, _ string) int {
|
||||
return 0
|
||||
}
|
||||
44
internal/util/token_count_tiktoken.go
Normal file
44
internal/util/token_count_tiktoken.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build !386 && !arm && !mips && !mipsle && !wasm
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
tiktoken "github.com/hupe1980/go-tiktoken"
|
||||
)
|
||||
|
||||
func countWithTokenizer(text, model string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
ids, _, err := encoding.Encode(text, nil, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func tokenizerModelForCount(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude"):
|
||||
return claudeTokenizerModel
|
||||
case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek-v4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek"):
|
||||
return defaultTokenizerModel
|
||||
default:
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user