This commit is contained in:
CJACK
2026-04-26 09:17:40 +08:00
parent 40b8182984
commit 0bfddf7943
10 changed files with 193 additions and 8 deletions

View File

@@ -71,15 +71,30 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
hooks.OnFinalize(reason, scannerErr)
}
}
contextDone := func() bool {
if cfg.Context.Err() == nil {
return false
}
if hooks.OnContextDone != nil {
hooks.OnContextDone()
}
return true
}
for {
if contextDone() {
return
}
select {
case <-cfg.Context.Done():
if hooks.OnContextDone != nil {
hooks.OnContextDone()
if contextDone() {
return
}
return
case <-tickCh(ticker):
if contextDone() {
return
}
if !hasContent {
keepaliveCount++
if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput {
@@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
hooks.OnKeepAlive()
}
case parsed, ok := <-parsedLines:
if contextDone() {
return
}
if !ok {
finalize(StopReasonUpstreamCompleted, <-done)
return

View File

@@ -0,0 +1,47 @@
package stream
import (
"context"
"strings"
"testing"
"ds2api/internal/sse"
)
func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
var finalized bool
var contextDone bool
var parsedCalled bool
ConsumeSSE(ConsumeConfig{
Context: ctx,
Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"),
ThinkingEnabled: false,
InitialType: "text",
KeepAliveInterval: 0,
}, ConsumeHooks{
OnParsed: func(_ sse.LineResult) ParsedDecision {
parsedCalled = true
return ParsedDecision{}
},
OnFinalize: func(_ StopReason, _ error) {
finalized = true
},
OnContextDone: func() {
contextDone = true
},
})
if !contextDone {
t.Fatal("expected OnContextDone to run for an already-cancelled context")
}
if finalized {
t.Fatal("expected OnFinalize not to run after context cancellation wins")
}
if parsedCalled {
t.Fatal("expected parsed lines not to be processed after context cancellation wins")
}
}

View File

@@ -27,6 +27,8 @@ RULES:
7) Numbers, booleans, and null stay plain text.
8) Use only the parameter names in the tool schema. Do not invent fields.
9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
10) If you call a tool, the first non-whitespace characters of that tool block must be exactly <tool_calls>.
11) Never omit the opening <tool_calls> tag, even if you already plan to close with </tool_calls>.
PARAMETER SHAPES:
- string => <parameter name="x"><![CDATA[value]]></parameter>
@@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences:
` + "```xml" + `
<tool_calls>...</tool_calls>
` + "```" + `
Wrong 3 — missing opening wrapper:
<invoke name="TOOL_NAME">...</invoke>
</tool_calls>
Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.

View File

@@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) {
}
}
func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) {
out := BuildToolCallInstructions([]string{"read_file"})
if !strings.Contains(out, "Never omit the opening <tool_calls> tag") {
t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out)
}
if !strings.Contains(out, "Wrong 3 — missing opening wrapper") {
t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out)
}
}
func findInvokeBlocks(text, name string) []string {
open := `<invoke name="` + name + `">`
remaining := text

View File

@@ -11,9 +11,17 @@ var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)<tool_calls\b[^>]*>\s*
var xmlInvokePattern = regexp.MustCompile(`(?is)<invoke\b([^>]*)>\s*(.*?)\s*</invoke>`)
var xmlParameterPattern = regexp.MustCompile(`(?is)<parameter\b([^>]*)>\s*(.*?)\s*</parameter>`)
var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`)
var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)</tool_calls>`)
var xmlInvokeStartPattern = regexp.MustCompile(`(?is)<invoke\b[^>]*\bname\s*=\s*("([^"]*)"|'([^']*)')`)
func parseXMLToolCalls(text string) []ParsedToolCall {
wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1)
if len(wrappers) == 0 {
repaired := repairMissingXMLToolCallsOpeningWrapper(text)
if repaired != text {
wrappers = xmlToolCallsWrapperPattern.FindAllStringSubmatch(repaired, -1)
}
}
if len(wrappers) == 0 {
return nil
}
@@ -36,6 +44,28 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
return out
}
func repairMissingXMLToolCallsOpeningWrapper(text string) string {
lower := strings.ToLower(text)
if strings.Contains(lower, "<tool_calls") {
return text
}
closeMatches := xmlToolCallsClosePattern.FindAllStringIndex(text, -1)
if len(closeMatches) == 0 {
return text
}
invokeLoc := xmlInvokeStartPattern.FindStringIndex(text)
if invokeLoc == nil {
return text
}
closeLoc := closeMatches[len(closeMatches)-1]
if invokeLoc[0] >= closeLoc[0] {
return text
}
return text[:invokeLoc[0]] + "<tool_calls>" + text[invokeLoc[0]:closeLoc[0]] + "</tool_calls>" + text[closeLoc[1]:]
}
func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
if len(block) < 3 {
return ParsedToolCall{}, false

View File

@@ -175,6 +175,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) {
}
}
func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) {
text := `Before tool call
<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>
</tool_calls>
after`
res := ParseToolCallsDetailed(text, []string{"read_file"})
if len(res.Calls) != 1 {
t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res)
}
if res.Calls[0].Name != "read_file" {
t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0])
}
if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" {
t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input)
}
if !res.SawToolCallSyntax {
t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res)
}
}
func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) {
text := `<tool_calls><invoke name="read_file"><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"read_file"})

View File

@@ -10,7 +10,7 @@ import (
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
var xmlToolCallClosingTags = []string{"</tool_calls>"}
var xmlToolCallOpeningTags = []string{"<tool_calls"}
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
// Order matters: longer/wrapper tags must be checked first.
@@ -24,7 +24,7 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls "}
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
@@ -55,6 +55,22 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
// If this block failed to become a tool call, pass it through as text.
return prefixPart + xmlBlock, nil, suffixPart, true
}
if !strings.Contains(lower, "<tool_calls") {
invokeIdx := strings.Index(lower, "<invoke")
closeIdx := strings.LastIndex(lower, "</tool_calls>")
if invokeIdx >= 0 && closeIdx > invokeIdx {
closeEnd := closeIdx + len("</tool_calls>")
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
prefixPart := captured[:invokeIdx]
suffixPart := captured[closeEnd:]
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
if len(parsed) > 0 {
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
return prefixPart, parsed, suffixPart, true
}
return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true
}
}
return "", nil, "", false
}

View File

@@ -288,6 +288,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
want int
}{
{"tool_calls_tag", "some text <tool_calls>\n", 10},
{"invoke_tag_missing_wrapper", "some text <invoke name=\"read_file\">\n", 10},
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
{"no_xml", "just plain text", -1},
@@ -310,6 +311,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
want int
}{
{"partial_tool_calls", "Hello <tool_ca", 6},
{"partial_invoke", "Hello <inv", 6},
{"bare_tool_call_not_held", "Hello <tool_name", -1},
{"partial_lt_only", "Text <", 5},
{"complete_tag", "Text <tool_calls>done", -1},
@@ -505,3 +507,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) {
t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String())
}
}
func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) {
var state State
chunks := []string{
"<invoke name=\"read_file\">\n",
" <parameter name=\"path\">README.md</parameter>\n",
"</invoke>\n",
"</tool_calls>",
}
var events []Event
for _, c := range chunks {
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
}
events = append(events, Flush(&state, []string{"read_file"})...)
var textContent strings.Builder
toolCalls := 0
for _, evt := range events {
textContent.WriteString(evt.Content)
toolCalls += len(evt.ToolCalls)
}
if toolCalls != 1 {
t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events)
}
if strings.Contains(textContent.String(), "<invoke") || strings.Contains(textContent.String(), "</tool_calls>") {
t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String())
}
}