feat: integrate reasoning content into assistant tool-call messages and improve tool markup parsing for prompt compatibility

This commit is contained in:
CJACK
2026-05-09 23:16:07 +08:00
parent 9e9a7f1bec
commit 067cf465bb
15 changed files with 513 additions and 14 deletions

View File

@@ -101,6 +101,43 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
}
}
func TestNormalizeClaudeMessagesPreservesThinkingOnToolUseHistory(t *testing.T) {
msgs := []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "thinking", "thinking": "need live search before answering"},
map[string]any{
"type": "tool_use",
"id": "call_1",
"name": "search_web",
"input": map[string]any{"query": "latest"},
},
},
},
}
got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized tool-call message, got %#v", got)
}
m := got[0].(map[string]any)
if m["reasoning_content"] != "need live search before answering" {
t.Fatalf("expected thinking preserved as reasoning_content, got %#v", m)
}
tc, _ := m["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
}
prompt := buildClaudePromptTokenText(got, true)
if !containsStr(prompt, "[reasoning_content]\nneed live search before answering\n[/reasoning_content]") {
t.Fatalf("expected thinking in prompt history, got %q", prompt)
}
if !containsStr(prompt, `<|DSML|invoke name="search_web">`) {
t.Fatalf("expected tool call in prompt history, got %q", prompt)
}
}
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
msgs := []any{
map[string]any{

View File

@@ -25,14 +25,21 @@ func normalizeClaudeMessages(messages []any) []any {
switch content := msg["content"].(type) {
case []any:
textParts := make([]string, 0, len(content))
pendingThinking := ""
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
message := map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
}
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
message["reasoning_content"] = pendingThinking
message["content"] = prependClaudeReasoningForPrompt(pendingThinking, safeStringValue(message["content"]))
pendingThinking = ""
}
out = append(out, message)
textParts = textParts[:0]
}
for _, block := range content {
@@ -46,10 +53,29 @@ func normalizeClaudeMessages(messages []any) []any {
if t, ok := b["text"].(string); ok {
textParts = append(textParts, t)
}
case "thinking":
if role == "assistant" {
if thinking := extractClaudeThinkingBlockText(b); thinking != "" {
if pendingThinking == "" {
pendingThinking = thinking
} else {
pendingThinking += "\n" + thinking
}
}
continue
}
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
textParts = append(textParts, raw)
}
case "tool_use":
if role == "assistant" {
flushText()
if toolMsg := normalizeClaudeToolUseToAssistant(b, state); toolMsg != nil {
if strings.TrimSpace(pendingThinking) != "" {
toolMsg["reasoning_content"] = pendingThinking
toolMsg["content"] = prependClaudeReasoningForPrompt(pendingThinking, safeStringValue(toolMsg["content"]))
pendingThinking = ""
}
out = append(out, toolMsg)
}
continue
@@ -69,6 +95,13 @@ func normalizeClaudeMessages(messages []any) []any {
}
}
flushText()
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
out = append(out, map[string]any{
"role": "assistant",
"reasoning_content": pendingThinking,
"content": formatClaudeReasoningForPrompt(pendingThinking),
})
}
default:
copied := cloneMap(msg)
out = append(out, copied)
@@ -77,6 +110,39 @@ func normalizeClaudeMessages(messages []any) []any {
return out
}
func prependClaudeReasoningForPrompt(reasoning, content string) string {
reasoning = strings.TrimSpace(reasoning)
content = strings.TrimSpace(content)
if reasoning == "" {
return content
}
block := formatClaudeReasoningForPrompt(reasoning)
if content == "" {
return block
}
return block + "\n\n" + content
}
func formatClaudeReasoningForPrompt(reasoning string) string {
reasoning = strings.TrimSpace(reasoning)
if reasoning == "" {
return ""
}
return "[reasoning_content]\n" + reasoning + "\n[/reasoning_content]"
}
func extractClaudeThinkingBlockText(block map[string]any) string {
if block == nil {
return ""
}
for _, key := range []string{"thinking", "text", "content"} {
if text := strings.TrimSpace(safeStringValue(block[key])); text != "" {
return text
}
}
return ""
}
func buildClaudeToolPrompt(tools []any) string {
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))

View File

@@ -44,14 +44,20 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
textParts := make([]string, 0, len(parts))
pendingThinking := ""
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
msg := map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
}
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
msg["reasoning_content"] = pendingThinking
pendingThinking = ""
}
out = append(out, msg)
textParts = textParts[:0]
}
@@ -61,6 +67,14 @@ func geminiMessagesFromRequest(req map[string]any) []any {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
if role == "assistant" && isGeminiThoughtPart(part) {
if pendingThinking == "" {
pendingThinking = text
} else {
pendingThinking += "\n" + text
}
continue
}
textParts = append(textParts, text)
continue
}
@@ -75,7 +89,7 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
}
lastToolCallIDByName[strings.ToLower(name)] = callID
out = append(out, map[string]any{
msg := map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
@@ -87,7 +101,12 @@ func geminiMessagesFromRequest(req map[string]any) []any {
},
},
},
})
}
if strings.TrimSpace(pendingThinking) != "" {
msg["reasoning_content"] = pendingThinking
pendingThinking = ""
}
out = append(out, msg)
}
continue
}
@@ -132,10 +151,29 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
}
flushText()
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
out = append(out, map[string]any{
"role": "assistant",
"reasoning_content": pendingThinking,
})
}
}
return out
}
func isGeminiThoughtPart(part map[string]any) bool {
if part == nil {
return false
}
if v, ok := part["thought"].(bool); ok {
return v
}
if v, ok := part["thoughtSignature"].(string); ok && strings.TrimSpace(v) != "" {
return true
}
return false
}
func normalizeGeminiSystemInstruction(raw any) string {
switch v := raw.(type) {
case string:

View File

@@ -1,6 +1,7 @@
package gemini
import (
"ds2api/internal/promptcompat"
"strings"
"testing"
)
@@ -53,6 +54,46 @@ func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) {
}
}
func TestGeminiMessagesFromRequestPreservesThoughtOnFunctionCallHistory(t *testing.T) {
req := map[string]any{
"contents": []any{
map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "need current state before answering", "thought": true},
map[string]any{
"functionCall": map[string]any{
"id": "call_g1",
"name": "search_web",
"args": map[string]any{"query": "ai"},
},
},
},
},
},
}
got := geminiMessagesFromRequest(req)
if len(got) != 1 {
t.Fatalf("expected one normalized message, got %#v", got)
}
assistant, _ := got[0].(map[string]any)
if assistant["reasoning_content"] != "need current state before answering" {
t.Fatalf("expected thought preserved as reasoning_content, got %#v", assistant)
}
tc, _ := assistant["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"])
}
prompt, _ := promptcompat.BuildOpenAIPromptForAdapter(got, nil, "", true)
if !strings.Contains(prompt, "[reasoning_content]\nneed current state before answering\n[/reasoning_content]") {
t.Fatalf("expected thought in prompt history, got %q", prompt)
}
if !strings.Contains(prompt, `<|DSML|invoke name="search_web">`) {
t.Fatalf("expected tool call in prompt history, got %q", prompt)
}
}
func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) {
req := map[string]any{
"contents": []any{

View File

@@ -81,6 +81,22 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
},
},
})
} else if len(calls) > 0 && strings.TrimSpace(finalThinking) != "" {
indexed = append(indexed, indexedItem{
index: s.ensureMessageOutputIndex(),
item: map[string]any{
"id": s.ensureMessageItemID(),
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]any{
{
"type": "reasoning",
"text": finalThinking,
},
},
},
})
} else if len(calls) == 0 {
content := make([]map[string]any, 0, 2)
if finalThinking != "" {