mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 10:25:28 +08:00
1
This commit is contained in:
@@ -71,15 +71,30 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
|
||||
hooks.OnFinalize(reason, scannerErr)
|
||||
}
|
||||
}
|
||||
contextDone := func() bool {
|
||||
if cfg.Context.Err() == nil {
|
||||
return false
|
||||
}
|
||||
if hooks.OnContextDone != nil {
|
||||
hooks.OnContextDone()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
for {
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-cfg.Context.Done():
|
||||
if hooks.OnContextDone != nil {
|
||||
hooks.OnContextDone()
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
return
|
||||
case <-tickCh(ticker):
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
if !hasContent {
|
||||
keepaliveCount++
|
||||
if cfg.MaxKeepAliveNoInput > 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput {
|
||||
@@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) {
|
||||
hooks.OnKeepAlive()
|
||||
}
|
||||
case parsed, ok := <-parsedLines:
|
||||
if contextDone() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
finalize(StopReasonUpstreamCompleted, <-done)
|
||||
return
|
||||
|
||||
47
internal/stream/engine_test.go
Normal file
47
internal/stream/engine_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/sse"
|
||||
)
|
||||
|
||||
func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
var finalized bool
|
||||
var contextDone bool
|
||||
var parsedCalled bool
|
||||
|
||||
ConsumeSSE(ConsumeConfig{
|
||||
Context: ctx,
|
||||
Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"),
|
||||
ThinkingEnabled: false,
|
||||
InitialType: "text",
|
||||
KeepAliveInterval: 0,
|
||||
}, ConsumeHooks{
|
||||
OnParsed: func(_ sse.LineResult) ParsedDecision {
|
||||
parsedCalled = true
|
||||
return ParsedDecision{}
|
||||
},
|
||||
OnFinalize: func(_ StopReason, _ error) {
|
||||
finalized = true
|
||||
},
|
||||
OnContextDone: func() {
|
||||
contextDone = true
|
||||
},
|
||||
})
|
||||
|
||||
if !contextDone {
|
||||
t.Fatal("expected OnContextDone to run for an already-cancelled context")
|
||||
}
|
||||
if finalized {
|
||||
t.Fatal("expected OnFinalize not to run after context cancellation wins")
|
||||
}
|
||||
if parsedCalled {
|
||||
t.Fatal("expected parsed lines not to be processed after context cancellation wins")
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,8 @@ RULES:
|
||||
7) Numbers, booleans, and null stay plain text.
|
||||
8) Use only the parameter names in the tool schema. Do not invent fields.
|
||||
9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue.
|
||||
10) If you call a tool, the first non-whitespace characters of that tool block must be exactly <tool_calls>.
|
||||
11) Never omit the opening <tool_calls> tag, even if you already plan to close with </tool_calls>.
|
||||
|
||||
PARAMETER SHAPES:
|
||||
- string => <parameter name="x"><![CDATA[value]]></parameter>
|
||||
@@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences:
|
||||
` + "```xml" + `
|
||||
<tool_calls>...</tool_calls>
|
||||
` + "```" + `
|
||||
Wrong 3 — missing opening wrapper:
|
||||
<invoke name="TOOL_NAME">...</invoke>
|
||||
</tool_calls>
|
||||
|
||||
Remember: The ONLY valid way to use tools is the <tool_calls>...</tool_calls> XML block at the end of your response.
|
||||
|
||||
|
||||
@@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) {
|
||||
out := BuildToolCallInstructions([]string{"read_file"})
|
||||
if !strings.Contains(out, "Never omit the opening <tool_calls> tag") {
|
||||
t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out)
|
||||
}
|
||||
if !strings.Contains(out, "Wrong 3 — missing opening wrapper") {
|
||||
t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func findInvokeBlocks(text, name string) []string {
|
||||
open := `<invoke name="` + name + `">`
|
||||
remaining := text
|
||||
|
||||
@@ -11,9 +11,17 @@ var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)<tool_calls\b[^>]*>\s*
|
||||
var xmlInvokePattern = regexp.MustCompile(`(?is)<invoke\b([^>]*)>\s*(.*?)\s*</invoke>`)
|
||||
var xmlParameterPattern = regexp.MustCompile(`(?is)<parameter\b([^>]*)>\s*(.*?)\s*</parameter>`)
|
||||
var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`)
|
||||
var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)</tool_calls>`)
|
||||
var xmlInvokeStartPattern = regexp.MustCompile(`(?is)<invoke\b[^>]*\bname\s*=\s*("([^"]*)"|'([^']*)')`)
|
||||
|
||||
func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1)
|
||||
if len(wrappers) == 0 {
|
||||
repaired := repairMissingXMLToolCallsOpeningWrapper(text)
|
||||
if repaired != text {
|
||||
wrappers = xmlToolCallsWrapperPattern.FindAllStringSubmatch(repaired, -1)
|
||||
}
|
||||
}
|
||||
if len(wrappers) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -36,6 +44,28 @@ func parseXMLToolCalls(text string) []ParsedToolCall {
|
||||
return out
|
||||
}
|
||||
|
||||
func repairMissingXMLToolCallsOpeningWrapper(text string) string {
|
||||
lower := strings.ToLower(text)
|
||||
if strings.Contains(lower, "<tool_calls") {
|
||||
return text
|
||||
}
|
||||
|
||||
closeMatches := xmlToolCallsClosePattern.FindAllStringIndex(text, -1)
|
||||
if len(closeMatches) == 0 {
|
||||
return text
|
||||
}
|
||||
invokeLoc := xmlInvokeStartPattern.FindStringIndex(text)
|
||||
if invokeLoc == nil {
|
||||
return text
|
||||
}
|
||||
closeLoc := closeMatches[len(closeMatches)-1]
|
||||
if invokeLoc[0] >= closeLoc[0] {
|
||||
return text
|
||||
}
|
||||
|
||||
return text[:invokeLoc[0]] + "<tool_calls>" + text[invokeLoc[0]:closeLoc[0]] + "</tool_calls>" + text[closeLoc[1]:]
|
||||
}
|
||||
|
||||
func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) {
|
||||
if len(block) < 3 {
|
||||
return ParsedToolCall{}, false
|
||||
|
||||
@@ -175,6 +175,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) {
|
||||
text := `Before tool call
|
||||
<invoke name="read_file"><parameter name="path">README.md</parameter></invoke>
|
||||
</tool_calls>
|
||||
after`
|
||||
res := ParseToolCallsDetailed(text, []string{"read_file"})
|
||||
if len(res.Calls) != 1 {
|
||||
t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res)
|
||||
}
|
||||
if res.Calls[0].Name != "read_file" {
|
||||
t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0])
|
||||
}
|
||||
if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" {
|
||||
t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input)
|
||||
}
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) {
|
||||
text := `<tool_calls><invoke name="read_file"><tool_name>read_file</tool_name><param>{"path":"README.md"}</param></invoke></tool_calls>`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
|
||||
//nolint:unused // kept as explicit tag inventory for future XML sieve refinements.
|
||||
var xmlToolCallClosingTags = []string{"</tool_calls>"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls"}
|
||||
var xmlToolCallOpeningTags = []string{"<tool_calls", "<invoke"}
|
||||
|
||||
// xmlToolCallTagPairs maps each opening tag to its expected closing tag.
|
||||
// Order matters: longer/wrapper tags must be checked first.
|
||||
@@ -24,7 +24,7 @@ var xmlToolCallTagPairs = []struct{ open, close string }{
|
||||
var xmlToolCallBlockPattern = regexp.MustCompile(`(?is)(<tool_calls\b[^>]*>\s*(?:.*?)\s*</tool_calls>)`)
|
||||
|
||||
// xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart.
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls "}
|
||||
var xmlToolTagsToDetect = []string{"<tool_calls>", "<tool_calls\n", "<tool_calls ", "<invoke ", "<invoke\n", "<invoke\t", "<invoke\r"}
|
||||
|
||||
// consumeXMLToolCapture tries to extract complete XML tool call blocks from captured text.
|
||||
func consumeXMLToolCapture(captured string, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
|
||||
@@ -55,6 +55,22 @@ func consumeXMLToolCapture(captured string, toolNames []string) (prefix string,
|
||||
// If this block failed to become a tool call, pass it through as text.
|
||||
return prefixPart + xmlBlock, nil, suffixPart, true
|
||||
}
|
||||
if !strings.Contains(lower, "<tool_calls") {
|
||||
invokeIdx := strings.Index(lower, "<invoke")
|
||||
closeIdx := strings.LastIndex(lower, "</tool_calls>")
|
||||
if invokeIdx >= 0 && closeIdx > invokeIdx {
|
||||
closeEnd := closeIdx + len("</tool_calls>")
|
||||
xmlBlock := "<tool_calls>" + captured[invokeIdx:closeIdx] + "</tool_calls>"
|
||||
prefixPart := captured[:invokeIdx]
|
||||
suffixPart := captured[closeEnd:]
|
||||
parsed := toolcall.ParseToolCalls(xmlBlock, toolNames)
|
||||
if len(parsed) > 0 {
|
||||
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true
|
||||
}
|
||||
}
|
||||
return "", nil, "", false
|
||||
}
|
||||
|
||||
|
||||
@@ -288,6 +288,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"tool_calls_tag", "some text <tool_calls>\n", 10},
|
||||
{"invoke_tag_missing_wrapper", "some text <invoke name=\"read_file\">\n", 10},
|
||||
{"bare_tool_call_text", "prefix <tool_call>\n", -1},
|
||||
{"xml_inside_code_fence", "```xml\n<tool_calls><invoke name=\"read_file\"></invoke></tool_calls>\n```", -1},
|
||||
{"no_xml", "just plain text", -1},
|
||||
@@ -310,6 +311,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) {
|
||||
want int
|
||||
}{
|
||||
{"partial_tool_calls", "Hello <tool_ca", 6},
|
||||
{"partial_invoke", "Hello <inv", 6},
|
||||
{"bare_tool_call_not_held", "Hello <tool_name", -1},
|
||||
{"partial_lt_only", "Text <", 5},
|
||||
{"complete_tag", "Text <tool_calls>done", -1},
|
||||
@@ -505,3 +507,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) {
|
||||
t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) {
|
||||
var state State
|
||||
chunks := []string{
|
||||
"<invoke name=\"read_file\">\n",
|
||||
" <parameter name=\"path\">README.md</parameter>\n",
|
||||
"</invoke>\n",
|
||||
"</tool_calls>",
|
||||
}
|
||||
var events []Event
|
||||
for _, c := range chunks {
|
||||
events = append(events, ProcessChunk(&state, c, []string{"read_file"})...)
|
||||
}
|
||||
events = append(events, Flush(&state, []string{"read_file"})...)
|
||||
|
||||
var textContent strings.Builder
|
||||
toolCalls := 0
|
||||
for _, evt := range events {
|
||||
textContent.WriteString(evt.Content)
|
||||
toolCalls += len(evt.ToolCalls)
|
||||
}
|
||||
|
||||
if toolCalls != 1 {
|
||||
t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events)
|
||||
}
|
||||
if strings.Contains(textContent.String(), "<invoke") || strings.Contains(textContent.String(), "</tool_calls>") {
|
||||
t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user