mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 00:15:28 +08:00
170 lines
4.2 KiB
Go
170 lines
4.2 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"ds2api/internal/util"
|
|
)
|
|
|
|
func injectToolPrompt(messages []map[string]any, tools []any, policy util.ToolChoicePolicy) ([]map[string]any, []string) {
|
|
if policy.IsNone() {
|
|
return messages, nil
|
|
}
|
|
toolSchemas := make([]string, 0, len(tools))
|
|
names := make([]string, 0, len(tools))
|
|
isAllowed := func(name string) bool {
|
|
if strings.TrimSpace(name) == "" {
|
|
return false
|
|
}
|
|
if len(policy.Allowed) == 0 {
|
|
return true
|
|
}
|
|
_, ok := policy.Allowed[name]
|
|
return ok
|
|
}
|
|
|
|
for _, t := range tools {
|
|
tool, ok := t.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
fn, _ := tool["function"].(map[string]any)
|
|
if len(fn) == 0 {
|
|
fn = tool
|
|
}
|
|
name, _ := fn["name"].(string)
|
|
desc, _ := fn["description"].(string)
|
|
schema, _ := fn["parameters"].(map[string]any)
|
|
name = strings.TrimSpace(name)
|
|
if !isAllowed(name) {
|
|
continue
|
|
}
|
|
names = append(names, name)
|
|
if desc == "" {
|
|
desc = "No description available"
|
|
}
|
|
b, _ := json.Marshal(schema)
|
|
toolSchemas = append(toolSchemas, fmt.Sprintf("Tool: %s\nDescription: %s\nParameters: %s", name, desc, string(b)))
|
|
}
|
|
if len(toolSchemas) == 0 {
|
|
return messages, names
|
|
}
|
|
toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\n" + buildToolCallInstructions(names)
|
|
if policy.Mode == util.ToolChoiceRequired {
|
|
toolPrompt += "\n7) For this response, you MUST call at least one tool from the allowed list."
|
|
}
|
|
if policy.Mode == util.ToolChoiceForced && strings.TrimSpace(policy.ForcedName) != "" {
|
|
toolPrompt += "\n7) For this response, you MUST call exactly this tool name: " + strings.TrimSpace(policy.ForcedName)
|
|
toolPrompt += "\n8) Do not call any other tool."
|
|
}
|
|
|
|
for i := range messages {
|
|
if messages[i]["role"] == "system" {
|
|
old, _ := messages[i]["content"].(string)
|
|
messages[i]["content"] = strings.TrimSpace(old + "\n\n" + toolPrompt)
|
|
return messages, names
|
|
}
|
|
}
|
|
messages = append([]map[string]any{{"role": "system", "content": toolPrompt}}, messages...)
|
|
return messages, names
|
|
}
|
|
|
|
// buildToolCallInstructions delegates to the shared util implementation.
|
|
func buildToolCallInstructions(toolNames []string) string {
|
|
return util.BuildToolCallInstructions(toolNames)
|
|
}
|
|
|
|
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
|
if len(deltas) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]map[string]any, 0, len(deltas))
|
|
for _, d := range deltas {
|
|
if d.Name == "" && d.Arguments == "" {
|
|
continue
|
|
}
|
|
callID, ok := ids[d.Index]
|
|
if !ok || callID == "" {
|
|
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
ids[d.Index] = callID
|
|
}
|
|
item := map[string]any{
|
|
"index": d.Index,
|
|
"id": callID,
|
|
"type": "function",
|
|
}
|
|
fn := map[string]any{}
|
|
if d.Name != "" {
|
|
fn["name"] = d.Name
|
|
}
|
|
if d.Arguments != "" {
|
|
fn["arguments"] = d.Arguments
|
|
}
|
|
if len(fn) > 0 {
|
|
item["function"] = fn
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
|
|
if len(deltas) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]toolCallDelta, 0, len(deltas))
|
|
for _, d := range deltas {
|
|
if d.Name != "" {
|
|
if seenNames != nil {
|
|
seenNames[d.Index] = d.Name
|
|
}
|
|
out = append(out, d)
|
|
continue
|
|
}
|
|
if seenNames == nil {
|
|
out = append(out, d)
|
|
continue
|
|
}
|
|
name := strings.TrimSpace(seenNames[d.Index])
|
|
if name == "" {
|
|
continue
|
|
}
|
|
out = append(out, d)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func formatFinalStreamToolCallsWithStableIDs(calls []util.ParsedToolCall, ids map[int]string) []map[string]any {
|
|
if len(calls) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]map[string]any, 0, len(calls))
|
|
for i, c := range calls {
|
|
callID := ""
|
|
if ids != nil {
|
|
callID = strings.TrimSpace(ids[i])
|
|
}
|
|
if callID == "" {
|
|
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
if ids != nil {
|
|
ids[i] = callID
|
|
}
|
|
}
|
|
args, _ := json.Marshal(c.Input)
|
|
out = append(out, map[string]any{
|
|
"index": i,
|
|
"id": callID,
|
|
"type": "function",
|
|
"function": map[string]any{
|
|
"name": c.Name,
|
|
"arguments": string(args),
|
|
},
|
|
})
|
|
}
|
|
return out
|
|
}
|