diff --git a/README.MD b/README.MD index 0ce3ec7..d3dc05d 100644 --- a/README.MD +++ b/README.MD @@ -106,6 +106,14 @@ flowchart LR 可通过配置中的 `claude_mapping` 或 `claude_model_mapping` 覆盖映射关系。 另外,`/anthropic/v1/models` 现已包含 Claude 1.x/2.x/3.x/4.x 历史模型 ID 与常见别名,便于旧客户端直接兼容。 + +#### Claude Code 接入避坑(实测) + +- `ANTHROPIC_BASE_URL` 推荐直接指向 DS2API 根地址(例如 `http://127.0.0.1:5001`),Claude Code 会请求 `/v1/messages?beta=true`。 +- `ANTHROPIC_API_KEY` 需要与 `config.json` 中 `keys` 一致;建议同时保留常规 key 与 `sk-ant-*` 形态 key,兼容不同客户端校验习惯。 +- 若系统设置了代理,建议对 DS2API 地址配置 `NO_PROXY=127.0.0.1,localhost,<你的主机IP>`,避免本地回环请求被代理拦截。 +- 如遇“工具调用输出成文本、未执行”问题,请升级到包含 Claude 工具调用多格式解析(JSON/XML/ANTML/invoke)的版本。 + ### Gemini 接口 Gemini 适配器将模型名通过 `model_aliases` 或内置规则映射到 DeepSeek 原生模型,支持 `generateContent` 和 `streamGenerateContent` 两种调用方式,并完整支持 Tool Calling(`functionDeclarations` → `functionCall` 输出)。 diff --git a/README.en.md b/README.en.md index 72d8bd8..1c07c23 100644 --- a/README.en.md +++ b/README.en.md @@ -106,6 +106,14 @@ flowchart LR Override mapping via `claude_mapping` or `claude_model_mapping` in config. In addition, `/anthropic/v1/models` now includes historical Claude 1.x/2.x/3.x/4.x IDs and common aliases for legacy client compatibility. + +#### Claude Code integration pitfalls (validated) + +- Set `ANTHROPIC_BASE_URL` to the DS2API root URL (for example `http://127.0.0.1:5001`). Claude Code sends requests to `/v1/messages?beta=true`. +- `ANTHROPIC_API_KEY` must match an entry in `keys` from `config.json`. Keeping both a regular key and an `sk-ant-*` style key improves client compatibility. +- If your environment has proxy variables, set `NO_PROXY=127.0.0.1,localhost,` for DS2API to avoid proxy interception of local traffic. +- If tool calls are rendered as plain text and not executed, upgrade to a build that includes multi-format Claude tool-call parsing (JSON/XML/ANTML/invoke). + ### Gemini Endpoint The Gemini adapter maps model names to DeepSeek native models via `model_aliases` or built-in heuristics, supporting both `generateContent` and `streamGenerateContent` call patterns with full Tool Calling support (`functionDeclarations` → `functionCall` output). diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go index ebce879..dda425a 100644 --- a/internal/adapter/claude/handler_stream_test.go +++ b/internal/adapter/claude/handler_stream_test.go @@ -315,3 +315,78 @@ func asString(v any) string { s, _ := v.(string) return s } + +func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) { + tests := []struct { + name string + payload string + }{ + {name: "xml_tool_call", payload: `Bashpwd`}, + {name: "xml_json_tool_call", payload: `{"tool":"Bash","params":{"command":"pwd"}}`}, + {name: "nested_tool_tag_style", payload: `pwd`}, + {name: "function_tag_style", payload: `Bashpwd`}, + {name: "antml_argument_style", payload: `pwd`}, + {name: "antml_function_attr_parameters", payload: `{"command":"pwd"}`}, + {name: "invoke_parameter_style", payload: `pwd`}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"}) + + frames := parseClaudeFrames(t, rec.Body.String()) + foundToolUse := false + for _, f := range findClaudeFrames(frames, "content_block_start") { + contentBlock, _ := f.Payload["content_block"].(map[string]any) + if contentBlock["type"] == "tool_use" { + foundToolUse = true + break + } + } + if !foundToolUse { + t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String()) + } + }) + } +} + +func TestHandleClaudeStreamRealtimeDoesNotStopOnUnclosedFencedToolExample(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + "data: {\"p\":\"response/content\",\"v\":\"Here is an example:\\n```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"Bash\\\",\\\"input\\\":{\\\"command\\\":\\\"pwd\\\"}}]}\"}", + "data: {\"p\":\"response/content\",\"v\":\"\\n```\\nDo not execute it.\"}", + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "show example only"}}, false, false, []string{"Bash"}) + + frames := parseClaudeFrames(t, rec.Body.String()) + for _, f := range findClaudeFrames(frames, "content_block_start") { + contentBlock, _ := f.Payload["content_block"].(map[string]any) + if contentBlock["type"] == "tool_use" { + t.Fatalf("unexpected tool_use for fenced example, body=%s", rec.Body.String()) + } + } + + foundEndTurn := false + for _, f := range findClaudeFrames(frames, "message_delta") { + delta, _ := f.Payload["delta"].(map[string]any) + if delta["stop_reason"] == "end_turn" { + foundEndTurn = true + break + } + } + if !foundEndTurn { + t.Fatalf("expected stop_reason=end_turn, body=%s", rec.Body.String()) + } +} diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index f5b0ad5..b6c009a 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -125,8 +125,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) { if !containsStr(prompt, "Search the web") { t.Fatalf("expected description in prompt") } - if !containsStr(prompt, "tool_calls") { - t.Fatalf("expected tool_calls instruction in prompt") + if !containsStr(prompt, "tool_use") { + t.Fatalf("expected tool_use instruction in prompt") + } + if containsStr(prompt, "tool_calls") { + t.Fatalf("expected prompt to avoid tool_calls JSON instruction") } } diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index 3728ffb..2f0c08a 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -51,7 +51,7 @@ func buildClaudeToolPrompt(tools []any) string { parts = append(parts, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, schema)) } parts = append(parts, - "When you need to use tools, you can call multiple tools in one response. Output ONLY JSON like {\"tool_calls\":[{\"name\":\"tool\",\"input\":{}}]}", + "When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.", "History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.", "After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.", ) diff --git a/internal/adapter/claude/stream_runtime_core.go b/internal/adapter/claude/stream_runtime_core.go index cb24bdd..fead90a 100644 --- a/internal/adapter/claude/stream_runtime_core.go +++ b/internal/adapter/claude/stream_runtime_core.go @@ -8,6 +8,7 @@ import ( "ds2api/internal/sse" streamengine "ds2api/internal/stream" + "ds2api/internal/util" ) type claudeStreamRuntime struct { @@ -116,6 +117,18 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse s.text.WriteString(p.Text) if s.bufferToolContent { + if hasUnclosedCodeFence(s.text.String()) { + continue + } + detected := util.ParseToolCalls(s.text.String(), s.toolNames) + if len(detected) > 0 { + s.finalize("tool_use") + return streamengine.ParsedDecision{ + ContentSeen: true, + Stop: true, + StopReason: streamengine.StopReason("tool_use_detected"), + } + } continue } s.closeThinkingBlock() @@ -144,3 +157,7 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse return streamengine.ParsedDecision{ContentSeen: contentSeen} } + +func hasUnclosedCodeFence(text string) bool { + return strings.Count(text, "```")%2 == 1 +} diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index 724cb9f..c4f4c4a 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -78,7 +78,7 @@ func formatAssistantToolCallsForPrompt(msg map[string]any, traceID string) strin args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) } if name == "" { - name = "unknown" + continue } if args == "" { args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index ecb3bbd..c9c967d 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -194,6 +194,29 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t * } } + +func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsMissingNameAreDropped(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_missing_name", + "type": "function", + "function": map[string]any{ + "arguments": `{"path":"README.MD"}`, + }, + }, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 0 { + t.Fatalf("expected nameless assistant tool_calls to be dropped, got %#v", normalized) + } +} + func TestNormalizeOpenAIMessagesForPrompt_AssistantNilContentDoesNotInjectNullLiteral(t *testing.T) { raw := []any{ map[string]any{ diff --git a/internal/adapter/openai/responses_stream_runtime_toolcalls.go b/internal/adapter/openai/responses_stream_runtime_toolcalls.go index 9947cbd..ad354d4 100644 --- a/internal/adapter/openai/responses_stream_runtime_toolcalls.go +++ b/internal/adapter/openai/responses_stream_runtime_toolcalls.go @@ -94,6 +94,16 @@ func (s *responsesStreamRuntime) closeMessageItem() { outputIndex := s.ensureMessageOutputIndex() text := s.visibleText.String() if s.messagePartAdded { + s.sendEvent( + "response.output_text.done", + openaifmt.BuildResponsesTextDonePayload( + s.responseID, + itemID, + outputIndex, + 0, + text, + ), + ) s.sendEvent( "response.content_part.done", openaifmt.BuildResponsesContentPartDonePayload( diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 6186461..f62ff13 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -226,6 +226,40 @@ func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing. } } +func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) { + h := &Handler{} + req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + rec := httptest.NewRecorder() + + sseLine := func(v string) string { + b, _ := json.Marshal(map[string]any{ + "p": "response/content", + "v": v, + }) + return "data: " + string(b) + "\n" + } + + streamBody := sseLine("hello") + "data: [DONE]\n" + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(streamBody)), + } + + h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, util.DefaultToolChoicePolicy(), "") + body := rec.Body.String() + if !strings.Contains(body, "event: response.output_text.done") { + t.Fatalf("expected response.output_text.done payload, body=%s", body) + } + textDoneIdx := strings.Index(body, "event: response.output_text.done") + partDoneIdx := strings.Index(body, "event: response.content_part.done") + if textDoneIdx < 0 || partDoneIdx < 0 { + t.Fatalf("expected output_text.done + content_part.done, body=%s", body) + } + if textDoneIdx > partDoneIdx { + t.Fatalf("expected output_text.done before content_part.done, body=%s", body) + } +} + func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) { h := &Handler{} req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) diff --git a/internal/format/openai/render_stream_events.go b/internal/format/openai/render_stream_events.go index dc13231..1e7cd09 100644 --- a/internal/format/openai/render_stream_events.go +++ b/internal/format/openai/render_stream_events.go @@ -71,6 +71,19 @@ func BuildResponsesTextDeltaPayload(responseID, itemID string, outputIndex, cont } } + +func BuildResponsesTextDonePayload(responseID, itemID string, outputIndex, contentIndex int, text string) map[string]any { + return map[string]any{ + "type": "response.output_text.done", + "id": responseID, + "response_id": responseID, + "item_id": itemID, + "output_index": outputIndex, + "content_index": contentIndex, + "text": text, + } +} + func BuildResponsesReasoningDeltaPayload(responseID, delta string) map[string]any { return map[string]any{ "type": "response.reasoning.delta", diff --git a/internal/js/helpers/stream-tool-sieve/sieve.js b/internal/js/helpers/stream-tool-sieve/sieve.js index 1f1fc59..c1b92a8 100644 --- a/internal/js/helpers/stream-tool-sieve/sieve.js +++ b/internal/js/helpers/stream-tool-sieve/sieve.js @@ -21,22 +21,14 @@ function processToolSieveChunk(state, chunk, toolNames) { } const events = []; - if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) { - const pending = state.pending || ''; - if (pending.trim() !== '') { - const content = (state.pendingToolRaw || '') + pending; - state.pending = ''; - state.pendingToolRaw = ''; - state.pendingToolCalls = []; - noteText(state, content); - events.push({ type: 'text', text: content }); - } else { - return events; - } - } - // eslint-disable-next-line no-constant-condition while (true) { + if (Array.isArray(state.pendingToolCalls) && state.pendingToolCalls.length > 0) { + events.push({ type: 'tool_calls', calls: state.pendingToolCalls }); + state.pendingToolRaw = ''; + state.pendingToolCalls = []; + continue; + } if (state.capturing) { if (state.pending) { state.capture += state.pending; diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index c3e76bf..53eac8e 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -39,6 +39,9 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa var parsed []ParsedToolCall for _, candidate := range candidates { tc := parseToolCallsPayload(candidate) + if len(tc) == 0 { + tc = parseXMLToolCalls(candidate) + } if len(tc) == 0 { tc = parseMarkupToolCalls(candidate) } @@ -49,7 +52,11 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa } } if len(parsed) == 0 { - return result + parsed = parseXMLToolCalls(text) + if len(parsed) == 0 { + return result + } + result.SawToolCallSyntax = true } calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) @@ -80,6 +87,9 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) continue } parsed := parseToolCallsPayload(candidate) + if len(parsed) == 0 { + parsed = parseXMLToolCalls(candidate) + } if len(parsed) == 0 { parsed = parseMarkupToolCalls(candidate) } diff --git a/internal/util/toolcalls_parse_markup.go b/internal/util/toolcalls_parse_markup.go new file mode 100644 index 0000000..b7b2908 --- /dev/null +++ b/internal/util/toolcalls_parse_markup.go @@ -0,0 +1,219 @@ +package util + +import ( + "encoding/json" + "encoding/xml" + "regexp" + "strings" +) + +var xmlToolCallPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) +var functionCallPattern = regexp.MustCompile(`(?is)\s*([^<]+?)\s*`) +var functionParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) +var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*`) +var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*`) +var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*`) +var invokeCallPattern = regexp.MustCompile(`(?is)(.*?)`) +var invokeParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) + +func parseXMLToolCalls(text string) []ParsedToolCall { + matches := xmlToolCallPattern.FindAllString(text, -1) + out := make([]ParsedToolCall, 0, len(matches)+1) + for _, block := range matches { + call, ok := parseSingleXMLToolCall(block) + if !ok { + continue + } + out = append(out, call) + } + if len(out) > 0 { + return out + } + if call, ok := parseFunctionCallTagStyle(text); ok { + return []ParsedToolCall{call} + } + if calls := parseAntmlFunctionCallStyles(text); len(calls) > 0 { + return calls + } + if call, ok := parseInvokeFunctionCallStyle(text); ok { + return []ParsedToolCall{call} + } + return nil +} + +func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { + inner := strings.TrimSpace(block) + inner = strings.TrimPrefix(inner, "") + inner = strings.TrimSuffix(inner, "") + inner = strings.TrimSpace(inner) + if strings.HasPrefix(inner, "{") { + var payload map[string]any + if err := json.Unmarshal([]byte(inner), &payload); err == nil { + name := strings.TrimSpace(asString(payload["tool"])) + if name == "" { + name = strings.TrimSpace(asString(payload["tool_name"])) + } + if name != "" { + input := map[string]any{} + if params, ok := payload["params"].(map[string]any); ok { + input = params + } else if params, ok := payload["parameters"].(map[string]any); ok { + input = params + } + return ParsedToolCall{Name: name, Input: input}, true + } + } + } + + dec := xml.NewDecoder(strings.NewReader(block)) + name := "" + params := map[string]any{} + inParams := false + inTool := false + for { + tok, err := dec.Token() + if err != nil { + break + } + switch t := tok.(type) { + case xml.StartElement: + tag := strings.ToLower(t.Name.Local) + switch tag { + case "tool": + inTool = true + for _, attr := range t.Attr { + if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" { + name = strings.TrimSpace(attr.Value) + } + } + case "parameters": + inParams = true + case "tool_name", "name": + var v string + if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" { + name = strings.TrimSpace(v) + } + default: + if inParams || inTool { + var v string + if err := dec.DecodeElement(&v, &t); err == nil { + params[t.Name.Local] = strings.TrimSpace(v) + } + } + } + case xml.EndElement: + tag := strings.ToLower(t.Name.Local) + if tag == "parameters" { + inParams = false + } + if tag == "tool" { + inTool = false + } + } + } + if strings.TrimSpace(name) == "" { + return ParsedToolCall{}, false + } + return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true +} + +func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { + m := functionCallPattern.FindStringSubmatch(text) + if len(m) < 2 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + input := map[string]any{} + for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) { + if len(pm) < 3 { + continue + } + key := strings.TrimSpace(pm[1]) + val := strings.TrimSpace(pm[2]) + if key != "" { + input[key] = val + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func parseAntmlFunctionCallStyles(text string) []ParsedToolCall { + matches := antmlFunctionCallPattern.FindAllStringSubmatch(text, -1) + if len(matches) == 0 { + return nil + } + out := make([]ParsedToolCall, 0, len(matches)) + for _, m := range matches { + if call, ok := parseSingleAntmlFunctionCallMatch(m); ok { + out = append(out, call) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) { + if len(m) < 3 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + body := strings.TrimSpace(m[2]) + input := map[string]any{} + if strings.HasPrefix(body, "{") { + if err := json.Unmarshal([]byte(body), &input); err == nil { + return ParsedToolCall{Name: name, Input: input}, true + } + } + if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 { + if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil { + return ParsedToolCall{Name: name, Input: input}, true + } + } + for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) { + if len(am) < 3 { + continue + } + k := strings.TrimSpace(am[1]) + v := strings.TrimSpace(am[2]) + if k != "" { + input[k] = v + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { + m := invokeCallPattern.FindStringSubmatch(text) + if len(m) < 3 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + input := map[string]any{} + for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) { + if len(pm) < 3 { + continue + } + k := strings.TrimSpace(pm[1]) + v := strings.TrimSpace(pm[2]) + if k != "" { + input[k] = v + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func asString(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index e830092..3ace015 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -138,6 +138,140 @@ func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) { } } +func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) { + text := `Bashpwdshow cwd` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) { + text := `Bashpwd` + res := ParseToolCallsDetailed(text, []string{"bash"}) + if !res.SawToolCallSyntax { + t.Fatalf("expected SawToolCallSyntax=true, got %#v", res) + } + if len(res.Calls) != 1 { + t.Fatalf("expected one parsed call, got %#v", res) + } +} + +func TestParseToolCallsSupportsClaudeXMLJSONToolCall(t *testing.T) { + text := `{"tool":"Bash","params":{"command":"pwd","description":"show cwd"}}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsFunctionCallTagStyle(t *testing.T) { + text := `Bashls -lalist` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "ls -la" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsAntmlFunctionCallStyle(t *testing.T) { + text := `{"command":"pwd","description":"x"}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsAntmlArgumentStyle(t *testing.T) { + text := `pwdx` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) { + text := `pwdd` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) { + text := `pwdshow cwd` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) { + text := `{"command":"pwd"}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsMultipleAntmlFunctionCalls(t *testing.T) { + text := `{"command":"pwd"}{"file_path":"README.md"}` + calls := ParseToolCalls(text, []string{"bash", "read"}) + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %#v", calls) + } + if calls[0].Name != "bash" || calls[1].Name != "read" { + t.Fatalf("expected canonical names [bash read], got %#v", calls) + } +} + func TestParseToolCallsDoesNotAcceptMismatchedMarkupTags(t *testing.T) { text := `read_file{"path":"README.md"}` calls := ParseToolCalls(text, []string{"read_file"}) diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index 20c00b8..f71279c 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -109,7 +109,23 @@ test('parseStandaloneToolCalls ignores fenced code block tool_call examples', () assert.equal(calls.length, 0); }); -test('sieve keeps late key convergence payload as plain text in strict mode', () => { + +test('sieve emits tool_calls in the same chunk processing tick once payload is complete', () => { + const state = createToolSieveState(); + const first = processToolSieveChunk(state, '{"', ['read_file']); + const second = processToolSieveChunk( + state, + 'tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}', + ['read_file'], + ); + const firstCalls = first.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []); + const secondCalls = second.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []); + assert.equal(firstCalls.length, 0); + assert.equal(secondCalls.length, 1); + assert.equal(secondCalls[0].name, 'read_file'); +}); + +test('sieve emits tool_calls when late key convergence forms a complete payload', () => { const events = runSieve( [ '{"', @@ -119,12 +135,11 @@ test('sieve keeps late key convergence payload as plain text in strict mode', () ['read_file'], ); const leakedText = collectText(events); - const hasToolCall = events.some((evt) => evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0); - const hasToolDelta = events.some((evt) => evt.type === 'tool_call_deltas' && Array.isArray(evt.deltas) && evt.deltas.length > 0); - assert.equal(hasToolCall || hasToolDelta, false); - assert.equal(leakedText.includes('{'), true); - assert.equal(leakedText.toLowerCase().includes('tool_calls'), true); + const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []); + assert.equal(finalCalls.length, 1); + assert.equal(finalCalls[0].name, 'read_file'); assert.equal(leakedText.includes('后置正文C。'), true); + assert.equal(leakedText.toLowerCase().includes('tool_calls'), false); }); test('sieve keeps embedded invalid tool-like json as normal text to avoid stream stalls', () => {