Files
ds2api/internal/prompt/messages.go

98 lines
2.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package prompt
import (
"encoding/json"
"fmt"
"regexp"
"strings"
)
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
func MessagesPrepare(messages []map[string]any) string {
type block struct {
Role string
Text string
}
processed := make([]block, 0, len(messages))
for _, m := range messages {
role, _ := m["role"].(string)
text := NormalizeContent(m["content"])
processed = append(processed, block{Role: role, Text: text})
}
if len(processed) == 0 {
return ""
}
merged := make([]block, 0, len(processed))
for _, msg := range processed {
if len(merged) > 0 && merged[len(merged)-1].Role == msg.Role {
merged[len(merged)-1].Text += "\n\n" + msg.Text
continue
}
merged = append(merged, msg)
}
parts := make([]string, 0, len(merged))
for _, m := range merged {
switch m.Role {
case "assistant":
// Keep assistant turns on their own block so the model sees a clear
// boundary between prior answer text and the EOS marker.
parts = append(parts, "<Assistant>\n"+m.Text+"\n<end▁of▁sentence>")
case "tool":
if strings.TrimSpace(m.Text) != "" {
parts = append(parts, "<Tool>\n"+m.Text)
}
case "system":
// Clear system boundary improves R1 and V3 context understanding significantly.
if text := strings.TrimSpace(m.Text); text != "" {
parts = append(parts, "<system_instructions>\n"+text+"\n</system_instructions>")
}
case "user":
// Put user turns on their own line so the role transition is explicit.
parts = append(parts, "<User>\n"+m.Text)
default:
if strings.TrimSpace(m.Text) != "" {
parts = append(parts, m.Text)
}
}
}
out := strings.Join(parts, "\n\n")
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
}
func NormalizeContent(v any) string {
if v == nil {
return ""
}
switch x := v.(type) {
case string:
return x
case []any:
parts := make([]string, 0, len(x))
for _, item := range x {
m, ok := item.(map[string]any)
if !ok {
continue
}
typeStr, _ := m["type"].(string)
typeStr = strings.ToLower(strings.TrimSpace(typeStr))
if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" {
if txt, ok := m["text"].(string); ok && txt != "" {
parts = append(parts, txt)
continue
}
if txt, ok := m["content"].(string); ok && txt != "" {
parts = append(parts, txt)
}
}
}
return strings.Join(parts, "\n")
default:
b, err := json.Marshal(v)
if err != nil {
return fmt.Sprintf("%v", v)
}
return string(b)
}
}