diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index 6d6301c..495d1cc 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -148,6 +148,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` 4. 把这整段内容并入 system prompt。 工具调用正例仍只示范 canonical XML:`` → `` → ``。 +提示词会额外强调:如果要调用工具,工具块的首个非空白字符必须就是 ``,不能只输出 `` 而漏掉 opening tag。 正例中的工具名只会来自当前请求实际声明的工具;如果当前请求没有足够的已知工具形态,就省略对应的单工具、多工具或嵌套示例,避免把不可用工具名写进 prompt。 对执行类工具,脚本内容必须进入执行参数本身:`Bash` / `execute_command` 使用 `command`,`exec_command` 使用 `cmd`;不要把脚本示范成 `path` / `content` 文件写入参数。 @@ -192,6 +193,7 @@ assistant 历史 `tool_calls` 不会保留成 OpenAI 原生 JSON,而会转成 ``` 这也是当前项目里唯一受支持的 canonical tool-calling 形态;其他形态都会作为普通文本保留,不会作为可执行调用语法。 +例外是 parser 会对一个非常窄的模型失误做修复:如果 assistant 输出了 `` ... ``,但漏掉最前面的 opening ``,解析阶段会补回 wrapper 后再尝试识别。 这件事很重要,因为它决定了: diff --git a/docs/toolcall-semantics.md b/docs/toolcall-semantics.md index 2627a0a..ea5c456 100644 --- a/docs/toolcall-semantics.md +++ b/docs/toolcall-semantics.md @@ -23,9 +23,14 @@ - 工具名必须放在 `invoke` 的 `name` 属性 - 参数必须使用 `...` +兼容修复: + +- 如果模型漏掉 opening ``,但后面仍输出了一个或多个 `` 并以 `` 收尾,Go 解析链路会在解析前补回缺失的 opening wrapper。 +- 这是一个针对常见模型失误的窄修复,不改变推荐输出格式;prompt 仍要求模型直接输出完整 canonical XML。 + ## 2) 非 canonical 内容 -任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。 +任何不满足上述 canonical XML 形态的内容,都会保留为普通文本,不会执行。一个例外是上一节提到的“缺失 opening ``、但 closing `` 仍存在”的窄修复场景。 当前 parser 不把 allow-list 当作硬安全边界:即使传入了已声明工具名列表,XML 里出现未声明工具名时也会尽量解析并交给上层协议输出;真正的执行侧仍必须自行校验工具名和参数。 @@ -33,7 +38,8 @@ 在流式链路中(Go / Node 一致): -- 只有从 `` wrapper 会进入结构化捕获 +- 如果流里直接从 `` 开始,但后面补上了 ``,Go 流式筛分也会按缺失 opening wrapper 的修复路径尝试恢复 - 已识别成功的工具调用不会再次回流到普通文本 - 不符合新格式的块不会执行,并继续按原样文本透传 - fenced code block 中的 XML 示例始终按普通文本处理 @@ -43,14 +49,14 @@ `ParseToolCallsDetailed` / `parseToolCallsDetailed` 返回: - `calls`:解析出的工具调用列表(`name` + `input`) -- `sawToolCallSyntax`:只有检测到 ` 0 && keepaliveCount >= cfg.MaxKeepAliveNoInput { @@ -95,6 +110,9 @@ func ConsumeSSE(cfg ConsumeConfig, hooks ConsumeHooks) { hooks.OnKeepAlive() } case parsed, ok := <-parsedLines: + if contextDone() { + return + } if !ok { finalize(StopReasonUpstreamCompleted, <-done) return diff --git a/internal/stream/engine_test.go b/internal/stream/engine_test.go new file mode 100644 index 0000000..b23474b --- /dev/null +++ b/internal/stream/engine_test.go @@ -0,0 +1,47 @@ +package stream + +import ( + "context" + "strings" + "testing" + + "ds2api/internal/sse" +) + +func TestConsumeSSEPrefersContextCancellationOverReadyParsedLines(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var finalized bool + var contextDone bool + var parsedCalled bool + + ConsumeSSE(ConsumeConfig{ + Context: ctx, + Body: strings.NewReader("data: {\"p\":\"response/content\",\"v\":\"hello\"}\n\ndata: [DONE]\n"), + ThinkingEnabled: false, + InitialType: "text", + KeepAliveInterval: 0, + }, ConsumeHooks{ + OnParsed: func(_ sse.LineResult) ParsedDecision { + parsedCalled = true + return ParsedDecision{} + }, + OnFinalize: func(_ StopReason, _ error) { + finalized = true + }, + OnContextDone: func() { + contextDone = true + }, + }) + + if !contextDone { + t.Fatal("expected OnContextDone to run for an already-cancelled context") + } + if finalized { + t.Fatal("expected OnFinalize not to run after context cancellation wins") + } + if parsedCalled { + t.Fatal("expected parsed lines not to be processed after context cancellation wins") + } +} diff --git a/internal/toolcall/tool_prompt.go b/internal/toolcall/tool_prompt.go index 7f405d2..aa556e8 100644 --- a/internal/toolcall/tool_prompt.go +++ b/internal/toolcall/tool_prompt.go @@ -27,6 +27,8 @@ RULES: 7) Numbers, booleans, and null stay plain text. 8) Use only the parameter names in the tool schema. Do not invent fields. 9) Do NOT wrap XML in markdown fences. Do NOT output explanations, role markers, or internal monologue. +10) If you call a tool, the first non-whitespace characters of that tool block must be exactly . +11) Never omit the opening tag, even if you already plan to close with . PARAMETER SHAPES: - string => @@ -42,6 +44,9 @@ Wrong 2 — Markdown code fences: ` + "```xml" + ` ... ` + "```" + ` +Wrong 3 — missing opening wrapper: + ... + Remember: The ONLY valid way to use tools is the ... XML block at the end of your response. diff --git a/internal/toolcall/tool_prompt_test.go b/internal/toolcall/tool_prompt_test.go index 8b0e8cf..d482d52 100644 --- a/internal/toolcall/tool_prompt_test.go +++ b/internal/toolcall/tool_prompt_test.go @@ -109,6 +109,16 @@ func TestBuildToolCallInstructions_WriteUsesFilePathAndContent(t *testing.T) { } } +func TestBuildToolCallInstructions_AnchorsMissingOpeningWrapperFailureMode(t *testing.T) { + out := BuildToolCallInstructions([]string{"read_file"}) + if !strings.Contains(out, "Never omit the opening tag") { + t.Fatalf("expected explicit missing-opening-tag warning, got: %s", out) + } + if !strings.Contains(out, "Wrong 3 — missing opening wrapper") { + t.Fatalf("expected missing-opening-wrapper negative example, got: %s", out) + } +} + func findInvokeBlocks(text, name string) []string { open := `` remaining := text diff --git a/internal/toolcall/toolcalls_parse_markup.go b/internal/toolcall/toolcalls_parse_markup.go index a1424e8..2a9c441 100644 --- a/internal/toolcall/toolcalls_parse_markup.go +++ b/internal/toolcall/toolcalls_parse_markup.go @@ -11,9 +11,17 @@ var xmlToolCallsWrapperPattern = regexp.MustCompile(`(?is)]*>\s* var xmlInvokePattern = regexp.MustCompile(`(?is)]*)>\s*(.*?)\s*`) var xmlParameterPattern = regexp.MustCompile(`(?is)]*)>\s*(.*?)\s*`) var xmlAttrPattern = regexp.MustCompile(`(?is)\b([a-z0-9_:-]+)\s*=\s*("([^"]*)"|'([^']*)')`) +var xmlToolCallsClosePattern = regexp.MustCompile(`(?is)`) +var xmlInvokeStartPattern = regexp.MustCompile(`(?is)]*\bname\s*=\s*("([^"]*)"|'([^']*)')`) func parseXMLToolCalls(text string) []ParsedToolCall { wrappers := xmlToolCallsWrapperPattern.FindAllStringSubmatch(text, -1) + if len(wrappers) == 0 { + repaired := repairMissingXMLToolCallsOpeningWrapper(text) + if repaired != text { + wrappers = xmlToolCallsWrapperPattern.FindAllStringSubmatch(repaired, -1) + } + } if len(wrappers) == 0 { return nil } @@ -36,6 +44,28 @@ func parseXMLToolCalls(text string) []ParsedToolCall { return out } +func repairMissingXMLToolCallsOpeningWrapper(text string) string { + lower := strings.ToLower(text) + if strings.Contains(lower, "= closeLoc[0] { + return text + } + + return text[:invokeLoc[0]] + "" + text[invokeLoc[0]:closeLoc[0]] + "" + text[closeLoc[1]:] +} + func parseSingleXMLToolCall(block []string) (ParsedToolCall, bool) { if len(block) < 3 { return ParsedToolCall{}, false diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index 13d0bef..8d26f73 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -175,6 +175,26 @@ func TestParseToolCallsRejectsBareInvokeWithoutToolCallsWrapper(t *testing.T) { } } +func TestParseToolCallsRepairsMissingOpeningToolCallsWrapperWhenClosingTagExists(t *testing.T) { + text := `Before tool call +README.md + +after` + res := ParseToolCallsDetailed(text, []string{"read_file"}) + if len(res.Calls) != 1 { + t.Fatalf("expected repaired wrapper to parse exactly one call, got %#v", res) + } + if res.Calls[0].Name != "read_file" { + t.Fatalf("expected repaired wrapper to preserve tool name, got %#v", res.Calls[0]) + } + if got, _ := res.Calls[0].Input["path"].(string); got != "README.md" { + t.Fatalf("expected repaired wrapper to preserve args, got %#v", res.Calls[0].Input) + } + if !res.SawToolCallSyntax { + t.Fatalf("expected repaired wrapper to mark tool syntax seen, got %#v", res) + } +} + func TestParseToolCallsRejectsLegacyCanonicalBody(t *testing.T) { text := `read_file{"path":"README.md"}` calls := ParseToolCalls(text, []string{"read_file"}) diff --git a/internal/toolstream/tool_sieve_xml.go b/internal/toolstream/tool_sieve_xml.go index 87fb075..6d6cbc4 100644 --- a/internal/toolstream/tool_sieve_xml.go +++ b/internal/toolstream/tool_sieve_xml.go @@ -10,7 +10,7 @@ import ( //nolint:unused // kept as explicit tag inventory for future XML sieve refinements. var xmlToolCallClosingTags = []string{""} -var xmlToolCallOpeningTags = []string{"]*>\s*(?:.*?)\s*)`) // xmlToolTagsToDetect is the set of XML tag prefixes used by findToolSegmentStart. -var xmlToolTagsToDetect = []string{"", "", "") + if invokeIdx >= 0 && closeIdx > invokeIdx { + closeEnd := closeIdx + len("") + xmlBlock := "" + captured[invokeIdx:closeIdx] + "" + prefixPart := captured[:invokeIdx] + suffixPart := captured[closeEnd:] + parsed := toolcall.ParseToolCalls(xmlBlock, toolNames) + if len(parsed) > 0 { + prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart) + return prefixPart, parsed, suffixPart, true + } + return prefixPart + captured[invokeIdx:closeEnd], nil, suffixPart, true + } + } return "", nil, "", false } diff --git a/internal/toolstream/tool_sieve_xml_test.go b/internal/toolstream/tool_sieve_xml_test.go index 4b06bc3..ba4a00b 100644 --- a/internal/toolstream/tool_sieve_xml_test.go +++ b/internal/toolstream/tool_sieve_xml_test.go @@ -288,6 +288,7 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) { want int }{ {"tool_calls_tag", "some text \n", 10}, + {"invoke_tag_missing_wrapper", "some text \n", 10}, {"bare_tool_call_text", "prefix \n", -1}, {"xml_inside_code_fence", "```xml\n\n```", -1}, {"no_xml", "just plain text", -1}, @@ -310,6 +311,7 @@ func TestFindPartialXMLToolTagStart(t *testing.T) { want int }{ {"partial_tool_calls", "Hello done", -1}, @@ -505,3 +507,32 @@ func TestProcessToolSievePassesThroughBareToolCallAsText(t *testing.T) { t.Fatalf("expected bare invoke to pass through unchanged, got %q", textContent.String()) } } + +func TestProcessToolSieveRepairsMissingOpeningWrapperWithoutLeakingInvokeText(t *testing.T) { + var state State + chunks := []string{ + "\n", + " README.md\n", + "\n", + "", + } + var events []Event + for _, c := range chunks { + events = append(events, ProcessChunk(&state, c, []string{"read_file"})...) + } + events = append(events, Flush(&state, []string{"read_file"})...) + + var textContent strings.Builder + toolCalls := 0 + for _, evt := range events { + textContent.WriteString(evt.Content) + toolCalls += len(evt.ToolCalls) + } + + if toolCalls != 1 { + t.Fatalf("expected repaired missing-wrapper stream to emit one tool call, got %d events=%#v", toolCalls, events) + } + if strings.Contains(textContent.String(), "") { + t.Fatalf("expected repaired missing-wrapper stream not to leak xml text, got %q", textContent.String()) + } +}