Files
ds2api/internal/promptcompat/thinking_injection.go
2026-04-26 14:21:15 +08:00

76 lines
2.5 KiB
Go

package promptcompat
import "strings"
const (
ThinkingInjectionMarker = "【思维链格式要求】"
DefaultThinkingInjectionPrompt = ThinkingInjectionMarker + "在你的思考过程(<think>标签内)中,请严格按照以下规则进行思考,不要遗漏:\n" +
"1. 分析阶段:分析用户需求是什么。\n" +
"2. 构思阶段:构思下一步动作,我要干什么。\n" +
"3. 工具调用阶段:为了满足用户需求,我需要调用什么工具;如果不需要工具,明确说明不需要调用工具。\n" +
"4. 回顾格式:完整复述一遍 System 要求的 XML 工具调用格式要求,回顾错误示例和正确示例,说明我要如何正确调用工具。"
)
func AppendThinkingInjectionToLatestUser(messages []any) ([]any, bool) {
return AppendThinkingInjectionPromptToLatestUser(messages, "")
}
func AppendThinkingInjectionPromptToLatestUser(messages []any, injectionPrompt string) ([]any, bool) {
if len(messages) == 0 {
return messages, false
}
injectionPrompt = strings.TrimSpace(injectionPrompt)
if injectionPrompt == "" {
injectionPrompt = DefaultThinkingInjectionPrompt
}
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) != "user" {
continue
}
content := msg["content"]
normalizedContent := NormalizeOpenAIContentForPrompt(content)
if strings.Contains(normalizedContent, ThinkingInjectionMarker) || strings.Contains(normalizedContent, injectionPrompt) {
return messages, false
}
updatedContent := appendThinkingInjectionToContent(content, injectionPrompt)
out := append([]any(nil), messages...)
cloned := make(map[string]any, len(msg))
for k, v := range msg {
cloned[k] = v
}
cloned["content"] = updatedContent
out[i] = cloned
return out, true
}
return messages, false
}
func appendThinkingInjectionToContent(content any, injectionPrompt string) any {
switch x := content.(type) {
case string:
return appendTextBlock(x, injectionPrompt)
case []any:
out := append([]any(nil), x...)
out = append(out, map[string]any{
"type": "text",
"text": injectionPrompt,
})
return out
default:
text := NormalizeOpenAIContentForPrompt(content)
return appendTextBlock(text, injectionPrompt)
}
}
func appendTextBlock(base, addition string) string {
base = strings.TrimSpace(base)
if base == "" {
return addition
}
return base + "\n\n" + addition
}