Files
ds2api/internal/promptcompat/tool_prompt.go
2026-04-29 01:59:05 +08:00

67 lines
1.9 KiB
Go

package promptcompat
import (
"encoding/json"
"fmt"
"strings"
"ds2api/internal/toolcall"
)
func injectToolPrompt(messages []map[string]any, tools []any, policy ToolChoicePolicy) ([]map[string]any, []string) {
if policy.IsNone() {
return messages, nil
}
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))
isAllowed := func(name string) bool {
if strings.TrimSpace(name) == "" {
return false
}
if len(policy.Allowed) == 0 {
return true
}
_, ok := policy.Allowed[name]
return ok
}
for _, t := range tools {
tool, ok := t.(map[string]any)
if !ok {
continue
}
name, desc, schema := toolcall.ExtractToolMeta(tool)
name = strings.TrimSpace(name)
if !isAllowed(name) {
continue
}
names = append(names, name)
if desc == "" {
desc = "No description available"
}
b, _ := json.Marshal(schema)
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
}
if len(toolSchemas) == 0 {
return messages, names
}
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\n" + toolcall.BuildToolCallInstructions(names)
if policy.Mode == ToolChoiceRequired {
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
}
if policy.Mode == ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
toolPrompt += "\n8) Do not call any other tool."
}
for i := range messages {
if messages[i]["role"] == "system" {
old, _ := messages[i]["content"].(string)
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
return messages, names
}
}
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
return messages, names
}