feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.

This commit is contained in:
CJACK
2026-02-22 19:33:52 +08:00
parent 312728c8b6
commit ae7dce0b32
26 changed files with 1109 additions and 501 deletions

View File

@@ -8,12 +8,48 @@ type StandardRequest struct {
Messages []any
FinalPrompt string
ToolNames []string
ToolChoice ToolChoicePolicy
Stream bool
Thinking bool
Search bool
PassThrough map[string]any
}
type ToolChoiceMode string
const (
ToolChoiceAuto ToolChoiceMode = "auto"
ToolChoiceNone ToolChoiceMode = "none"
ToolChoiceRequired ToolChoiceMode = "required"
ToolChoiceForced ToolChoiceMode = "forced"
)
type ToolChoicePolicy struct {
Mode ToolChoiceMode
ForcedName string
Allowed map[string]struct{}
}
func DefaultToolChoicePolicy() ToolChoicePolicy {
return ToolChoicePolicy{Mode: ToolChoiceAuto}
}
func (p ToolChoicePolicy) IsNone() bool {
return p.Mode == ToolChoiceNone
}
func (p ToolChoicePolicy) IsRequired() bool {
return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced
}
func (p ToolChoicePolicy) Allows(name string) bool {
if len(p.Allowed) == 0 {
return true
}
_, ok := p.Allowed[name]
return ok
}
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
payload := map[string]any{
"chat_session_id": sessionID,

View File

@@ -10,38 +10,62 @@ type ParsedToolCall struct {
Input map[string]any `json:"input"`
}
type ToolCallParseResult struct {
Calls []ParsedToolCall
SawToolCallSyntax bool
RejectedByPolicy bool
RejectedToolNames []string
}
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseToolCallsDetailed(text, availableToolNames).Calls
}
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
result := ToolCallParseResult{}
if strings.TrimSpace(text) == "" {
return nil
return result
}
text = stripFencedCodeBlocks(text)
if strings.TrimSpace(text) == "" {
return nil
return result
}
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
candidates := buildToolCallCandidates(text)
var parsed []ParsedToolCall
for _, candidate := range candidates {
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
parsed = tc
result.SawToolCallSyntax = true
break
}
}
if len(parsed) == 0 {
return nil
return result
}
return filterToolCalls(parsed, availableToolNames)
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
}
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
result := ToolCallParseResult{}
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return nil
return result
}
if looksLikeToolExampleContext(trimmed) {
return nil
return result
}
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
candidates := []string{trimmed}
for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate)
@@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
continue
}
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
return filterToolCalls(parsed, availableToolNames)
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
}
return nil
return result
}
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
allowed := map[string]struct{}{}
for _, name := range availableToolNames {
allowed[name] = struct{}{}
}
out := make([]ParsedToolCall, 0, len(parsed))
rejectedSet := map[string]struct{}{}
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if len(allowed) > 0 {
if _, ok := allowed[tc.Name]; !ok {
rejectedSet[tc.Name] = struct{}{}
continue
}
}
@@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
}
out = append(out, tc)
}
// If the model clearly emitted tool_calls JSON but all names are outside the
// declared set, keep the parsed calls as a fallback so upper layers can still
// intercept structured tool output instead of leaking raw JSON to users.
if len(out) == 0 && len(parsed) > 0 {
for _, tc := range parsed {
if tc.Name == "" {
continue
}
if tc.Input == nil {
tc.Input = map[string]any{}
}
out = append(out, tc)
}
rejected := make([]string, 0, len(rejectedSet))
for name := range rejectedSet {
rejected = append(rejected, name)
}
return out
return out, rejected
}
func parseToolCallsPayload(payload string) []ParsedToolCall {

View File

@@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
}
}
func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) {
func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected fallback 1 call, got %d", len(calls))
if len(calls) != 0 {
t.Fatalf("expected unknown tool to be rejected, got %#v", calls)
}
if calls[0].Name != "unknown" {
t.Fatalf("unexpected name: %s", calls[0].Name)
}
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
res := ParseToolCallsDetailed(text, []string{"search"})
if !res.SawToolCallSyntax {
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
}
if !res.RejectedByPolicy {
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
}
if len(res.Calls) != 0 {
t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls)
}
}