Files
ds2api/internal/toolcall/toolcalls_parse.go

86 lines
2.3 KiB
Go

package toolcall
import (
"strings"
)
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
}
type ToolCallParseResult struct {
Calls []ParsedToolCall
SawToolCallSyntax bool
RejectedByPolicy bool
RejectedToolNames []string
}
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseToolCallsDetailed(text, availableToolNames).Calls
}
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
return parseToolCallsDetailedXMLOnly(text)
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
}
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
return parseToolCallsDetailedXMLOnly(text)
}
func parseToolCallsDetailedXMLOnly(text string) ToolCallParseResult {
result := ToolCallParseResult{}
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return result
}
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
parsed := parseXMLToolCalls(trimmed)
if len(parsed) == 0 {
parsed = parseMarkupToolCalls(trimmed)
}
if len(parsed) == 0 {
return result
}
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []string) {
out := make([]ParsedToolCall, 0, len(parsed))
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if tc.Input == nil {
tc.Input = map[string]any{}
}
out = append(out, tc)
}
return out, nil
}
func looksLikeToolCallSyntax(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "<tool_calls") ||
strings.Contains(lower, "<tool_call") ||
strings.Contains(lower, "<function_calls") ||
strings.Contains(lower, "<function_call") ||
strings.Contains(lower, "<invoke") ||
strings.Contains(lower, "<tool_use") ||
strings.Contains(lower, "<attempt_completion") ||
strings.Contains(lower, "<ask_followup_question") ||
strings.Contains(lower, "<new_task") ||
strings.Contains(lower, "<result")
}