diff --git a/README.MD b/README.MD index d3dc05d..636e693 100644 --- a/README.MD +++ b/README.MD @@ -363,6 +363,8 @@ cp opencode.json.example opencode.json 3. 未在 `tools` 声明中的工具名会被严格拒绝,不会下发为有效 tool call 4. `responses` 支持并执行 `tool_choice`(`auto`/`none`/`required`/强制函数);`required` 违规时非流式返回 `422`,流式返回 `response.failed` 5. 仅在通过策略校验后才会发出有效工具调用事件,避免错误工具名进入客户端执行链 +6. strict 模式下采用“可解析即拦截”:即使 tool JSON 前后混有 prose,只要结构可提取仍会拦截 tool_calls,剩余文本继续透传 +7. 当参数字符串无法可靠修复为对象时,会保留 `{"_raw":"..."}` 回退,避免 silent corruption ## 本地开发抓包工具 @@ -476,6 +478,23 @@ go run ./cmd/ds2api-tests \ npm ci --prefix webui && npm run build --prefix webui ``` +## 测试 + +详细测试指南请参阅 [TESTING.md](TESTING.md)。 + +### 快速测试命令 + +```bash +# 运行所有单元测试 +go test ./... + +# 运行 tool calls 相关测试(调试工具调用问题) +go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/ + +# 运行端到端测试 +./tests/scripts/run-live.sh +``` + ## Release 自动构建(GitHub Actions) 工作流文件:`.github/workflows/release-artifacts.yml` diff --git a/TESTING.md b/TESTING.md index c5e13e6..bf821fe 100644 --- a/TESTING.md +++ b/TESTING.md @@ -173,6 +173,57 @@ rg "" artifacts/testsuite//server.log go test ./... ``` +### 运行特定模块的单元测试 + +```bash +# 运行 tool calls 相关测试(推荐用于调试 tool call 解析问题) +go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/ + +# 运行单个测试用例 +go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ + +# 运行 format 相关测试 +go test -v ./internal/format/... + +# 运行 adapter 相关测试 +go test -v ./internal/adapter/openai/... +``` + +### 调试 Tool Call 问题 | Debugging Tool Call Issues + +当遇到 DeepSeek 工具调用解析问题时,可以使用以下方法: + +```bash +# 1. 运行 tool calls 相关的所有测试 +go test -v -run 'TestParseToolCalls|TestRepair' ./internal/util/ + +# 2. 查看测试输出中的详细调试信息 +go test -v -run TestParseToolCallsWithDeepSeekHallucination ./internal/util/ 2>&1 + +# 2.1 strict 模式(Go/JS)语义对齐检查:混合 prose + tool JSON 仍可拦截 +node --test tests/node/stream-tool-sieve.test.js + +# 2.2 Windows 路径与文本换行语义回归 +go test -v -run TestParseToolCallsWithInvalidBackslashes ./internal/util/ +go test -v -run TestParseToolCallsWithPathEscapesAndTextNewlines ./internal/util/ + +# 3. 检查具体测试用例的修复效果 +# 测试用例位于 internal/util/toolcalls_test.go,包含: +# - TestParseToolCallsWithDeepSeekHallucination: DeepSeek 典型幻觉输出 +# - TestRepairLooseJSONWithNestedObjects: 嵌套对象的方括号修复 +# - TestParseToolCallsWithMixedWindowsPaths: Windows 路径处理 +``` + +### 运行 Node.js 测试 + +```bash +# 运行 Node 测试 +node --test tests/node/stream-tool-sieve.test.js + +# 或使用脚本 +./tests/scripts/run-unit-node.sh +``` + ### 跑端到端测试(跳过 preflight) ```bash diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index 5cd16da..1a81660 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -98,11 +98,11 @@ func (s *chatStreamRuntime) sendDone() { func (s *chatStreamRuntime) finalize(finishReason string) { finalThinking := s.thinking.String() finalText := s.text.String() - detected := util.ParseStandaloneToolCalls(finalText, s.toolNames) - if len(detected) > 0 && !s.toolCallsDoneEmitted { + detected := util.ParseStandaloneToolCallsDetailed(finalText, s.toolNames) + if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted { finishReason = "tool_calls" delta := map[string]any{ - "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected, s.streamToolCallIDs), + "tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs), } if !s.firstChunkSent { delta["role"] = "assistant" @@ -158,7 +158,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) { } } - if len(detected) > 0 || s.toolCallsEmitted { + if len(detected.Calls) > 0 || s.toolCallsEmitted { finishReason = "tool_calls" } s.sendChunk(openaifmt.BuildChatStreamChunk( diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 37ebaf9..3adfd15 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -53,7 +53,7 @@ func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolCh if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block." + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY a JSON code block like this:\n```json\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n```\n\n【EXAMPLE】\nUser: Please check the weather in Beijing and Shanghai, and update my todo list.\nAssistant:\n```json\n{\"tool_calls\": [\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Beijing\"}},\n {\"name\": \"get_weather\", \"input\": {\"city\": \"Shanghai\"}},\n {\"name\": \"update_todo\", \"input\": {\"todos\": [{\"content\": \"Buy milk\"}, {\"content\": \"Write report\"}]}}\n]}\n```\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON code block. The response must start with ```json and end with ```.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block.\n5) JSON SYNTAX STRICTLY REQUIRED: All property names MUST be enclosed in double quotes (e.g., \"name\", not name).\n6) ARRAY FORMAT: If providing a list of items, you MUST enclose them in square brackets `[]` (e.g., \"todos\": [{\"item\": \"a\"}, {\"item\": \"b\"}]). DO NOT output comma-separated objects without brackets." if policy.Mode == util.ToolChoiceRequired { toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list." } diff --git a/internal/adapter/openai/tool_sieve_core.go b/internal/adapter/openai/tool_sieve_core.go index e7e41f8..72628e9 100644 --- a/internal/adapter/openai/tool_sieve_core.go +++ b/internal/adapter/openai/tool_sieve_core.go @@ -167,13 +167,28 @@ func findToolSegmentStart(s string) int { return -1 } lower := strings.ToLower(s) + keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"} offset := 0 for { - keyRel := strings.Index(lower[offset:], "tool_calls") - if keyRel < 0 { + bestKeyIdx := -1 + matchedKeyword := "" + + for _, kw := range keywords { + idx := strings.Index(lower[offset:], kw) + if idx >= 0 { + absIdx := offset + idx + if bestKeyIdx < 0 || absIdx < bestKeyIdx { + bestKeyIdx = absIdx + matchedKeyword = kw + } + } + } + + if bestKeyIdx < 0 { return -1 } - keyIdx := offset + keyRel + + keyIdx := bestKeyIdx start := strings.LastIndex(s[:keyIdx], "{") if start < 0 { start = keyIdx @@ -181,7 +196,7 @@ func findToolSegmentStart(s string) int { if !insideCodeFence(s[:start]) { return start } - offset = keyIdx + len("tool_calls") + offset = keyIdx + len(matchedKeyword) } } @@ -191,13 +206,22 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix return "", nil, "", false } lower := strings.ToLower(captured) - keyIdx := strings.Index(lower, "tool_calls") + + keyIdx := -1 + keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"} + for _, kw := range keywords { + idx := strings.Index(lower, kw) + if idx >= 0 && (keyIdx < 0 || idx < keyIdx) { + keyIdx = idx + } + } + if keyIdx < 0 { return "", nil, "", false } start := strings.LastIndex(captured[:keyIdx], "{") if start < 0 { - return "", nil, "", false + start = keyIdx } obj, end, ok := extractJSONObjectFrom(captured, start) if !ok { @@ -215,6 +239,9 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix // consume it to avoid leaking raw tool_calls JSON to user content. return prefixPart, nil, suffixPart, true } + // If it has obvious keywords but failed to parse even after loose repair, + // we still might want to intercept it if it looks like an attempt at tool call. + // For now, keep the original logic but rely on loose JSON repair. return captured, nil, "", true } return prefixPart, parsed.Calls, suffixPart, true diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go index 181e8b9..bdea9b5 100644 --- a/internal/format/openai/render_chat.go +++ b/internal/format/openai/render_chat.go @@ -8,15 +8,15 @@ import ( ) func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - detected := util.ParseStandaloneToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if strings.TrimSpace(finalThinking) != "" { messageObj["reasoning_content"] = finalThinking } - if len(detected) > 0 { + if len(detected.Calls) > 0 { finishReason = "tool_calls" - messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected) + messageObj["tool_calls"] = util.FormatOpenAIToolCalls(detected.Calls) messageObj["content"] = nil } diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go index 21df584..a3b37f0 100644 --- a/internal/format/openai/render_responses.go +++ b/internal/format/openai/render_responses.go @@ -13,12 +13,12 @@ import ( func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { // Strict mode: only standalone, structured tool-call payloads are treated // as executable tool calls. - detected := util.ParseStandaloneToolCalls(finalText, toolNames) + detected := util.ParseStandaloneToolCallsDetailed(finalText, toolNames) exposedOutputText := finalText output := make([]any, 0, 2) - if len(detected) > 0 { + if len(detected.Calls) > 0 { exposedOutputText = "" - output = append(output, toResponsesFunctionCallItems(detected)...) + output = append(output, toResponsesFunctionCallItems(detected.Calls)...) } else { content := make([]any, 0, 2) if finalThinking != "" { diff --git a/internal/js/helpers/stream-tool-sieve/sieve.js b/internal/js/helpers/stream-tool-sieve/sieve.js index c1b92a8..a3b7fd8 100644 --- a/internal/js/helpers/stream-tool-sieve/sieve.js +++ b/internal/js/helpers/stream-tool-sieve/sieve.js @@ -165,19 +165,34 @@ function findToolSegmentStart(s) { return -1; } const lower = s.toLowerCase(); + const keywords = ['tool_calls', 'function.name:', '[tool_call_history]']; let offset = 0; // eslint-disable-next-line no-constant-condition while (true) { - const keyIdx = lower.indexOf('tool_calls', offset); - if (keyIdx < 0) { + let bestKeyIdx = -1; + let matchedKeyword = ''; + + for (const kw of keywords) { + const idx = lower.indexOf(kw, offset); + if (idx >= 0) { + if (bestKeyIdx < 0 || idx < bestKeyIdx) { + bestKeyIdx = idx; + matchedKeyword = kw; + } + } + } + + if (bestKeyIdx < 0) { return -1; } + + const keyIdx = bestKeyIdx; const start = s.slice(0, keyIdx).lastIndexOf('{'); const candidateStart = start >= 0 ? start : keyIdx; if (!insideCodeFence(s.slice(0, candidateStart))) { return candidateStart; } - offset = keyIdx + 'tool_calls'.length; + offset = keyIdx + matchedKeyword.length; } } @@ -187,20 +202,28 @@ function consumeToolCapture(state, toolNames) { return { ready: false, prefix: '', calls: [], suffix: '' }; } const lower = captured.toLowerCase(); - const keyIdx = lower.indexOf('tool_calls'); + + let keyIdx = -1; + const keywords = ['tool_calls', 'function.name:', '[tool_call_history]']; + for (const kw of keywords) { + const idx = lower.indexOf(kw); + if (idx >= 0 && (keyIdx < 0 || idx < keyIdx)) { + keyIdx = idx; + } + } + if (keyIdx < 0) { return { ready: false, prefix: '', calls: [], suffix: '' }; } const start = captured.slice(0, keyIdx).lastIndexOf('{'); - if (start < 0) { - return { ready: false, prefix: '', calls: [], suffix: '' }; - } - const obj = extractJSONObjectFrom(captured, start); + const actualStart = start >= 0 ? start : keyIdx; + + const obj = extractJSONObjectFrom(captured, actualStart); if (!obj.ok) { return { ready: false, prefix: '', calls: [], suffix: '' }; } - const prefixPart = captured.slice(0, start); + const prefixPart = captured.slice(0, actualStart); const suffixPart = captured.slice(obj.end); if (insideCodeFence((state.recentTextTail || '') + prefixPart)) { @@ -212,16 +235,7 @@ function consumeToolCapture(state, toolNames) { }; } - if ((state.recentTextTail || '').trim() !== '' || prefixPart.trim() !== '' || suffixPart.trim() !== '') { - return { - ready: true, - prefix: captured, - calls: [], - suffix: '', - }; - } - - const parsed = parseStandaloneToolCallsDetailed(captured.slice(start, obj.end), toolNames); + const parsed = parseStandaloneToolCallsDetailed(captured.slice(actualStart, obj.end), toolNames); if (!Array.isArray(parsed.calls) || parsed.calls.length === 0) { if (parsed.sawToolCallSyntax && parsed.rejectedByPolicy) { return { diff --git a/internal/util/toolcalls_candidates.go b/internal/util/toolcalls_candidates.go index 4e8afc4..49db011 100644 --- a/internal/util/toolcalls_candidates.go +++ b/internal/util/toolcalls_candidates.go @@ -20,7 +20,7 @@ func buildToolCallCandidates(text string) []string { } } - // best-effort extraction around "tool_calls" key in mixed text payloads. + // best-effort extraction around tool call keywords in mixed text payloads. candidates = append(candidates, extractToolCallObjects(trimmed)...) // best-effort object slice: from first '{' to last '}' @@ -57,25 +57,65 @@ func extractToolCallObjects(text string) []string { lower := strings.ToLower(text) out := []string{} offset := 0 + keywords := []string{"tool_calls", "function.name:", "[tool_call_history]"} for { - idx := strings.Index(lower[offset:], "tool_calls") - if idx < 0 { + bestIdx := -1 + matchedKeyword := "" + for _, kw := range keywords { + idx := strings.Index(lower[offset:], kw) + if idx >= 0 { + absIdx := offset + idx + if bestIdx < 0 || absIdx < bestIdx { + bestIdx = absIdx + matchedKeyword = kw + } + } + } + + if bestIdx < 0 { break } - idx += offset - start := strings.LastIndex(text[:idx], "{") - for start >= 0 { + + idx := bestIdx + // Avoid backtracking too far to prevent OOM on malicious or very long strings + searchLimit := idx - 2000 + if searchLimit < offset { + searchLimit = offset + } + + start := strings.LastIndex(text[searchLimit:idx], "{") + if start >= 0 { + start += searchLimit + } + + if start < 0 { + offset = idx + len(matchedKeyword) + continue + } + + foundObj := false + for start >= searchLimit { candidate, end, ok := extractJSONObject(text, start) if ok { // Move forward to avoid repeatedly matching the same object. offset = end out = append(out, strings.TrimSpace(candidate)) + foundObj = true break } - start = strings.LastIndex(text[:start], "{") + // Try previous '{' + if start > searchLimit { + prevStart := strings.LastIndex(text[searchLimit:start], "{") + if prevStart >= 0 { + start = searchLimit + prevStart + continue + } + } + break } - if start < 0 { - offset = idx + len("tool_calls") + + if !foundObj { + offset = idx + len(matchedKeyword) } } return out @@ -88,7 +128,12 @@ func extractJSONObject(text string, start int) (string, int, bool) { depth := 0 quote := byte(0) escaped := false - for i := start; i < len(text); i++ { + // Limit scan length to avoid OOM on unclosed objects + maxLen := start + 50000 + if maxLen > len(text) { + maxLen = len(text) + } + for i := start; i < maxLen; i++ { ch := text[i] if quote != 0 { if escaped { diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index fb6d459..910c573 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -1,9 +1,6 @@ package util -import ( - "encoding/json" - "strings" -) +import "strings" type ParsedToolCall struct { Name string `json:"name"` @@ -83,31 +80,26 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) return result } result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed) - candidates := []string{trimmed} - for _, candidate := range candidates { - candidate = strings.TrimSpace(candidate) - if candidate == "" { - continue - } - parsed := parseToolCallsPayload(candidate) - if len(parsed) == 0 { - parsed = parseXMLToolCalls(candidate) - } - if len(parsed) == 0 { - parsed = parseMarkupToolCalls(candidate) - } - if len(parsed) == 0 { - parsed = parseTextKVToolCalls(candidate) - } - if len(parsed) > 0 { - result.SawToolCallSyntax = true - calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) - result.Calls = calls - result.RejectedToolNames = rejectedNames - result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 - return result - } + + parsed := parseToolCallsPayload(trimmed) + if len(parsed) == 0 { + parsed = parseXMLToolCalls(trimmed) } + if len(parsed) == 0 { + parsed = parseMarkupToolCalls(trimmed) + } + if len(parsed) == 0 { + parsed = parseTextKVToolCalls(trimmed) + } + if len(parsed) == 0 { + return result + } + + result.SawToolCallSyntax = true + calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + result.Calls = calls + result.RejectedToolNames = rejectedNames + result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 return result } @@ -140,6 +132,7 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin } return nil, rejected } + out := make([]ParsedToolCall, 0, len(parsed)) rejectedSet := map[string]struct{}{} rejected := make([]string, 0) @@ -168,25 +161,6 @@ func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCan return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical) } -func parseToolCallsPayload(payload string) []ParsedToolCall { - var decoded any - if err := json.Unmarshal([]byte(payload), &decoded); err != nil { - return nil - } - switch v := decoded.(type) { - case map[string]any: - if tc, ok := v["tool_calls"]; ok { - return parseToolCallList(tc) - } - if parsed, ok := parseToolCallItem(v); ok { - return []ParsedToolCall{parsed} - } - case []any: - return parseToolCallList(v) - } - return nil -} - func looksLikeToolCallSyntax(text string) bool { lower := strings.ToLower(text) return strings.Contains(lower, "tool_calls") || @@ -195,85 +169,3 @@ func looksLikeToolCallSyntax(text string) bool { strings.Contains(lower, " maxScanLen { + return s + } + + var out strings.Builder + out.Grow(len(s) + 8) + i := 0 + for i < len(s) { + if s[i] != ':' { + out.WriteByte(s[i]) + i++ + continue + } + out.WriteByte(':') + i++ + for i < len(s) && isJSONWhitespace(s[i]) { + out.WriteByte(s[i]) + i++ + } + if i >= len(s) || s[i] != '{' { + continue + } + + start := i + end := scanJSONObjectEnd(s, start) + if end < 0 { + out.WriteString(s[start:]) + break + } + cursor := end + next := skipJSONWhitespace(s, cursor) + if next >= len(s) || s[next] != ',' { + out.WriteString(s[start:end]) + i = end + continue + } + + seqEnd := end + hasMultiple := false + for { + comma := skipJSONWhitespace(s, seqEnd) + if comma >= len(s) || s[comma] != ',' { + break + } + objStart := skipJSONWhitespace(s, comma+1) + if objStart >= len(s) || s[objStart] != '{' { + break + } + objEnd := scanJSONObjectEnd(s, objStart) + if objEnd < 0 { + break + } + hasMultiple = true + seqEnd = objEnd + } + if !hasMultiple { + out.WriteString(s[start:end]) + i = end + continue + } + + out.WriteByte('[') + out.WriteString(s[start:seqEnd]) + out.WriteByte(']') + i = seqEnd + } + return out.String() +} + +func scanJSONObjectEnd(s string, start int) int { + depth := 0 + inString := false + escaped := false + for i := start; i < len(s); i++ { + c := s[i] + if inString { + if escaped { + escaped = false + continue + } + if c == '\\' { + escaped = true + continue + } + if c == '"' { + inString = false + } + continue + } + if c == '"' { + inString = true + continue + } + if c == '{' { + depth++ + continue + } + if c == '}' { + depth-- + if depth == 0 { + return i + 1 + } + } + } + return -1 +} + +func skipJSONWhitespace(s string, i int) int { + for i < len(s) && isJSONWhitespace(s[i]) { + i++ + } + return i +} + +func isJSONWhitespace(b byte) bool { + return b == ' ' || b == '\n' || b == '\r' || b == '\t' +} + +func isHex4(seq []rune) bool { + if len(seq) != 4 { + return false + } + for _, r := range seq { + if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) { + return false + } + } + return true +} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index 3ace015..10458df 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -1,6 +1,9 @@ package util -import "testing" +import ( + "strings" + "testing" +) func TestParseToolCalls(t *testing.T) { text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix` @@ -279,3 +282,242 @@ func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) { t.Fatalf("expected mismatched tags to be rejected, got %#v", calls) } } + +func TestRepairInvalidJSONBackslashes(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {`{"path": "C:\Users\name"}`, `{"path": "C:\\Users\\name"}`}, + {`{"cmd": "cd D:\git_codes"}`, `{"cmd": "cd D:\\git_codes"}`}, + {`{"text": "line1\nline2"}`, `{"text": "line1\nline2"}`}, + {`{"path": "D:\\back\\slash"}`, `{"path": "D:\\back\\slash"}`}, + {`{"unicode": "\u2705"}`, `{"unicode": "\u2705"}`}, + {`{"invalid_u": "\u123"}`, `{"invalid_u": "\\u123"}`}, + } + + for _, tt := range tests { + got := repairInvalidJSONBackslashes(tt.input) + if got != tt.expected { + t.Errorf("repairInvalidJSONBackslashes(%s) = %s; want %s", tt.input, got, tt.expected) + } + } +} + +func TestRepairLooseJSON(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {`{tool_calls: [{"name": "search", "input": {"q": "go"}}]}`, `{"tool_calls": [{"name": "search", "input": {"q": "go"}}]}`}, + {`{name: "search", input: {q: "go"}}`, `{"name": "search", "input": {"q": "go"}}`}, + } + + for _, tt := range tests { + got := RepairLooseJSON(tt.input) + if got != tt.expected { + t.Errorf("RepairLooseJSON(%s) = %s; want %s", tt.input, got, tt.expected) + } + } +} + +func TestParseToolCallsWithUnquotedKeys(t *testing.T) { + text := `这里是列表:{tool_calls: [{"name": "todowrite", "input": {"todos": "test"}}]}` + availableTools := []string{"todowrite"} + + parsed := ParseToolCalls(text, availableTools) + if len(parsed) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(parsed)) + } + if parsed[0].Name != "todowrite" { + t.Errorf("expected tool todowrite, got %s", parsed[0].Name) + } +} + +func TestParseToolCallsWithInvalidBackslashes(t *testing.T) { + // DeepSeek sometimes outputs Windows paths with single backslashes in JSON strings + // Note: using raw string to simulate what AI actually sends in the stream + text := `好的,执行以下命令:{"name": "execute_command", "input": "{\"command\": \"cd D:\git_codes && dir\"}"}` + availableTools := []string{"execute_command"} + + parsed := ParseToolCalls(text, availableTools) + // If standard JSON fails, buildToolCallCandidates should still extract the object, + // and parseToolCallsPayload should repair it. + if len(parsed) != 1 { + // If it still fails, let's see why + candidates := buildToolCallCandidates(text) + t.Logf("Candidates: %v", candidates) + t.Fatalf("expected 1 tool call, got %d", len(parsed)) + } + + cmd, ok := parsed[0].Input["command"].(string) + if !ok { + t.Fatalf("expected command string in input, got %v", parsed[0].Input) + } + + expected := "cd D:\\git_codes && dir" + if cmd != expected { + t.Errorf("expected command %q, got %q", expected, cmd) + } +} + +func TestParseToolCallsWithDeepSeekHallucination(t *testing.T) { + // 模拟 DeepSeek 典型的幻觉输出:未加引号的键名 + 包含 Windows 路径的嵌套 JSON 字符串 + 漏掉列表的方括号 + text := `检测到实施意图——实现经典算法。需在misc/目录创建Python文件。 +关键约束: +1. Windows UTF-8编码处理 +2. 必须用绝对路径导入 +3. 禁止write覆盖已有文件(misc/目录允许创建新文件) +将任务分解并委托: +- 研究8皇后算法模式(并行探索) +- 实现带可视化输出的解决方案(unspecified-high) +先创建todo列表追踪步骤。 +{tool_calls: [{"name": "todowrite", "input": {"todos": {"content": "研究8皇后问题算法模式(回溯法)和输出格式", "status": "pending", "priority": "high"}, {"content": "在misc/目录创建8皇后Python脚本,包含完整解决方案和可视化输出", "status": "pending", "priority": "high"}, {"content": "验证脚本正确性(运行测试)", "status": "pending", "priority": "medium"}}}]}` + + availableTools := []string{"todowrite"} + parsed := ParseToolCalls(text, availableTools) + + if len(parsed) != 1 { + cands := buildToolCallCandidates(text) + for i, c := range cands { + t.Logf("CAND %d: %s", i, c) + repaired := RepairLooseJSON(c) + t.Logf(" REPAIRED: %s", repaired) + } + t.Fatalf("expected 1 tool call, got %d. Candidates: %v", len(parsed), buildToolCallCandidates(text)) + } + + if parsed[0].Name != "todowrite" { + t.Errorf("expected tool name 'todowrite', got %q", parsed[0].Name) + } + + todos, ok := parsed[0].Input["todos"].([]any) + if !ok { + t.Fatalf("expected 'todos' to be parsed as a list, got %T: %#v", parsed[0].Input["todos"], parsed[0].Input["todos"]) + } + if len(todos) != 3 { + t.Errorf("expected 3 todo items, got %d", len(todos)) + } +} + +func TestParseToolCallsWithMixedWindowsPaths(t *testing.T) { + // 更复杂的案例:嵌套 JSON 字符串中的反斜杠未转义 + text := `关键约束: 1. Windows UTF-8编码处理 2. 必须用绝对路径导入 D:\git_codes\ds2api\misc +{tool_calls: [{"name": "write_file", "input": "{\"path\": \"D:\\git_codes\\ds2api\\misc\\queens.py\", \"content\": \"print('hello')\"}"}]}` + + availableTools := []string{"write_file"} + parsed := ParseToolCalls(text, availableTools) + + if len(parsed) != 1 { + t.Fatalf("expected 1 tool call from mixed text with paths, got %d", len(parsed)) + } + + path, _ := parsed[0].Input["path"].(string) + // 在解析后的 Go map 中,反斜杠应该被还原 + if !strings.Contains(path, "D:\\git_codes") && !strings.Contains(path, "D:/git_codes") { + t.Errorf("expected path to contain Windows style separators, got %q", path) + } +} + +func TestParseToolCallsWithPathEscapesAndTextNewlines(t *testing.T) { + text := `{"name":"write_file","input":"{\"content\":\"line1\\nline2\",\"path\":\"D:\\tmp\\a.txt\"}"}` + availableTools := []string{"write_file"} + parsed := ParseToolCalls(text, availableTools) + if len(parsed) != 1 { + t.Fatalf("expected 1 parsed tool call, got %d", len(parsed)) + } + + content, _ := parsed[0].Input["content"].(string) + path, _ := parsed[0].Input["path"].(string) + if !strings.Contains(content, "line1\nline2") { + t.Fatalf("expected content to preserve newline semantics, got %q", content) + } + if strings.ContainsAny(path, "\n\r\t") { + t.Fatalf("expected path to avoid control chars, got %q", path) + } + if !strings.Contains(path, `D:\tmp\a.txt`) { + t.Fatalf("expected path with literal backslashes, got %q", path) + } +} + +func TestRepairLooseJSONWithNestedObjects(t *testing.T) { + // 覆盖深层嵌套对象的方括号修复,避免 regex 单层能力带来的漂移。 + tests := []struct { + name string + input string + expected string + }{ + // 1. 单层嵌套对象(核心修复目标) + { + name: "单层嵌套 - 2个元素", + input: `"todos": {"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}`, + expected: `"todos": [{"content": "研究算法", "input": {"q": "8 queens"}}, {"content": "实现", "input": {"path": "queens.py"}}]`, + }, + // 2. 3个单层嵌套对象 + { + name: "3个单层嵌套对象", + input: `"items": {"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}`, + expected: `"items": [{"a": {"x":1}}, {"b": {"y":2}}, {"c": {"z":3}}]`, + }, + // 3. 混合嵌套:有些字段是对象,有些是原始值 + { + name: "混合嵌套 - 对象和原始值混合", + input: `"items": {"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}`, + expected: `"items": [{"name": "test", "config": {"timeout": 30}}, {"name": "test2", "config": {"timeout": 60}}]`, + }, + // 4. 4个嵌套对象(边界测试) + { + name: "4个嵌套对象", + input: `"todos": {"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}`, + expected: `"todos": [{"id": 1}, {"id": 2}, {"id": 3}, {"id": 4}]`, + }, + // 5. DeepSeek 典型幻觉:无空格逗号分隔 + { + name: "无空格逗号分隔", + input: `"results": {"name": "a"}, {"name": "b"}, {"name": "c"}`, + expected: `"results": [{"name": "a"}, {"name": "b"}, {"name": "c"}]`, + }, + // 6. 嵌套数组(数组在对象内,不是深层嵌套) + { + name: "对象内包含数组", + input: `"data": {"items": [1,2,3]}, {"items": [4,5,6]}`, + expected: `"data": [{"items": [1,2,3]}, {"items": [4,5,6]}]`, + }, + // 7. 真实的 DeepSeek 8皇后问题输出 + { + name: "DeepSeek 8皇后真实输出", + input: `"todos": {"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}`, + expected: `"todos": [{"content": "研究8皇后算法", "status": "pending"}, {"content": "实现Python脚本", "status": "pending"}, {"content": "验证结果", "status": "pending"}]`, + }, + // 8. 简单无嵌套对象(回归测试) + { + name: "简单无嵌套对象", + input: `"items": {"a": 1}, {"b": 2}`, + expected: `"items": [{"a": 1}, {"b": 2}]`, + }, + // 9. 更复杂的单层嵌套 + { + name: "复杂单层嵌套", + input: `"functions": {"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}`, + expected: `"functions": [{"name": "execute", "input": {"command": "ls"}}, {"name": "read", "input": {"file": "a.txt"}}]`, + }, + // 10. 5个嵌套对象 + { + name: "5个嵌套对象", + input: `"tasks": {"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}`, + expected: `"tasks": [{"id":1}, {"id":2}, {"id":3}, {"id":4}, {"id":5}]`, + }, + { + name: "深层嵌套对象", + input: `"todos": {"meta":{"a":{"b":1}},"content":"x"}, {"meta":{"a":{"b":2}},"content":"y"}`, + expected: `"todos": [{"meta":{"a":{"b":1}},"content":"x"}, {"meta":{"a":{"b":2}},"content":"y"}]`, + }, + } + + for _, tt := range tests { + got := RepairLooseJSON(tt.input) + if got != tt.expected { + t.Errorf("[%s] RepairLooseJSON with nested objects:\n input: %s\n got: %s\n expected: %s", tt.name, tt.input, got, tt.expected) + } + } +} diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index 61d72d6..8148245 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -259,28 +259,28 @@ test('sieve emits final tool_calls for split arguments payload without increment assert.deepEqual(finalCalls[0].input, { path: 'README.MD', mode: 'head' }); }); -test('sieve keeps tool json as text when leading prose exists (strict mode)', () => { +test('sieve intercepts tool json even when leading prose exists (strict mode)', () => { const events = runSieve( ['我将调用工具。', '{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'], ['read_file'], ); const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); const leakedText = collectText(events); - assert.equal(hasTool, false); + assert.equal(hasTool, true); assert.equal(leakedText.includes('我将调用工具。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); -test('sieve keeps same-chunk trailing prose payload as text in strict mode', () => { +test('sieve intercepts same-chunk payload once tool json is complete in strict mode', () => { const events = runSieve( ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}然后继续解释。'], ['read_file'], ); const hasTool = events.some((evt) => (evt.type === 'tool_calls' && evt.calls?.length > 0) || (evt.type === 'tool_call_deltas' && evt.deltas?.length > 0)); const leakedText = collectText(events); - assert.equal(hasTool, false); - assert.equal(leakedText.includes('然后继续解释。'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + assert.equal(hasTool, true); + assert.equal(leakedText.includes('然后继续解释。'), false); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); test('formatOpenAIStreamToolCalls reuses ids with the same idStore', () => { diff --git a/tests/repair_json_tool.go b/tests/repair_json_tool.go new file mode 100644 index 0000000..7abf952 --- /dev/null +++ b/tests/repair_json_tool.go @@ -0,0 +1,77 @@ +package main + +import ( + "fmt" + "strings" +) + +func repairInvalidJSONBackslashes(s string) string { + if !strings.Contains(s, "\\") { + return s + } + var out strings.Builder + out.Grow(len(s) + 10) + runes := []rune(s) + for i := 0; i < len(runes); i++ { + if runes[i] == '\\' { + if i+1 < len(runes) { + next := runes[i+1] + switch next { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': + out.WriteRune('\\') + out.WriteRune(next) + i++ + continue + case 'u': + if i+5 < len(runes) { + isHex := true + for j := 1; j <= 4; j++ { + r := runes[i+1+j] + if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) { + isHex = false + break + } + } + if isHex { + out.WriteRune('\\') + out.WriteRune('u') + for j := 1; j <= 4; j++ { + out.WriteRune(runes[i+1+j]) + } + i += 5 + continue + } + } + } + } + // Not a valid escape sequence, double it + out.WriteString("\\\\") + } else { + out.WriteRune(runes[i]) + } + } + return out.String() +} + +func main() { + tests := []struct { + input string + expected string + }{ + {`{"path": "C:\Users\name"}`, `{"path": "C:\\Users\\name"}`}, + {`{"cmd": "cd D:\git_codes"}`, `{"cmd": "cd D:\\git_codes"}`}, + {`{"text": "line1\nline2"}`, `{"text": "line1\nline2"}`}, + {`{"path": "D:\\back\\slash"}`, `{"path": "D:\\back\\slash"}`}, + {`{"unicode": "\u2705"}`, `{"unicode": "\u2705"}`}, + {`{"invalid_u": "\u123"}`, `{"invalid_u": "\\u123"}`}, + } + + for _, tt := range tests { + got := repairInvalidJSONBackslashes(tt.input) + if got != tt.expected { + fmt.Printf("FAIL: input=%s\n got=%s\n exp=%s\n", tt.input, got, tt.expected) + } else { + fmt.Printf("PASS: input=%s\n", tt.input) + } + } +}