feat: extract and inject assistant reasoning content into history split prompts

This commit is contained in:
CJACK.
2026-04-22 19:56:28 +00:00
parent 2788e20f05
commit c291d333c4
8 changed files with 280 additions and 60 deletions

View File

@@ -30,6 +30,7 @@ func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, st
return stdReq, nil
}
reasoningContent := extractHistorySplitReasoningContent(historyMessages)
historyText := buildOpenAIHistoryTranscript(historyMessages)
if strings.TrimSpace(historyText) == "" {
return stdReq, errors.New("history split produced empty transcript")
@@ -51,12 +52,12 @@ func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, st
stdReq.Messages = promptMessages
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking)
stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, reasoningContent, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking)
return stdReq, nil
}
func buildHistorySplitPrompt(messages []any, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) {
if len(messages) == 0 {
func buildHistorySplitPrompt(messages []any, reasoningContent string, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) {
if len(messages) == 0 && strings.TrimSpace(reasoningContent) == "" {
return "", nil
}
instruction := historySplitPromptInstruction()
@@ -65,7 +66,7 @@ func buildHistorySplitPrompt(messages []any, toolsRaw any, toolPolicy util.ToolC
"role": "system",
"content": instruction,
})
withInstruction = append(withInstruction, messages...)
withInstruction = append(withInstruction, injectHistorySplitReasoningMessage(messages, reasoningContent)...)
return buildOpenAIFinalPromptWithPolicy(withInstruction, toolsRaw, "", toolPolicy, thinkingEnabled)
}
@@ -150,7 +151,7 @@ func buildOpenAIHistoryTranscript(messages []any) string {
func buildOpenAIHistoryEntry(role string, msg map[string]any) string {
switch role {
case "assistant":
return strings.TrimSpace(buildAssistantContentForPrompt(msg))
return strings.TrimSpace(buildAssistantHistoryContent(msg))
case "tool", "function":
return strings.TrimSpace(buildToolHistoryContent(msg))
case "user":
@@ -160,6 +161,10 @@ func buildOpenAIHistoryEntry(role string, msg map[string]any) string {
}
}
func buildAssistantHistoryContent(msg map[string]any) string {
return strings.TrimSpace(buildAssistantContentForPrompt(msg))
}
func buildToolHistoryContent(msg map[string]any) string {
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
parts := make([]string, 0, 2)
@@ -183,6 +188,68 @@ func buildToolHistoryContent(msg map[string]any) string {
}
}
func extractHistorySplitReasoningContent(messages []any) string {
for i := len(messages) - 1; i >= 0; i-- {
msg, ok := messages[i].(map[string]any)
if !ok {
continue
}
role := strings.ToLower(strings.TrimSpace(asString(msg["role"])))
if role != "assistant" {
continue
}
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
if reasoning == "" {
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
}
if reasoning != "" {
return reasoning
}
}
return ""
}
func injectHistorySplitReasoningMessage(messages []any, reasoningContent string) []any {
reasoningContent = strings.TrimSpace(reasoningContent)
if reasoningContent == "" {
return messages
}
reasoningMsg := map[string]any{
"role": "assistant",
"content": "",
"reasoning_content": reasoningContent,
}
lastUserIndex := lastOpenAIUserMessageIndex(messages)
if lastUserIndex < 0 {
out := make([]any, 0, len(messages)+1)
out = append(out, reasoningMsg)
out = append(out, messages...)
return out
}
out := make([]any, 0, len(messages)+1)
for i, raw := range messages {
if i == lastUserIndex {
out = append(out, reasoningMsg)
}
out = append(out, raw)
}
return out
}
func lastOpenAIUserMessageIndex(messages []any) int {
last := -1
for i, raw := range messages {
msg, ok := raw.(map[string]any)
if !ok {
continue
}
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) == "user" {
last = i
}
}
return last
}
func roleLabelForHistory(role string) string {
role = strings.ToLower(strings.TrimSpace(role))
switch role {

View File

@@ -59,8 +59,11 @@ func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T)
if !strings.Contains(transcript, "tool_call_id=call-1") {
t.Fatalf("expected tool call id in transcript, got %s", transcript)
}
if strings.Contains(transcript, "hidden reasoning") {
t.Fatalf("did not expect hidden reasoning in transcript, got %s", transcript)
if !strings.Contains(transcript, "[reasoning_content]") {
t.Fatalf("expected reasoning block in HISTORY.txt, got %s", transcript)
}
if !strings.Contains(transcript, "hidden reasoning") {
t.Fatalf("expected reasoning text in HISTORY.txt, got %s", transcript)
}
userIdx := strings.Index(transcript, "=== 1. USER ===")
@@ -72,14 +75,24 @@ func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T)
if userIdx >= assistantIdx || assistantIdx >= toolIdx {
t.Fatalf("expected USER -> ASSISTANT -> TOOL order, got %s", transcript)
}
if reasoningIdx := strings.Index(transcript, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(transcript, "<tool_calls>") {
t.Fatalf("expected reasoning block before tool calls, got %s", transcript)
}
reasoning := extractHistorySplitReasoningContent(historyMessages)
if reasoning != "hidden reasoning" {
t.Fatalf("expected latest assistant reasoning to be extracted, got %q", reasoning)
}
finalPrompt, _ := buildHistorySplitPrompt(promptMessages, nil, util.DefaultToolChoicePolicy(), false)
finalPrompt, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
if !strings.Contains(finalPrompt, "latest user turn") {
t.Fatalf("expected latest user turn in final prompt, got %s", finalPrompt)
}
if strings.Contains(finalPrompt, "first user turn") {
t.Fatalf("expected earlier history to be removed from final prompt, got %s", finalPrompt)
}
if !strings.Contains(finalPrompt, "[reasoning_content]") || !strings.Contains(finalPrompt, "hidden reasoning") {
t.Fatalf("expected latest assistant reasoning to be attached to prompt, got %s", finalPrompt)
}
if !strings.Contains(finalPrompt, "HISTORY.txt") {
t.Fatalf("expected history instruction in final prompt, got %s", finalPrompt)
}
@@ -118,8 +131,12 @@ func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) {
if len(promptMessages) == 0 || len(historyMessages) == 0 {
t.Fatalf("expected both prompt and history messages, got prompt=%d history=%d", len(promptMessages), len(historyMessages))
}
reasoning := extractHistorySplitReasoningContent(historyMessages)
if reasoning != "" {
t.Fatalf("expected no reasoning in this fixture, got %q", reasoning)
}
promptText := buildOpenAIFinalPromptForSplitTest(promptMessages)
promptText, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false)
if !strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected latest user turn in prompt, got %s", promptText)
}
@@ -136,11 +153,6 @@ func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) {
}
}
func buildOpenAIFinalPromptForSplitTest(messages []any) string {
prompt, _ := buildHistorySplitPrompt(messages, nil, util.DefaultToolChoicePolicy(), false)
return prompt
}
func TestApplyHistorySplitSkipsFirstTurn(t *testing.T) {
ds := &inlineUploadDSStub{}
h := &Handler{
@@ -233,6 +245,9 @@ func TestChatCompletionsHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testin
if strings.Contains(promptText, "first user turn") {
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
}
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
}
if !strings.Contains(promptText, "HISTORY.txt") {
t.Fatalf("expected history instruction in completion prompt, got %s", promptText)
}
@@ -283,6 +298,9 @@ func TestResponsesHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) {
if strings.Contains(promptText, "first user turn") {
t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText)
}
if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") {
t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText)
}
}
func TestChatCompletionsHistorySplitUploadFailureReturnsInternalServerError(t *testing.T) {

View File

@@ -6,6 +6,8 @@ import (
"ds2api/internal/prompt"
)
const assistantReasoningLabel = "reasoning_content"
func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any {
_ = traceID
out := make([]map[string]any, 0, len(raw))
@@ -55,17 +57,95 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
func buildAssistantContentForPrompt(msg map[string]any) string {
content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"]))
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
switch {
case content == "" && toolHistory == "":
return ""
case content == "":
return toolHistory
case toolHistory == "":
return content
default:
return content + "\n\n" + toolHistory
reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"]))
if reasoning == "" {
reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"]))
}
toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"])
parts := make([]string, 0, 3)
if reasoning != "" {
parts = append(parts, formatPromptLabeledBlock(assistantReasoningLabel, reasoning))
}
if content != "" {
parts = append(parts, content)
}
if toolHistory != "" {
parts = append(parts, toolHistory)
}
switch len(parts) {
case 0:
return ""
case 1:
return parts[0]
default:
return strings.Join(parts, "\n\n")
}
}
func normalizeOpenAIReasoningContentForPrompt(v any) string {
switch x := v.(type) {
case string:
return x
case []any:
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
case map[string]any:
return extractOpenAIReasoningTextFromItem(x)
default:
return ""
}
}
func extractOpenAIReasoningContentFromMessage(v any) string {
switch x := v.(type) {
case []any:
return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n")
case map[string]any:
return extractOpenAIReasoningTextFromItem(x)
default:
return ""
}
}
func extractOpenAIReasoningPartsFromItems(items []any) []string {
parts := make([]string, 0, len(items))
for _, item := range items {
if text := extractOpenAIReasoningTextFromItemMap(item); text != "" {
parts = append(parts, text)
}
}
return parts
}
func extractOpenAIReasoningTextFromItemMap(item any) string {
m, ok := item.(map[string]any)
if !ok {
return ""
}
return extractOpenAIReasoningTextFromItem(m)
}
func extractOpenAIReasoningTextFromItem(m map[string]any) string {
if m == nil {
return ""
}
switch strings.ToLower(strings.TrimSpace(asString(m["type"]))) {
case "reasoning", "thinking":
for _, key := range []string{"text", "thinking", "content"} {
if text := strings.TrimSpace(asString(m[key])); text != "" {
return text
}
}
}
return ""
}
func formatPromptLabeledBlock(label, text string) string {
label = strings.TrimSpace(label)
text = strings.TrimSpace(text)
if label == "" {
return text
}
return "[" + label + "]\n" + text + "\n[/" + label + "]"
}
func buildToolContentForPrompt(msg map[string]any) string {

View File

@@ -296,3 +296,31 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextE
t.Fatalf("expected content fallback text preserved, got %q", content)
}
}
func TestNormalizeOpenAIMessagesForPrompt_AssistantReasoningContentPreserved(t *testing.T) {
raw := []any{
map[string]any{
"role": "assistant",
"content": "visible answer",
"reasoning_content": "internal reasoning",
},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 1 {
t.Fatalf("expected one normalized assistant message, got %#v", normalized)
}
content, _ := normalized[0]["content"].(string)
if !strings.Contains(content, "[reasoning_content]") {
t.Fatalf("expected labeled reasoning block in assistant content, got %q", content)
}
if !strings.Contains(content, "internal reasoning") {
t.Fatalf("expected reasoning text in assistant content, got %q", content)
}
if !strings.Contains(content, "visible answer") {
t.Fatalf("expected visible answer in assistant content, got %q", content)
}
if reasoningIdx := strings.Index(content, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(content, "visible answer") {
t.Fatalf("expected reasoning block before visible answer, got %q", content)
}
}

View File

@@ -12,7 +12,6 @@ type toolStreamSieveState struct {
codeFenceStack []int
codeFencePendingTicks int
codeFenceLineStart bool
recentTextTail string
pendingToolRaw string
pendingToolCalls []toolcall.ParsedToolCall
disableDeltas bool
@@ -36,9 +35,6 @@ type toolCallDelta struct {
Arguments string
}
// Keep in sync with JS TOOL_SIEVE_CONTEXT_TAIL_LIMIT.
const toolSieveContextTailLimit = 2048
func (s *toolStreamSieveState) resetIncrementalToolState() {
s.disableDeltas = false
s.toolNameSent = false
@@ -54,18 +50,6 @@ func (s *toolStreamSieveState) noteText(content string) {
return
}
updateCodeFenceState(s, content)
s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit)
}
func appendTail(prev, next string, max int) string {
if max <= 0 {
return ""
}
combined := prev + next
if len(combined) <= max {
return combined
}
return combined[len(combined)-max:]
}
func hasMeaningfulText(text string) bool {

View File

@@ -42,6 +42,49 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) {
}
}
func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) {
var state toolStreamSieveState
const toolName = "write_to_file"
payload := strings.Repeat("x", 4096)
splitAt := len(payload) / 2
chunks := []string{
"<tool_calls>\n <tool_call>\n <tool_name>" + toolName + "</tool_name>\n <parameters>\n <content><![CDATA[",
payload[:splitAt],
payload[splitAt:],
"]]></content>\n </parameters>\n </tool_call>\n</tool_calls>",
}
var events []toolStreamEvent
for _, c := range chunks {
events = append(events, processToolSieveChunk(&state, c, []string{toolName})...)
}
events = append(events, flushToolSieve(&state, []string{toolName})...)
var textContent strings.Builder
toolCalls := 0
var gotPayload any
for _, evt := range events {
if evt.Content != "" {
textContent.WriteString(evt.Content)
}
if len(evt.ToolCalls) > 0 && gotPayload == nil {
gotPayload = evt.ToolCalls[0].Input["content"]
}
toolCalls += len(evt.ToolCalls)
}
if toolCalls != 1 {
t.Fatalf("expected one long XML tool call, got %d events=%#v", toolCalls, events)
}
if textContent.Len() != 0 {
t.Fatalf("expected no leaked text for long XML tool call, got %q", textContent.String())
}
got, _ := gotPayload.(string)
if got != payload {
t.Fatalf("expected long XML payload to survive intact, got len=%d want=%d", len(got), len(payload))
}
}
func TestProcessToolSieveXMLWithLeadingText(t *testing.T) {
var state toolStreamSieveState
// Model outputs some prose then an XML tool call.

View File

@@ -1,14 +1,10 @@
'use strict';
// Keep in sync with Go toolSieveContextTailLimit.
const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 2048;
function createToolSieveState() {
return {
pending: '',
capture: '',
capturing: false,
recentTextTail: '',
codeFenceStack: [],
codeFencePendingTicks: 0,
codeFenceLineStart: true,
@@ -39,20 +35,6 @@ function noteText(state, text) {
return;
}
updateCodeFenceState(state, text);
state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT);
}
function appendTail(prev, next, max) {
const left = typeof prev === 'string' ? prev : '';
const right = typeof next === 'string' ? next : '';
if (!Number.isFinite(max) || max <= 0) {
return '';
}
const combined = left + right;
if (combined.length <= max) {
return combined;
}
return combined.slice(combined.length - max);
}
function looksLikeToolExampleContext(text) {
@@ -171,11 +153,9 @@ function toStringSafe(v) {
}
module.exports = {
TOOL_SIEVE_CONTEXT_TAIL_LIMIT,
createToolSieveState,
resetIncrementalToolState,
noteText,
appendTail,
looksLikeToolExampleContext,
insideCodeFence,
insideCodeFenceWithState,

View File

@@ -98,6 +98,26 @@ test('sieve emits tool_calls when XML tag spans multiple chunks', () => {
assert.equal(finalCalls[0].name, 'read_file');
});
test('sieve keeps long XML tool calls buffered until the closing tag arrives', () => {
const longContent = 'x'.repeat(4096);
const splitAt = longContent.length / 2;
const events = runSieve(
[
'<tool_calls>\n <tool_call>\n <tool_name>write_to_file</tool_name>\n <parameters>\n <content><![CDATA[',
longContent.slice(0, splitAt),
longContent.slice(splitAt),
']]></content>\n </parameters>\n </tool_call>\n</tool_calls>',
],
['write_to_file'],
);
const leakedText = collectText(events);
const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []);
assert.equal(leakedText, '');
assert.equal(finalCalls.length, 1);
assert.equal(finalCalls[0].name, 'write_to_file');
assert.equal(finalCalls[0].input.content, longContent);
});
test('sieve passes JSON tool_calls payload through as text (XML-only)', () => {
const events = runSieve(
['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],