mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-18 15:15:08 +08:00
feat: Improve OpenAI tool call handling by passing unknown tool calls as content and filtering streamed tool calls by schema.
This commit is contained in:
@@ -10,9 +10,23 @@ import (
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) {
|
||||
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
||||
if policy.IsNone() {
|
||||
return messages, nil
|
||||
}
|
||||
toolSchemas := make([]string, 0, len(tools))
|
||||
names := make([]string, 0, len(tools))
|
||||
isAllowed := func(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return false
|
||||
}
|
||||
if len(policy.Allowed) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := policy.Allowed[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
for _, t := range tools {
|
||||
tool, ok := t.(map[string]any)
|
||||
if !ok {
|
||||
@@ -25,8 +39,9 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
schema, _ := fn["parameters"].(map[string]any)
|
||||
if name == "" {
|
||||
name = "unknown"
|
||||
name = strings.TrimSpace(name)
|
||||
if !isAllowed(name) {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
if desc == "" {
|
||||
@@ -39,6 +54,13 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nHistory markers in conversation:\n- [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] means a tool call you already made earlier.\n- [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] means the runtime returned a tool result (not user input).\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error.\n4) Do not repeat a tool call that is already satisfied by an existing [TOOL_RESULT_HISTORY] block."
|
||||
if policy.Mode == util.ToolChoiceRequired {
|
||||
toolPrompt += "\n5) For this response, you MUST call at least one tool from the allowed list."
|
||||
}
|
||||
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
||||
toolPrompt += "\n5) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
||||
toolPrompt += "\n6) Do not call any other tool."
|
||||
}
|
||||
|
||||
for i := range messages {
|
||||
if messages[i]["role"] == "system" {
|
||||
@@ -85,6 +107,33 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
|
||||
return out
|
||||
}
|
||||
|
||||
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := namesToSet(allowedNames)
|
||||
out := make([]toolCallDelta, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name != "" {
|
||||
if len(allowed) > 0 {
|
||||
if _, ok := allowed[d.Name]; !ok {
|
||||
seenNames[d.Index] = "__blocked__"
|
||||
continue
|
||||
}
|
||||
}
|
||||
seenNames[d.Index] = d.Name
|
||||
out = append(out, d)
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(seenNames[d.Index])
|
||||
if name == "" || name == "__blocked__" {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user