From 25e40cc3a6f229d3b5b0678fea020dfe770e6f0a Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sat, 7 Mar 2026 17:27:29 +0800 Subject: [PATCH] Fix quality gate and expand Claude tool-call format compatibility --- .../adapter/claude/handler_stream_test.go | 42 ++++ internal/util/toolcalls_parse.go | 173 --------------- internal/util/toolcalls_parse_markup.go | 203 ++++++++++++++++++ internal/util/toolcalls_test.go | 28 +++ 4 files changed, 273 insertions(+), 173 deletions(-) create mode 100644 internal/util/toolcalls_parse_markup.go diff --git a/internal/adapter/claude/handler_stream_test.go b/internal/adapter/claude/handler_stream_test.go index ebce879..f3a8b6e 100644 --- a/internal/adapter/claude/handler_stream_test.go +++ b/internal/adapter/claude/handler_stream_test.go @@ -315,3 +315,45 @@ func asString(v any) string { s, _ := v.(string) return s } + +func TestHandleClaudeStreamRealtimeToolSafetyAcrossStructuredFormats(t *testing.T) { + tests := []struct { + name string + payload string + }{ + {name: "xml_tool_call", payload: `Bashpwd`}, + {name: "xml_json_tool_call", payload: `{"tool":"Bash","params":{"command":"pwd"}}`}, + {name: "nested_tool_tag_style", payload: `pwd`}, + {name: "function_tag_style", payload: `Bashpwd`}, + {name: "antml_argument_style", payload: `pwd`}, + {name: "antml_function_attr_parameters", payload: `{"command":"pwd"}`}, + {name: "invoke_parameter_style", payload: `pwd`}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h := &Handler{} + resp := makeClaudeSSEHTTPResponse( + `data: {"p":"response/content","v":"`+strings.ReplaceAll(tc.payload, `"`, `\"`)+`"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil) + + h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"Bash"}) + + frames := parseClaudeFrames(t, rec.Body.String()) + foundToolUse := false + for _, f := range findClaudeFrames(frames, "content_block_start") { + contentBlock, _ := f.Payload["content_block"].(map[string]any) + if contentBlock["type"] == "tool_use" { + foundToolUse = true + break + } + } + if !foundToolUse { + t.Fatalf("expected tool_use block for format %s, body=%s", tc.name, rec.Body.String()) + } + }) + } +} diff --git a/internal/util/toolcalls_parse.go b/internal/util/toolcalls_parse.go index 2d9034a..42962bb 100644 --- a/internal/util/toolcalls_parse.go +++ b/internal/util/toolcalls_parse.go @@ -2,19 +2,11 @@ package util import ( "encoding/json" - "encoding/xml" "regexp" "strings" ) var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`) -var xmlToolCallPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) -var functionCallPattern = regexp.MustCompile(`(?is)\s*([^<]+?)\s*`) -var functionParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) -var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*name="([^"]+)"[^>]*>\s*(.*?)\s*`) -var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*`) -var invokeCallPattern = regexp.MustCompile(`(?is)(.*?)`) -var invokeParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) type ParsedToolCall struct { Name string `json:"name"` @@ -279,168 +271,3 @@ func parseToolCallInput(v any) map[string]any { return map[string]any{} } } - -func parseXMLToolCalls(text string) []ParsedToolCall { - matches := xmlToolCallPattern.FindAllString(text, -1) - out := make([]ParsedToolCall, 0, len(matches)+1) - for _, block := range matches { - call, ok := parseSingleXMLToolCall(block) - if !ok { - continue - } - out = append(out, call) - } - if len(out) > 0 { - return out - } - if call, ok := parseFunctionCallTagStyle(text); ok { - return []ParsedToolCall{call} - } - if call, ok := parseAntmlFunctionCallStyle(text); ok { - return []ParsedToolCall{call} - } - if call, ok := parseInvokeFunctionCallStyle(text); ok { - return []ParsedToolCall{call} - } - return nil -} - -func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { - inner := strings.TrimSpace(block) - inner = strings.TrimPrefix(inner, "") - inner = strings.TrimSuffix(inner, "") - inner = strings.TrimSpace(inner) - if strings.HasPrefix(inner, "{") { - var payload map[string]any - if err := json.Unmarshal([]byte(inner), &payload); err == nil { - name := strings.TrimSpace(asString(payload["tool"])) - if name == "" { - name = strings.TrimSpace(asString(payload["tool_name"])) - } - if name != "" { - input := map[string]any{} - if params, ok := payload["params"].(map[string]any); ok { - input = params - } else if params, ok := payload["parameters"].(map[string]any); ok { - input = params - } - return ParsedToolCall{Name: name, Input: input}, true - } - } - } - - dec := xml.NewDecoder(strings.NewReader(block)) - name := "" - params := map[string]any{} - inParams := false - for { - tok, err := dec.Token() - if err != nil { - break - } - start, ok := tok.(xml.StartElement) - if !ok { - continue - } - switch strings.ToLower(start.Name.Local) { - case "parameters": - inParams = true - case "tool_name", "name": - var v string - if err := dec.DecodeElement(&v, &start); err == nil && strings.TrimSpace(v) != "" { - name = strings.TrimSpace(v) - } - default: - if inParams { - var v string - if err := dec.DecodeElement(&v, &start); err == nil { - params[start.Name.Local] = strings.TrimSpace(v) - } - } - } - } - if strings.TrimSpace(name) == "" { - return ParsedToolCall{}, false - } - return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true -} - -func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { - m := functionCallPattern.FindStringSubmatch(text) - if len(m) < 2 { - return ParsedToolCall{}, false - } - name := strings.TrimSpace(m[1]) - if name == "" { - return ParsedToolCall{}, false - } - input := map[string]any{} - for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) { - if len(pm) < 3 { - continue - } - key := strings.TrimSpace(pm[1]) - val := strings.TrimSpace(pm[2]) - if key != "" { - input[key] = val - } - } - return ParsedToolCall{Name: name, Input: input}, true -} - -func parseAntmlFunctionCallStyle(text string) (ParsedToolCall, bool) { - m := antmlFunctionCallPattern.FindStringSubmatch(text) - if len(m) < 3 { - return ParsedToolCall{}, false - } - name := strings.TrimSpace(m[1]) - if name == "" { - return ParsedToolCall{}, false - } - body := strings.TrimSpace(m[2]) - input := map[string]any{} - if strings.HasPrefix(body, "{") { - if err := json.Unmarshal([]byte(body), &input); err == nil { - return ParsedToolCall{Name: name, Input: input}, true - } - } - for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) { - if len(am) < 3 { - continue - } - k := strings.TrimSpace(am[1]) - v := strings.TrimSpace(am[2]) - if k != "" { - input[k] = v - } - } - return ParsedToolCall{Name: name, Input: input}, true -} - -func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { - m := invokeCallPattern.FindStringSubmatch(text) - if len(m) < 3 { - return ParsedToolCall{}, false - } - name := strings.TrimSpace(m[1]) - if name == "" { - return ParsedToolCall{}, false - } - input := map[string]any{} - for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) { - if len(pm) < 3 { - continue - } - k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(pm[2]) - if k != "" { - input[k] = v - } - } - return ParsedToolCall{Name: name, Input: input}, true -} - -func asString(v any) string { - s, _ := v.(string) - return s -} diff --git a/internal/util/toolcalls_parse_markup.go b/internal/util/toolcalls_parse_markup.go new file mode 100644 index 0000000..262fd59 --- /dev/null +++ b/internal/util/toolcalls_parse_markup.go @@ -0,0 +1,203 @@ +package util + +import ( + "encoding/json" + "encoding/xml" + "regexp" + "strings" +) + +var xmlToolCallPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) +var functionCallPattern = regexp.MustCompile(`(?is)\s*([^<]+?)\s*`) +var functionParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) +var antmlFunctionCallPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?function_call[^>]*(?:name|function)="([^"]+)"[^>]*>\s*(.*?)\s*`) +var antmlArgumentPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?argument\s+name="([^"]+)"\s*>\s*(.*?)\s*`) +var antmlParametersPattern = regexp.MustCompile(`(?is)<(?:[a-z0-9_]+:)?parameters\s*>\s*(\{.*?\})\s*`) +var invokeCallPattern = regexp.MustCompile(`(?is)(.*?)`) +var invokeParamPattern = regexp.MustCompile(`(?is)\s*(.*?)\s*`) + +func parseXMLToolCalls(text string) []ParsedToolCall { + matches := xmlToolCallPattern.FindAllString(text, -1) + out := make([]ParsedToolCall, 0, len(matches)+1) + for _, block := range matches { + call, ok := parseSingleXMLToolCall(block) + if !ok { + continue + } + out = append(out, call) + } + if len(out) > 0 { + return out + } + if call, ok := parseFunctionCallTagStyle(text); ok { + return []ParsedToolCall{call} + } + if call, ok := parseAntmlFunctionCallStyle(text); ok { + return []ParsedToolCall{call} + } + if call, ok := parseInvokeFunctionCallStyle(text); ok { + return []ParsedToolCall{call} + } + return nil +} + +func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { + inner := strings.TrimSpace(block) + inner = strings.TrimPrefix(inner, "") + inner = strings.TrimSuffix(inner, "") + inner = strings.TrimSpace(inner) + if strings.HasPrefix(inner, "{") { + var payload map[string]any + if err := json.Unmarshal([]byte(inner), &payload); err == nil { + name := strings.TrimSpace(asString(payload["tool"])) + if name == "" { + name = strings.TrimSpace(asString(payload["tool_name"])) + } + if name != "" { + input := map[string]any{} + if params, ok := payload["params"].(map[string]any); ok { + input = params + } else if params, ok := payload["parameters"].(map[string]any); ok { + input = params + } + return ParsedToolCall{Name: name, Input: input}, true + } + } + } + + dec := xml.NewDecoder(strings.NewReader(block)) + name := "" + params := map[string]any{} + inParams := false + inTool := false + for { + tok, err := dec.Token() + if err != nil { + break + } + switch t := tok.(type) { + case xml.StartElement: + tag := strings.ToLower(t.Name.Local) + switch tag { + case "tool": + inTool = true + for _, attr := range t.Attr { + if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") && strings.TrimSpace(name) == "" { + name = strings.TrimSpace(attr.Value) + } + } + case "parameters": + inParams = true + case "tool_name", "name": + var v string + if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" { + name = strings.TrimSpace(v) + } + default: + if inParams || inTool { + var v string + if err := dec.DecodeElement(&v, &t); err == nil { + params[t.Name.Local] = strings.TrimSpace(v) + } + } + } + case xml.EndElement: + tag := strings.ToLower(t.Name.Local) + if tag == "parameters" { + inParams = false + } + if tag == "tool" { + inTool = false + } + } + } + if strings.TrimSpace(name) == "" { + return ParsedToolCall{}, false + } + return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true +} + +func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { + m := functionCallPattern.FindStringSubmatch(text) + if len(m) < 2 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + input := map[string]any{} + for _, pm := range functionParamPattern.FindAllStringSubmatch(text, -1) { + if len(pm) < 3 { + continue + } + key := strings.TrimSpace(pm[1]) + val := strings.TrimSpace(pm[2]) + if key != "" { + input[key] = val + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func parseAntmlFunctionCallStyle(text string) (ParsedToolCall, bool) { + m := antmlFunctionCallPattern.FindStringSubmatch(text) + if len(m) < 3 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + body := strings.TrimSpace(m[2]) + input := map[string]any{} + if strings.HasPrefix(body, "{") { + if err := json.Unmarshal([]byte(body), &input); err == nil { + return ParsedToolCall{Name: name, Input: input}, true + } + } + if pm := antmlParametersPattern.FindStringSubmatch(body); len(pm) >= 2 { + if err := json.Unmarshal([]byte(strings.TrimSpace(pm[1])), &input); err == nil { + return ParsedToolCall{Name: name, Input: input}, true + } + } + for _, am := range antmlArgumentPattern.FindAllStringSubmatch(body, -1) { + if len(am) < 3 { + continue + } + k := strings.TrimSpace(am[1]) + v := strings.TrimSpace(am[2]) + if k != "" { + input[k] = v + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { + m := invokeCallPattern.FindStringSubmatch(text) + if len(m) < 3 { + return ParsedToolCall{}, false + } + name := strings.TrimSpace(m[1]) + if name == "" { + return ParsedToolCall{}, false + } + input := map[string]any{} + for _, pm := range invokeParamPattern.FindAllStringSubmatch(m[2], -1) { + if len(pm) < 3 { + continue + } + k := strings.TrimSpace(pm[1]) + v := strings.TrimSpace(pm[2]) + if k != "" { + input[k] = v + } + } + return ParsedToolCall{Name: name, Input: input}, true +} + +func asString(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/util/toolcalls_test.go b/internal/util/toolcalls_test.go index f38dbb0..6fd3d59 100644 --- a/internal/util/toolcalls_test.go +++ b/internal/util/toolcalls_test.go @@ -232,3 +232,31 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) { t.Fatalf("expected command argument, got %#v", calls[0].Input) } } + +func TestParseToolCallsSupportsNestedToolTagStyle(t *testing.T) { + text := `pwdshow cwd` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +} + +func TestParseToolCallsSupportsAntmlFunctionAttributeWithParametersTag(t *testing.T) { + text := `{"command":"pwd"}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "bash" { + t.Fatalf("expected canonical tool name bash, got %q", calls[0].Name) + } + if calls[0].Input["command"] != "pwd" { + t.Fatalf("expected command argument, got %#v", calls[0].Input) + } +}