mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-10 19:27:41 +08:00
67 lines
1.9 KiB
Go
67 lines
1.9 KiB
Go
package promptcompat
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"ds2api/internal/toolcall"
|
|
)
|
|
|
|
func injectToolPrompt(messages []map[string]any, tools []any, policy ToolChoicePolicy) ([]map[string]any, []string) {
|
|
if policy.IsNone() {
|
|
return messages, nil
|
|
}
|
|
toolSchemas := make([]string, 0, len(tools))
|
|
names := make([]string, 0, len(tools))
|
|
isAllowed := func(name string) bool {
|
|
if strings.TrimSpace(name) == "" {
|
|
return false
|
|
}
|
|
if len(policy.Allowed) == 0 {
|
|
return true
|
|
}
|
|
_, ok := policy.Allowed[name]
|
|
return ok
|
|
}
|
|
|
|
for _, t := range tools {
|
|
tool, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
name, desc, schema := toolcall.ExtractToolMeta(tool)
|
|
name = strings.TrimSpace(name)
|
|
if !isAllowed(name) {
|
|
continue
|
|
}
|
|
names = append(names, name)
|
|
if desc == "" {
|
|
desc = "No description available"
|
|
}
|
|
b, _ := json.Marshal(schema)
|
|
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
|
}
|
|
if len(toolSchemas) == 0 {
|
|
return messages, names
|
|
}
|
|
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\n" + toolcall.BuildToolCallInstructions(names)
|
|
if policy.Mode == ToolChoiceRequired {
|
|
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
|
|
}
|
|
if policy.Mode == ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
|
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
|
toolPrompt += "\n8) Do not call any other tool."
|
|
}
|
|
|
|
for i := range messages {
|
|
if messages[i]["role"] == "system" {
|
|
old, _ := messages[i]["content"].(string)
|
|
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
|
return messages, names
|
|
}
|
|
}
|
|
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
|
return messages, names
|
|
}
|