mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-17 06:35:14 +08:00
383 lines
10 KiB
Go
383 lines
10 KiB
Go
package prompt
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"regexp"
|
||
"sort"
|
||
"strings"
|
||
)
|
||
|
||
var promptXMLTextEscaper = strings.NewReplacer(
|
||
"&", "&",
|
||
"<", "<",
|
||
">", ">",
|
||
)
|
||
|
||
var promptXMLNamePattern = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_.:-]*$`)
|
||
|
||
const (
|
||
promptDSMLToolCallsOpen = "<|DSML|tool_calls>"
|
||
promptDSMLToolCallsClose = "</|DSML|tool_calls>"
|
||
promptDSMLInvokeOpen = "<|DSML|invoke"
|
||
promptDSMLInvokeClose = "</|DSML|invoke>"
|
||
promptDSMLParameterOpen = "<|DSML|parameter"
|
||
promptDSMLParameterClose = "</|DSML|parameter>"
|
||
)
|
||
|
||
// FormatToolCallsForPrompt renders a tool_calls slice into the prompt-visible
|
||
// invoke/parameter history block used across adapters.
|
||
func FormatToolCallsForPrompt(raw any) string {
|
||
calls, ok := raw.([]any)
|
||
if !ok || len(calls) == 0 {
|
||
return ""
|
||
}
|
||
|
||
blocks := make([]string, 0, len(calls))
|
||
for _, item := range calls {
|
||
call, ok := item.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
block := formatToolCallForPrompt(call)
|
||
if block != "" {
|
||
blocks = append(blocks, block)
|
||
}
|
||
}
|
||
if len(blocks) == 0 {
|
||
return ""
|
||
}
|
||
return promptDSMLToolCallsOpen + "\n" + strings.Join(blocks, "\n") + "\n" + promptDSMLToolCallsClose
|
||
}
|
||
|
||
// StringifyToolCallArguments normalizes tool arguments into a compact string
|
||
// while preserving raw concatenated payloads when they already look like model
|
||
// output rather than a single JSON object.
|
||
func StringifyToolCallArguments(v any) string {
|
||
switch x := v.(type) {
|
||
case nil:
|
||
return "{}"
|
||
case string:
|
||
s := strings.TrimSpace(x)
|
||
if s == "" {
|
||
return "{}"
|
||
}
|
||
s = normalizeToolArgumentString(s)
|
||
if s == "" {
|
||
return "{}"
|
||
}
|
||
return s
|
||
default:
|
||
b, err := json.Marshal(x)
|
||
if err != nil || len(b) == 0 {
|
||
return "{}"
|
||
}
|
||
return string(b)
|
||
}
|
||
}
|
||
|
||
func formatToolCallForPrompt(call map[string]any) string {
|
||
if call == nil {
|
||
return ""
|
||
}
|
||
|
||
name := strings.TrimSpace(asString(call["name"]))
|
||
fn, _ := call["function"].(map[string]any)
|
||
if name == "" && fn != nil {
|
||
name = strings.TrimSpace(asString(fn["name"]))
|
||
}
|
||
if name == "" {
|
||
return ""
|
||
}
|
||
|
||
argsRaw := call["arguments"]
|
||
if argsRaw == nil {
|
||
argsRaw = call["input"]
|
||
}
|
||
if argsRaw == nil && fn != nil {
|
||
argsRaw = fn["arguments"]
|
||
if argsRaw == nil {
|
||
argsRaw = fn["input"]
|
||
}
|
||
}
|
||
|
||
parameters := formatToolCallParametersForPrompt(argsRaw)
|
||
if parameters == "" {
|
||
return ` ` + promptDSMLInvokeOpen + ` name="` + escapeXMLAttribute(name) + `">` + promptDSMLInvokeClose
|
||
}
|
||
|
||
return " " + promptDSMLInvokeOpen + " name=\"" + escapeXMLAttribute(name) + "\">\n" +
|
||
parameters + "\n" +
|
||
" " + promptDSMLInvokeClose
|
||
}
|
||
|
||
func formatToolCallParametersForPrompt(raw any) string {
|
||
value := normalizePromptToolCallValue(raw)
|
||
body, ok := renderPromptToolParameters(value, " ")
|
||
if ok && strings.TrimSpace(body) != "" {
|
||
return body
|
||
}
|
||
|
||
fallback := StringifyToolCallArguments(raw)
|
||
if strings.TrimSpace(fallback) == "" {
|
||
return ""
|
||
}
|
||
return " " + promptDSMLParameterOpen + " name=\"content\">" + renderPromptXMLText(fallback) + promptDSMLParameterClose
|
||
}
|
||
|
||
func renderPromptToolParameters(value any, indent string) (string, bool) {
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return "", true
|
||
case map[string]any:
|
||
if len(v) == 0 {
|
||
return "", true
|
||
}
|
||
keys := make([]string, 0, len(v))
|
||
for k := range v {
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
lines := make([]string, 0, len(keys))
|
||
for _, key := range keys {
|
||
rendered, ok := renderPromptParameterNode(key, v[key], indent)
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
lines = append(lines, rendered)
|
||
}
|
||
return strings.Join(lines, "\n"), true
|
||
case []any:
|
||
lines := make([]string, 0, len(v))
|
||
for _, item := range v {
|
||
rendered, ok := renderPromptParameterNode("item", item, indent)
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
lines = append(lines, rendered)
|
||
}
|
||
return strings.Join(lines, "\n"), true
|
||
case string:
|
||
return indent + promptDSMLParameterOpen + ` name="content">` + renderPromptXMLText(v) + promptDSMLParameterClose, true
|
||
default:
|
||
return indent + promptDSMLParameterOpen + ` name="value">` + renderPromptXMLText(fmt.Sprint(v)) + promptDSMLParameterClose, true
|
||
}
|
||
}
|
||
|
||
func renderPromptParameterNode(name string, value any, indent string) (string, bool) {
|
||
trimmedName := strings.TrimSpace(name)
|
||
if trimmedName == "" {
|
||
return "", false
|
||
}
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
|
||
case map[string]any:
|
||
body, ok := renderPromptToolXMLBody(v, indent+" ")
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
if strings.TrimSpace(body) == "" {
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
|
||
}
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + promptDSMLParameterClose, true
|
||
case []any:
|
||
body, ok := renderPromptToolXMLArray(v, indent+" ")
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
if strings.TrimSpace(body) == "" {
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + promptDSMLParameterClose, true
|
||
}
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + "\">\n" + body + "\n" + indent + promptDSMLParameterClose, true
|
||
case string:
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(v) + promptDSMLParameterClose, true
|
||
default:
|
||
return indent + promptDSMLParameterOpen + ` name="` + escapeXMLAttribute(trimmedName) + `">` + renderPromptXMLText(fmt.Sprint(v)) + promptDSMLParameterClose, true
|
||
}
|
||
}
|
||
|
||
func normalizePromptToolCallValue(raw any) any {
|
||
switch x := raw.(type) {
|
||
case nil:
|
||
return nil
|
||
case string:
|
||
s := strings.TrimSpace(x)
|
||
if s == "" {
|
||
return ""
|
||
}
|
||
var parsed any
|
||
if err := json.Unmarshal([]byte(s), &parsed); err == nil {
|
||
return parsed
|
||
}
|
||
return x
|
||
default:
|
||
return x
|
||
}
|
||
}
|
||
|
||
func renderPromptToolXMLBody(value any, indent string) (string, bool) {
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return "", true
|
||
case map[string]any:
|
||
return renderPromptToolXMLMap(v, indent)
|
||
case []any:
|
||
return renderPromptToolXMLArray(v, indent)
|
||
case string:
|
||
return indent + "<content>" + renderPromptXMLText(v) + "</content>", true
|
||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||
return indent + "<value>" + escapeXMLText(fmt.Sprint(v)) + "</value>", true
|
||
default:
|
||
return indent + "<value>" + renderPromptXMLText(fmt.Sprint(v)) + "</value>", true
|
||
}
|
||
}
|
||
|
||
func renderPromptToolXMLMap(m map[string]any, indent string) (string, bool) {
|
||
if len(m) == 0 {
|
||
return "", true
|
||
}
|
||
keys := make([]string, 0, len(m))
|
||
for k := range m {
|
||
if !isValidPromptXMLName(k) {
|
||
return "", false
|
||
}
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
|
||
lines := make([]string, 0, len(keys))
|
||
for _, key := range keys {
|
||
rendered, ok := renderPromptToolXMLNode(key, m[key], indent)
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
lines = append(lines, rendered)
|
||
}
|
||
return strings.Join(lines, "\n"), true
|
||
}
|
||
|
||
func renderPromptToolXMLArray(items []any, indent string) (string, bool) {
|
||
if len(items) == 0 {
|
||
return "", true
|
||
}
|
||
lines := make([]string, 0, len(items))
|
||
for _, item := range items {
|
||
rendered, ok := renderPromptToolXMLNode("item", item, indent)
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
lines = append(lines, rendered)
|
||
}
|
||
return strings.Join(lines, "\n"), true
|
||
}
|
||
|
||
func renderPromptToolXMLNode(name string, value any, indent string) (string, bool) {
|
||
if !isValidPromptXMLName(name) {
|
||
return "", false
|
||
}
|
||
switch v := value.(type) {
|
||
case nil:
|
||
return indent + "<" + name + "></" + name + ">", true
|
||
case map[string]any:
|
||
inner, ok := renderPromptToolXMLMap(v, indent+" ")
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
if strings.TrimSpace(inner) == "" {
|
||
return indent + "<" + name + "></" + name + ">", true
|
||
}
|
||
return indent + "<" + name + ">\n" + inner + "\n" + indent + "</" + name + ">", true
|
||
case []any:
|
||
if len(v) == 0 {
|
||
return indent + "<" + name + "></" + name + ">", true
|
||
}
|
||
lines := make([]string, 0, len(v))
|
||
for _, item := range v {
|
||
rendered, ok := renderPromptToolXMLNode(name, item, indent)
|
||
if !ok {
|
||
return "", false
|
||
}
|
||
lines = append(lines, rendered)
|
||
}
|
||
return strings.Join(lines, "\n"), true
|
||
case string:
|
||
return indent + "<" + name + ">" + renderPromptXMLText(v) + "</" + name + ">", true
|
||
case bool, float32, float64, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||
return indent + "<" + name + ">" + escapeXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||
default:
|
||
return indent + "<" + name + ">" + renderPromptXMLText(fmt.Sprint(v)) + "</" + name + ">", true
|
||
}
|
||
}
|
||
|
||
// renderPromptXMLText emits CDATA for every string so prompt-visible tool
|
||
// history stays uniform and does not drift back toward ad-hoc escaping.
|
||
func renderPromptXMLText(text string) string {
|
||
if text == "" {
|
||
return ""
|
||
}
|
||
if strings.Contains(text, "]]>") {
|
||
return "<![CDATA[" + strings.ReplaceAll(text, "]]>", "]]]]><![CDATA[>") + "]]>"
|
||
}
|
||
return "<![CDATA[" + text + "]]>"
|
||
}
|
||
|
||
func isValidPromptXMLName(name string) bool {
|
||
return promptXMLNamePattern.MatchString(strings.TrimSpace(name))
|
||
}
|
||
|
||
func escapeXMLAttribute(text string) string {
|
||
if text == "" {
|
||
return ""
|
||
}
|
||
return strings.NewReplacer(
|
||
"&", "&",
|
||
`"`, """,
|
||
"<", "<",
|
||
">", ">",
|
||
).Replace(text)
|
||
}
|
||
|
||
func normalizeToolArgumentString(raw string) string {
|
||
trimmed := strings.TrimSpace(raw)
|
||
if trimmed == "" {
|
||
return ""
|
||
}
|
||
if looksLikeConcatenatedJSON(trimmed) {
|
||
// Keep the original payload to avoid silently rewriting model output.
|
||
return raw
|
||
}
|
||
return trimmed
|
||
}
|
||
|
||
func looksLikeConcatenatedJSON(raw string) bool {
|
||
trimmed := strings.TrimSpace(raw)
|
||
if trimmed == "" {
|
||
return false
|
||
}
|
||
if strings.Contains(trimmed, "}{") || strings.Contains(trimmed, "][") {
|
||
return true
|
||
}
|
||
dec := json.NewDecoder(strings.NewReader(trimmed))
|
||
var first any
|
||
if err := dec.Decode(&first); err != nil {
|
||
return false
|
||
}
|
||
var second any
|
||
return dec.Decode(&second) == nil
|
||
}
|
||
|
||
func asString(v any) string {
|
||
if s, ok := v.(string); ok {
|
||
return s
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func escapeXMLText(v string) string {
|
||
if v == "" {
|
||
return ""
|
||
}
|
||
return promptXMLTextEscaper.Replace(v)
|
||
}
|