feat: integrate reasoning content into assistant tool-call messages and improve tool markup parsing for prompt compatibility

This commit is contained in:
CJACK
2026-05-09 23:16:07 +08:00
parent 9e9a7f1bec
commit 067cf465bb
15 changed files with 513 additions and 14 deletions

View File

@@ -21,6 +21,18 @@ func BuildResponseObjectWithToolCalls(responseID, model, finalPrompt, finalThink
output := make([]any, 0, 2)
if len(detected) > 0 {
exposedOutputText = ""
if strings.TrimSpace(finalThinking) != "" {
output = append(output, map[string]any{
"type": "message",
"id": "msg_" + strings.ReplaceAll(uuid.NewString(), "-", ""),
"role": "assistant",
"status": "completed",
"content": []any{map[string]any{
"type": "reasoning",
"text": finalThinking,
}},
})
}
output = append(output, toResponsesFunctionCallItems(detected, toolsRaw)...)
} else {
content := make([]any, 0, 2)

View File

@@ -85,12 +85,24 @@ func TestBuildResponseObjectPromotesToolCallFromThinkingWhenTextEmpty(t *testing
)
output, _ := obj["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected one output item, got %#v", obj["output"])
if len(output) != 2 {
t.Fatalf("expected reasoning message plus function_call output, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected function_call output, got %#v", first["type"])
if first["type"] != "message" {
t.Fatalf("expected reasoning message output first, got %#v", first["type"])
}
content, _ := first["content"].([]any)
if len(content) != 1 {
t.Fatalf("expected reasoning content, got %#v", first["content"])
}
block0, _ := content[0].(map[string]any)
if block0["type"] != "reasoning" {
t.Fatalf("expected reasoning block, got %#v", block0["type"])
}
second, _ := output[1].(map[string]any)
if second["type"] != "function_call" {
t.Fatalf("expected function_call output, got %#v", second["type"])
}
}

View File

@@ -101,6 +101,43 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
}
}
func TestNormalizeClaudeMessagesPreservesThinkingOnToolUseHistory(t *testing.T) {
msgs := []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{"type": "thinking", "thinking": "need live search before answering"},
map[string]any{
"type": "tool_use",
"id": "call_1",
"name": "search_web",
"input": map[string]any{"query": "latest"},
},
},
},
}
got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized tool-call message, got %#v", got)
}
m := got[0].(map[string]any)
if m["reasoning_content"] != "need live search before answering" {
t.Fatalf("expected thinking preserved as reasoning_content, got %#v", m)
}
tc, _ := m["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
}
prompt := buildClaudePromptTokenText(got, true)
if !containsStr(prompt, "[reasoning_content]\nneed live search before answering\n[/reasoning_content]") {
t.Fatalf("expected thinking in prompt history, got %q", prompt)
}
if !containsStr(prompt, `<|DSML|invoke name="search_web">`) {
t.Fatalf("expected tool call in prompt history, got %q", prompt)
}
}
func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
msgs := []any{
map[string]any{

View File

@@ -25,14 +25,21 @@ func normalizeClaudeMessages(messages []any) []any {
switch content := msg["content"].(type) {
case []any:
textParts := make([]string, 0, len(content))
pendingThinking := ""
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
message := map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
}
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
message["reasoning_content"] = pendingThinking
message["content"] = prependClaudeReasoningForPrompt(pendingThinking, safeStringValue(message["content"]))
pendingThinking = ""
}
out = append(out, message)
textParts = textParts[:0]
}
for _, block := range content {
@@ -46,10 +53,29 @@ func normalizeClaudeMessages(messages []any) []any {
if t, ok := b["text"].(string); ok {
textParts = append(textParts, t)
}
case "thinking":
if role == "assistant" {
if thinking := extractClaudeThinkingBlockText(b); thinking != "" {
if pendingThinking == "" {
pendingThinking = thinking
} else {
pendingThinking += "\n" + thinking
}
}
continue
}
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
textParts = append(textParts, raw)
}
case "tool_use":
if role == "assistant" {
flushText()
if toolMsg := normalizeClaudeToolUseToAssistant(b, state); toolMsg != nil {
if strings.TrimSpace(pendingThinking) != "" {
toolMsg["reasoning_content"] = pendingThinking
toolMsg["content"] = prependClaudeReasoningForPrompt(pendingThinking, safeStringValue(toolMsg["content"]))
pendingThinking = ""
}
out = append(out, toolMsg)
}
continue
@@ -69,6 +95,13 @@ func normalizeClaudeMessages(messages []any) []any {
}
}
flushText()
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
out = append(out, map[string]any{
"role": "assistant",
"reasoning_content": pendingThinking,
"content": formatClaudeReasoningForPrompt(pendingThinking),
})
}
default:
copied := cloneMap(msg)
out = append(out, copied)
@@ -77,6 +110,39 @@ func normalizeClaudeMessages(messages []any) []any {
return out
}
func prependClaudeReasoningForPrompt(reasoning, content string) string {
reasoning = strings.TrimSpace(reasoning)
content = strings.TrimSpace(content)
if reasoning == "" {
return content
}
block := formatClaudeReasoningForPrompt(reasoning)
if content == "" {
return block
}
return block + "\n\n" + content
}
func formatClaudeReasoningForPrompt(reasoning string) string {
reasoning = strings.TrimSpace(reasoning)
if reasoning == "" {
return ""
}
return "[reasoning_content]\n" + reasoning + "\n[/reasoning_content]"
}
func extractClaudeThinkingBlockText(block map[string]any) string {
if block == nil {
return ""
}
for _, key := range []string{"thinking", "text", "content"} {
if text := strings.TrimSpace(safeStringValue(block[key])); text != "" {
return text
}
}
return ""
}
func buildClaudeToolPrompt(tools []any) string {
toolSchemas := make([]string, 0, len(tools))
names := make([]string, 0, len(tools))

View File

@@ -44,14 +44,20 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
textParts := make([]string, 0, len(parts))
pendingThinking := ""
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
msg := map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
}
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
msg["reasoning_content"] = pendingThinking
pendingThinking = ""
}
out = append(out, msg)
textParts = textParts[:0]
}
@@ -61,6 +67,14 @@ func geminiMessagesFromRequest(req map[string]any) []any {
continue
}
if text := strings.TrimSpace(asString(part["text"])); text != "" {
if role == "assistant" && isGeminiThoughtPart(part) {
if pendingThinking == "" {
pendingThinking = text
} else {
pendingThinking += "\n" + text
}
continue
}
textParts = append(textParts, text)
continue
}
@@ -75,7 +89,7 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
}
lastToolCallIDByName[strings.ToLower(name)] = callID
out = append(out, map[string]any{
msg := map[string]any{
"role": "assistant",
"tool_calls": []any{
map[string]any{
@@ -87,7 +101,12 @@ func geminiMessagesFromRequest(req map[string]any) []any {
},
},
},
})
}
if strings.TrimSpace(pendingThinking) != "" {
msg["reasoning_content"] = pendingThinking
pendingThinking = ""
}
out = append(out, msg)
}
continue
}
@@ -132,10 +151,29 @@ func geminiMessagesFromRequest(req map[string]any) []any {
}
}
flushText()
if role == "assistant" && strings.TrimSpace(pendingThinking) != "" {
out = append(out, map[string]any{
"role": "assistant",
"reasoning_content": pendingThinking,
})
}
}
return out
}
func isGeminiThoughtPart(part map[string]any) bool {
if part == nil {
return false
}
if v, ok := part["thought"].(bool); ok {
return v
}
if v, ok := part["thoughtSignature"].(string); ok && strings.TrimSpace(v) != "" {
return true
}
return false
}
func normalizeGeminiSystemInstruction(raw any) string {
switch v := raw.(type) {
case string:

View File

@@ -1,6 +1,7 @@
package gemini
import (
"ds2api/internal/promptcompat"
"strings"
"testing"
)
@@ -53,6 +54,46 @@ func TestGeminiMessagesFromRequestPreservesFunctionRoundtrip(t *testing.T) {
}
}
func TestGeminiMessagesFromRequestPreservesThoughtOnFunctionCallHistory(t *testing.T) {
req := map[string]any{
"contents": []any{
map[string]any{
"role": "model",
"parts": []any{
map[string]any{"text": "need current state before answering", "thought": true},
map[string]any{
"functionCall": map[string]any{
"id": "call_g1",
"name": "search_web",
"args": map[string]any{"query": "ai"},
},
},
},
},
},
}
got := geminiMessagesFromRequest(req)
if len(got) != 1 {
t.Fatalf("expected one normalized message, got %#v", got)
}
assistant, _ := got[0].(map[string]any)
if assistant["reasoning_content"] != "need current state before answering" {
t.Fatalf("expected thought preserved as reasoning_content, got %#v", assistant)
}
tc, _ := assistant["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("expected one tool call, got %#v", assistant["tool_calls"])
}
prompt, _ := promptcompat.BuildOpenAIPromptForAdapter(got, nil, "", true)
if !strings.Contains(prompt, "[reasoning_content]\nneed current state before answering\n[/reasoning_content]") {
t.Fatalf("expected thought in prompt history, got %q", prompt)
}
if !strings.Contains(prompt, `<|DSML|invoke name="search_web">`) {
t.Fatalf("expected tool call in prompt history, got %q", prompt)
}
}
func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T) {
req := map[string]any{
"contents": []any{

View File

@@ -81,6 +81,22 @@ func (s *responsesStreamRuntime) buildCompletedResponseObject(finalThinking, fin
},
},
})
} else if len(calls) > 0 && strings.TrimSpace(finalThinking) != "" {
indexed = append(indexed, indexedItem{
index: s.ensureMessageOutputIndex(),
item: map[string]any{
"id": s.ensureMessageItemID(),
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]any{
{
"type": "reasoning",
"text": finalThinking,
},
},
},
})
} else if len(calls) == 0 {
content := make([]map[string]any, 0, 2)
if finalThinking != "" {

View File

@@ -616,14 +616,55 @@ function consumeToolMarkupNamePrefixOnce(raw, lower, idx) {
}
if (lower.startsWith('dsml', idx)) {
let next = idx + 'dsml'.length;
if (next < raw.length && raw[next] === '-') {
if (next < raw.length && (raw[next] === '-' || raw[next] === '_')) {
next += 1;
}
return { next, ok: true };
}
const arbitrary = consumeArbitraryToolMarkupNamePrefix(raw, lower, idx);
if (arbitrary.ok) {
return arbitrary;
}
return { next: idx, ok: false };
}
function consumeArbitraryToolMarkupNamePrefix(raw, lower, idx) {
if (idx < 0 || idx >= raw.length || !isToolMarkupPrefixSegmentChar(raw[idx])) {
return { next: idx, ok: false };
}
let j = idx + 1;
while (j < raw.length && isToolMarkupPrefixSegmentChar(raw[j])) {
j += 1;
}
let k = j;
while (k < raw.length && [' ', '\t', '\r', '\n'].includes(raw[k])) {
k += 1;
}
let next = k;
let ok = false;
if (next < raw.length && isToolMarkupPipe(raw[next])) {
next += 1;
ok = true;
} else if (next < raw.length && (raw[next] === '_' || raw[next] === '-')) {
next += 1;
ok = true;
}
if (!ok) {
return { next: idx, ok: false };
}
while (next < raw.length && [' ', '\t', '\r', '\n'].includes(raw[next])) {
next += 1;
}
if (!hasToolMarkupNamePrefix(lower.slice(next))) {
return { next: idx, ok: false };
}
return { next, ok: true };
}
function isToolMarkupPrefixSegmentChar(ch) {
return /^[A-Za-z0-9]$/.test(ch);
}
function hasToolMarkupNamePrefix(lowerTail) {
for (const name of TOOL_MARKUP_NAMES) {
if (lowerTail.startsWith(name.raw) || name.raw.startsWith(lowerTail)) {

View File

@@ -1,6 +1,9 @@
package promptcompat
import "testing"
import (
"strings"
"testing"
)
func TestNormalizeResponsesInputItemPreservesAssistantReasoningContent(t *testing.T) {
item := map[string]any{
@@ -48,3 +51,44 @@ func TestNormalizeResponsesInputItemAssistantMessageWithReasoningBlocks(t *testi
t.Fatalf("expected content blocks preserved, got %#v", got["content"])
}
}
func TestNormalizeResponsesInputArrayMergesReasoningMessageIntoFunctionCallHistory(t *testing.T) {
input := []any{
map[string]any{
"type": "message",
"role": "assistant",
"content": []any{
map[string]any{"type": "reasoning", "text": "need fresh docs before answering"},
},
},
map[string]any{
"type": "function_call",
"call_id": "call_search",
"name": "search_web",
"arguments": `{"query":"docs"}`,
},
}
got := NormalizeResponsesInputAsMessages(input)
if len(got) != 1 {
t.Fatalf("expected reasoning and function_call merged into one assistant message, got %#v", got)
}
msg, _ := got[0].(map[string]any)
if msg["role"] != "assistant" {
t.Fatalf("expected assistant message, got %#v", msg)
}
if msg["reasoning_content"] != "need fresh docs before answering" {
t.Fatalf("expected reasoning_content on tool-call message, got %#v", msg)
}
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected one tool call, got %#v", msg["tool_calls"])
}
history := BuildOpenAIHistoryTranscript(got)
if !strings.Contains(history, "[reasoning_content]\nneed fresh docs before answering\n[/reasoning_content]") {
t.Fatalf("expected reasoning in history transcript, got %q", history)
}
if !strings.Contains(history, `<|DSML|invoke name="search_web">`) {
t.Fatalf("expected tool call in history transcript, got %q", history)
}
}

View File

@@ -61,19 +61,52 @@ func normalizeResponsesInputArray(items []any) []any {
out := make([]any, 0, len(items))
callNameByID := map[string]string{}
fallbackParts := make([]string, 0, len(items))
pendingAssistantReasoning := ""
flushFallback := func() {
if len(fallbackParts) == 0 {
return
}
if pendingAssistantReasoning != "" {
out = append(out, map[string]any{"role": "assistant", "reasoning_content": pendingAssistantReasoning})
pendingAssistantReasoning = ""
}
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
fallbackParts = fallbackParts[:0]
}
flushPendingReasoning := func() {
if pendingAssistantReasoning == "" {
return
}
out = append(out, map[string]any{"role": "assistant", "reasoning_content": pendingAssistantReasoning})
pendingAssistantReasoning = ""
}
for _, item := range items {
switch x := item.(type) {
case map[string]any:
if msg := normalizeResponsesInputItemWithState(x, callNameByID); msg != nil {
if reasoning := assistantReasoningOnlyContent(msg); reasoning != "" {
if pendingAssistantReasoning == "" {
pendingAssistantReasoning = reasoning
} else {
pendingAssistantReasoning += "\n" + reasoning
}
continue
}
if isAssistantToolCallMessage(msg) && pendingAssistantReasoning != "" {
if strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"])) == "" {
msg["reasoning_content"] = pendingAssistantReasoning
}
pendingAssistantReasoning = ""
} else {
flushPendingReasoning()
}
flushFallback()
if isAssistantToolCallMessage(msg) && len(out) > 0 {
if merged := mergeResponsesAssistantToolCalls(out[len(out)-1], msg); merged {
continue
}
}
out = append(out, msg)
continue
}
@@ -86,9 +119,55 @@ func normalizeResponsesInputArray(items []any) []any {
}
}
}
flushPendingReasoning()
flushFallback()
if len(out) == 0 {
return nil
}
return out
}
func assistantReasoningOnlyContent(msg map[string]any) string {
if !isAssistantMessage(msg) || isAssistantToolCallMessage(msg) {
return ""
}
if _, hasContent := msg["content"]; hasContent {
normalizedContent := strings.TrimSpace(NormalizeOpenAIContentForPrompt(msg["content"]))
reasoningFromContent := strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
if normalizedContent != "" && normalizedContent != reasoningFromContent {
return ""
}
if reasoningFromContent != "" {
return reasoningFromContent
}
}
return strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
}
func isAssistantMessage(msg map[string]any) bool {
return strings.EqualFold(strings.TrimSpace(asString(msg["role"])), "assistant")
}
func isAssistantToolCallMessage(msg map[string]any) bool {
if !isAssistantMessage(msg) {
return false
}
toolCalls, ok := msg["tool_calls"].([]any)
return ok && len(toolCalls) > 0
}
func mergeResponsesAssistantToolCalls(prev any, next map[string]any) bool {
prevMsg, ok := prev.(map[string]any)
if !ok || !isAssistantToolCallMessage(prevMsg) || !isAssistantToolCallMessage(next) {
return false
}
prevCalls, _ := prevMsg["tool_calls"].([]any)
nextCalls, _ := next["tool_calls"].([]any)
prevMsg["tool_calls"] = append(prevCalls, nextCalls...)
if strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(prevMsg["reasoning_content"])) == "" {
if reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(next["reasoning_content"])); reasoning != "" {
prevMsg["reasoning_content"] = reasoning
}
}
return true
}

View File

@@ -242,9 +242,47 @@ func consumeToolMarkupNamePrefixOnce(text string, idx int) (int, bool) {
}
return next, true
}
if next, ok := consumeArbitraryToolMarkupNamePrefix(text, idx); ok {
return next, true
}
return idx, false
}
func consumeArbitraryToolMarkupNamePrefix(text string, idx int) (int, bool) {
if idx < 0 || idx >= len(text) || !isToolMarkupPrefixSegmentByte(text[idx]) {
return idx, false
}
j := idx + 1
for j < len(text) && isToolMarkupPrefixSegmentByte(text[j]) {
j++
}
k := j
for k < len(text) && (text[k] == ' ' || text[k] == '\t' || text[k] == '\r' || text[k] == '\n') {
k++
}
next, ok := consumeToolMarkupPipe(text, k)
if !ok {
if k < len(text) && (text[k] == '_' || text[k] == '-') {
next = k + 1
ok = true
}
}
if !ok {
return idx, false
}
for next < len(text) && (text[next] == ' ' || text[next] == '\t' || text[next] == '\r' || text[next] == '\n') {
next++
}
if !hasToolMarkupNamePrefix(text, next) {
return idx, false
}
return next, true
}
func isToolMarkupPrefixSegmentByte(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func hasASCIIPartialPrefixFoldAt(text string, start int, prefix string) bool {
remain := len(text) - start
if remain <= 0 || remain > len(prefix) {

View File

@@ -72,6 +72,45 @@ EOF
}
}
func TestParseToolCallsSupportsUnderscoredDSMLShell(t *testing.T) {
text := `<dsml_tool_calls>
<dsml_invoke name="search_web">
<dsml_parameter name="query"><![CDATA[2026年5月 热点事件]]></dsml_parameter>
<dsml_parameter name="topic"><![CDATA[news]]></dsml_parameter>
</dsml_invoke>
<dsml_invoke name="eval_javascript">
<dsml_parameter name="code"><![CDATA[1 + 1]]></dsml_parameter>
</dsml_invoke>
</dsml_tool_calls>`
calls := ParseToolCalls(text, []string{"search_web", "eval_javascript"})
if len(calls) != 2 {
t.Fatalf("expected two underscored DSML calls, got %#v", calls)
}
if calls[0].Name != "search_web" || calls[0].Input["query"] != "2026年5月 热点事件" || calls[0].Input["topic"] != "news" {
t.Fatalf("unexpected first underscored DSML call: %#v", calls[0])
}
if calls[1].Name != "eval_javascript" || calls[1].Input["code"] != "1 + 1" {
t.Fatalf("unexpected second underscored DSML call: %#v", calls[1])
}
}
func TestParseToolCallsSupportsArbitraryPrefixedToolMarkup(t *testing.T) {
cases := []string{
`<abc|tool_calls><abc|invoke name="Read"><abc|parameter name="file_path">README.md</abc|parameter></abc|invoke></abc|tool_calls>`,
`<vendor_tool_calls><vendor_invoke name="Read"><vendor_parameter name="file_path">README.md</vendor_parameter></vendor_invoke></vendor_tool_calls>`,
`<agent - tool_calls><agent - invoke name="Read"><agent - parameter name="file_path">README.md</agent - parameter></agent - invoke></agent - tool_calls>`,
}
for _, text := range cases {
calls := ParseToolCalls(text, []string{"Read"})
if len(calls) != 1 {
t.Fatalf("expected one arbitrary-prefixed tool call for %q, got %#v", text, calls)
}
if calls[0].Name != "Read" || calls[0].Input["file_path"] != "README.md" {
t.Fatalf("unexpected arbitrary-prefixed parse result: %#v", calls[0])
}
}
}
func TestParseToolCallsIgnoresBareHyphenatedToolCallsLookalike(t *testing.T) {
text := `<tool-calls><invoke name="Bash"><parameter name="command">pwd</parameter></invoke></tool-calls>`
calls := ParseToolCalls(text, []string{"Bash"})