Files
ds2api/internal/prompt/tool_calls.go

383 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"encoding/json"
"fmt"
"regexp"
"sort"
"strings"
)
var promptXMLTextEscaper = strings.NewReplacer(
"&", "&",
"<", "&lt;",
">", "&gt;",
)
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
const (
promptDSMLToolCallsOpen = "<DSMLtool_calls>"
promptDSMLToolCallsClose = "</DSMLtool_calls>"
promptDSMLInvokeOpen = "<DSMLinvoke"
promptDSMLInvokeClose = "</DSMLinvoke>"
promptDSMLParameterOpen = "<DSMLparameter"
promptDSMLParameterClose = "</DSMLparameter>"
)
// FormatToolCallsForPrompt renders a tool_calls slice into the prompt-visible
// invoke/parameter history block used across adapters.
func FormatToolCallsForPrompt(raw any) string {
calls, ok := raw.([]any)
if !ok || len(calls) == 0 {
return ""
}
blocks := make([]string, 0, len(calls))
for _, item := range calls {
call, ok := item.(map[string]any)
if !ok {
continue
}
block := formatToolCallForPrompt(call)
if block != "" {
blocks = append(blocks, block)
}
}
if len(blocks) == 0 {
return ""
}
return promptDSMLToolCallsOpen + "\n" + strings.Join(blocks, "\n") + "\n" + promptDSMLToolCallsClose
}
// StringifyToolCallArguments normalizes tool arguments into a compact string
// while preserving raw concatenated payloads when they already look like model
// output rather than a single JSON object.
func StringifyToolCallArguments(v any) string {
switch x := v.(type) {
case nil:
return "{}"
case string:
s := strings.TrimSpace(x)
if s == "" {
return "{}"
}
s = normalizeToolArgumentString(s)
if s == "" {
return "{}"
}
return s
default:
b, err := json.Marshal(x)
if err != nil || len(b) == 0 {
return "{}"
}
return string(b)
}
}
func formatToolCallForPrompt(call map[string]any) string {
if call == nil {
return ""
}
name := strings.TrimSpace(asString(call["name"]))
fn, _ := call["function"].(map[string]any)
if name == "" && fn != nil {
name = strings.TrimSpace(asString(fn["name"]))
}
if name == "" {
return ""
}
argsRaw := call["arguments"]
if argsRaw == nil {
argsRaw = call["input"]
}
if argsRaw == nil && fn != nil {
argsRaw = fn["arguments"]
if argsRaw == nil {
argsRaw = fn["input"]
}
}
parameters := formatToolCallParametersForPrompt(argsRaw)
if parameters == "" {
return ` ` + promptDSMLInvokeOpen + ` name="` + escapeXMLAttribute(name) + `">` + promptDSMLInvokeClose
}
return " " + promptDSMLInvokeOpen + " name=\"" + escapeXMLAttribute(name) + "\">\n" +
parameters + "\n" +
" " + promptDSMLInvokeClose
}
func formatToolCallParametersForPrompt(raw any) string {
value := normalizePromptToolCallValue(raw)
body, ok := renderPromptToolParameters(value, " ")
if ok && strings.TrimSpace(body) != "" {
return body
}
fallback := StringifyToolCallArguments(raw)
if strings.TrimSpace(fallback) == "" {
return ""
}
return " " + promptDSMLParameterOpen + " name=\"content\">" + renderPromptXMLText(fallback) + promptDSMLParameterClose
}
func renderPromptToolParameters(value any, indent string) (string, bool) {
switch v := value.(type) {
case nil:
return "", true
case map[string]any:
if len(v) == 0 {
return "", true
}
keys := make([]string, 0, len(v))
for k := range v {
keys = append(keys, k)
}
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, key := range keys {
rendered, ok := renderPromptParameterNode(key, v[key], indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
case []any:
lines := make([]string, 0, len(v))
for _, item := range v {
rendered, ok := renderPromptParameterNode("item", item, indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
case string:
return indent + promptDSMLParameterOpen + ` name="content">` + renderPromptXMLText(v) + promptDSMLParameterClose, true
default:
return indent + promptDSMLParameterOpen + ` name="value">` + renderPromptXMLText(fmt.Sprint(v)) + promptDSMLParameterClose, true
}
}
func renderPromptParameterNode(name string, value any, indent string) (string, bool) {
trimmedName := strings.TrimSpace(name)
if trimmedName == "" {
return "", false
}
switch v := value.(type) {
case nil:
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
case map[string]any:
body, ok := renderPromptToolXMLBody(v, indent+" ")
if !ok {
return "", false
}
if strings.TrimSpace(body) == "" {
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
}
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + promptDSMLParameterClose, true
case []any:
body, ok := renderPromptToolXMLArray(v, indent+" ")
if !ok {
return "", false
}
if strings.TrimSpace(body) == "" {
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
}
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + promptDSMLParameterClose, true
case string:
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(v) + promptDSMLParameterClose, true
default:
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(fmt.Sprint(v)) + promptDSMLParameterClose, true
}
}
func normalizePromptToolCallValue(raw any) any {
switch x := raw.(type) {
case nil:
return nil
case string:
s := strings.TrimSpace(x)
if s == "" {
return ""
}
var parsed any
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
return parsed
}
return x
default:
return x
}
}
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
switch v := value.(type) {
case nil:
return "", true
case map[string]any:
return renderPromptToolXMLMap(v, indent)
case []any:
return renderPromptToolXMLArray(v, indent)
case string:
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
default:
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
}
}
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
if len(m) == 0 {
return "", true
}
keys := make([]string, 0, len(m))
for k := range m {
if !isValidPromptXMLName(k) {
return "", false
}
keys = append(keys, k)
}
sort.Strings(keys)
lines := make([]string, 0, len(keys))
for _, key := range keys {
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
}
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
if len(items) == 0 {
return "", true
}
lines := make([]string, 0, len(items))
for _, item := range items {
rendered, ok := renderPromptToolXMLNode("item", item, indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
}
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
if !isValidPromptXMLName(name) {
return "", false
}
switch v := value.(type) {
case nil:
return indent + "<" + name + "></" + name + ">", true
case map[string]any:
inner, ok := renderPromptToolXMLMap(v, indent+" ")
if !ok {
return "", false
}
if strings.TrimSpace(inner) == "" {
return indent + "<" + name + "></" + name + ">", true
}
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
case []any:
if len(v) == 0 {
return indent + "<" + name + "></" + name + ">", true
}
lines := make([]string, 0, len(v))
for _, item := range v {
rendered, ok := renderPromptToolXMLNode(name, item, indent)
if !ok {
return "", false
}
lines = append(lines, rendered)
}
return strings.Join(lines, "\n"), true
case string:
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
default:
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
}
}
// renderPromptXMLText emits CDATA for every string so prompt-visible tool
// history stays uniform and does not drift back toward ad-hoc escaping.
func renderPromptXMLText(text string) string {
if text == "" {
return ""
}
if strings.Contains(text, "]]>") {
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
}
return "<![CDATA[" + text + "]]>"
}
func isValidPromptXMLName(name string) bool {
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
}
func escapeXMLAttribute(text string) string {
if text == "" {
return ""
}
return strings.NewReplacer(
"&", "&amp;",
`"`, "&quot;",
"<", "&lt;",
">", "&gt;",
).Replace(text)
}
func normalizeToolArgumentString(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
if looksLikeConcatenatedJSON(trimmed) {
// Keep the original payload to avoid silently rewriting model output.
return raw
}
return trimmed
}
func looksLikeConcatenatedJSON(raw string) bool {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return false
}
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
return true
}
dec := json.NewDecoder(strings.NewReader(trimmed))
var first any
if err := dec.Decode(&first); err != nil {
return false
}
var second any
return dec.Decode(&second) == nil
}
func asString(v any) string {
if s, ok := v.(string); ok {
return s
}
return ""
}
func escapeXMLText(v string) string {
if v == "" {
return ""
}
return promptXMLTextEscaper.Replace(v)
}