mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-14 05:05:09 +08:00
feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.
This commit is contained in:
@@ -8,12 +8,48 @@ type StandardRequest struct {
|
||||
Messages []any
|
||||
FinalPrompt string
|
||||
ToolNames []string
|
||||
ToolChoice ToolChoicePolicy
|
||||
Stream bool
|
||||
Thinking bool
|
||||
Search bool
|
||||
PassThrough map[string]any
|
||||
}
|
||||
|
||||
type ToolChoiceMode string
|
||||
|
||||
const (
|
||||
ToolChoiceAuto ToolChoiceMode = "auto"
|
||||
ToolChoiceNone ToolChoiceMode = "none"
|
||||
ToolChoiceRequired ToolChoiceMode = "required"
|
||||
ToolChoiceForced ToolChoiceMode = "forced"
|
||||
)
|
||||
|
||||
type ToolChoicePolicy struct {
|
||||
Mode ToolChoiceMode
|
||||
ForcedName string
|
||||
Allowed map[string]struct{}
|
||||
}
|
||||
|
||||
func DefaultToolChoicePolicy() ToolChoicePolicy {
|
||||
return ToolChoicePolicy{Mode: ToolChoiceAuto}
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsNone() bool {
|
||||
return p.Mode == ToolChoiceNone
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) IsRequired() bool {
|
||||
return p.Mode == ToolChoiceRequired || p.Mode == ToolChoiceForced
|
||||
}
|
||||
|
||||
func (p ToolChoicePolicy) Allows(name string) bool {
|
||||
if len(p.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := p.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r StandardRequest) CompletionPayload(sessionID string) map[string]any {
|
||||
payload := map[string]any{
|
||||
"chat_session_id": sessionID,
|
||||
|
||||
@@ -10,38 +10,62 @@ type ParsedToolCall struct {
|
||||
Input map[string]any `json:"input"`
|
||||
}
|
||||
|
||||
type ToolCallParseResult struct {
|
||||
Calls []ParsedToolCall
|
||||
SawToolCallSyntax bool
|
||||
RejectedByPolicy bool
|
||||
RejectedToolNames []string
|
||||
}
|
||||
|
||||
func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
text = stripFencedCodeBlocks(text)
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(text), "tool_calls")
|
||||
|
||||
candidates := buildToolCallCandidates(text)
|
||||
var parsed []ParsedToolCall
|
||||
for _, candidate := range candidates {
|
||||
if tc := parseToolCallsPayload(candidate); len(tc) > 0 {
|
||||
parsed = tc
|
||||
result.SawToolCallSyntax = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(parsed) == 0 {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
|
||||
result := ToolCallParseResult{}
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
if looksLikeToolExampleContext(trimmed) {
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
result.SawToolCallSyntax = strings.Contains(strings.ToLower(trimmed), "tool_calls")
|
||||
candidates := []string{trimmed}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
@@ -52,24 +76,31 @@ func ParseStandaloneToolCalls(text string, availableToolNames []string) []Parsed
|
||||
continue
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
result.SawToolCallSyntax = true
|
||||
calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames)
|
||||
result.Calls = calls
|
||||
result.RejectedToolNames = rejectedNames
|
||||
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return result
|
||||
}
|
||||
|
||||
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
|
||||
func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) {
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
}
|
||||
out := make([]ParsedToolCall, 0, len(parsed))
|
||||
rejectedSet := map[string]struct{}{}
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[tc.Name]; !ok {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -78,21 +109,11 @@ func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []Par
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
// If the model clearly emitted tool_calls JSON but all names are outside the
|
||||
// declared set, keep the parsed calls as a fallback so upper layers can still
|
||||
// intercept structured tool output instead of leaking raw JSON to users.
|
||||
if len(out) == 0 && len(parsed) > 0 {
|
||||
for _, tc := range parsed {
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
if tc.Input == nil {
|
||||
tc.Input = map[string]any{}
|
||||
}
|
||||
out = append(out, tc)
|
||||
}
|
||||
rejected := make([]string, 0, len(rejectedSet))
|
||||
for name := range rejectedSet {
|
||||
rejected = append(rejected, name)
|
||||
}
|
||||
return out
|
||||
return out, rejected
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
|
||||
@@ -38,14 +38,25 @@ func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsKeepsUnknownAsFallback(t *testing.T) {
|
||||
func TestParseToolCallsRejectsUnknownToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected fallback 1 call, got %d", len(calls))
|
||||
if len(calls) != 0 {
|
||||
t.Fatalf("expected unknown tool to be rejected, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "unknown" {
|
||||
t.Fatalf("unexpected name: %s", calls[0].Name)
|
||||
}
|
||||
|
||||
func TestParseToolCallsDetailedMarksPolicyRejection(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
|
||||
res := ParseToolCallsDetailed(text, []string{"search"})
|
||||
if !res.SawToolCallSyntax {
|
||||
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
|
||||
}
|
||||
if !res.RejectedByPolicy {
|
||||
t.Fatalf("expected RejectedByPolicy=true, got %#v", res)
|
||||
}
|
||||
if len(res.Calls) != 0 {
|
||||
t.Fatalf("expected no calls after policy rejection, got %#v", res.Calls)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user