Files
ds2api/internal/toolcall/toolcalls_parse.go

164 lines
3.7 KiB
Go

package toolcall
import (
"strings"
)
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
}
type ToolCallParseResult struct {
Calls []ParsedToolCall
SawToolCallSyntax bool
RejectedByPolicy bool
RejectedToolNames []string
}
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseToolCallsDetailed(text, availableToolNames).Calls
}
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
return parseToolCallsDetailedXMLOnly(text)
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
}
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
return parseToolCallsDetailedXMLOnly(text)
}
func parseToolCallsDetailedXMLOnly(text string) ToolCallParseResult {
result := ToolCallParseResult{}
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return result
}
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
trimmed = stripFencedCodeBlocks(trimmed)
trimmed = strings.TrimSpace(trimmed)
if trimmed == "" {
return result
}
parsed := parseXMLToolCalls(trimmed)
if len(parsed) == 0 {
parsed = parseMarkupToolCalls(trimmed)
}
if len(parsed) == 0 {
return result
}
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []string) {
out := make([]ParsedToolCall, 0, len(parsed))
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if tc.Input == nil {
tc.Input = map[string]any{}
}
out = append(out, tc)
}
return out, nil
}
func looksLikeToolCallSyntax(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "<tool_calls") ||
strings.Contains(lower, "<tool_call") ||
strings.Contains(lower, "<function_calls") ||
strings.Contains(lower, "<function_call") ||
strings.Contains(lower, "<invoke") ||
strings.Contains(lower, "<tool_use") ||
strings.Contains(lower, "<attempt_completion") ||
strings.Contains(lower, "<ask_followup_question") ||
strings.Contains(lower, "<new_task") ||
strings.Contains(lower, "<result")
}
func stripFencedCodeBlocks(text string) string {
if text == "" {
return ""
}
var b strings.Builder
b.Grow(len(text))
lines := strings.SplitAfter(text, "\n")
inFence := false
fenceMarker := ""
for _, line := range lines {
trimmed := strings.TrimLeft(line, " \t")
if !inFence {
if marker, ok := parseFenceOpen(trimmed); ok {
inFence = true
fenceMarker = marker
continue
}
b.WriteString(line)
continue
}
if isFenceClose(trimmed, fenceMarker) {
inFence = false
fenceMarker = ""
}
}
if inFence {
return ""
}
return b.String()
}
func parseFenceOpen(line string) (string, bool) {
if len(line) < 3 {
return "", false
}
ch := line[0]
if ch != '`' && ch != '~' {
return "", false
}
count := countLeadingFenceChars(line, ch)
if count < 3 {
return "", false
}
return strings.Repeat(string(ch), count), true
}
func isFenceClose(line, marker string) bool {
if marker == "" {
return false
}
ch := marker[0]
if line == "" || line[0] != ch {
return false
}
count := countLeadingFenceChars(line, ch)
if count < len(marker) {
return false
}
rest := strings.TrimSpace(line[count:])
return rest == ""
}
func countLeadingFenceChars(line string, ch byte) int {
count := 0
for count < len(line) && line[count] == ch {
count++
}
return count
}