补充工具调用行为说明并修正测试文档过时命令

This commit is contained in:
CJACK.
2026-03-03 00:39:02 +08:00
parent c329bf26b6
commit a6aa4a1839
12 changed files with 197 additions and 64 deletions

View File

@@ -2,9 +2,12 @@ package util
import (
"encoding/json"
"regexp"
"strings"
)
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
@@ -121,12 +124,7 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
if tc.Name == "" {
continue
}
matchedName := ""
if _, ok := allowed[tc.Name]; ok {
matchedName = tc.Name
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
matchedName = canonical
}
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
if matchedName == "" {
rejectedSet[tc.Name] = struct{}{}
continue
@@ -144,6 +142,31 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
return out, rejected
}
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
if _, ok := allowed[name]; ok {
return name
}
lower := strings.ToLower(strings.TrimSpace(name))
if canonical, ok := allowedCanonical[lower]; ok {
return canonical
}
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
return canonical
}
}
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
if loose == "" {
return ""
}
for candidateLower, canonical := range allowedCanonical {
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
return canonical
}
}
return ""
}
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {

View File

@@ -115,3 +115,25 @@ func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
}
}
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
calls := ParseToolCalls(text, []string{"search_web"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "search_web" {
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
}
}
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
}
}