From c291d333c4ed65492f71d04e2f1d11b1719c9bd4 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Wed, 22 Apr 2026 19:56:28 +0000 Subject: [PATCH] feat: extract and inject assistant reasoning content into history split prompts --- internal/adapter/openai/history_split.go | 77 +++++++++++++- internal/adapter/openai/history_split_test.go | 36 +++++-- internal/adapter/openai/message_normalize.go | 100 ++++++++++++++++-- .../adapter/openai/message_normalize_test.go | 28 +++++ internal/adapter/openai/tool_sieve_state.go | 16 --- .../adapter/openai/tool_sieve_xml_test.go | 43 ++++++++ .../js/helpers/stream-tool-sieve/state.js | 20 ---- tests/node/stream-tool-sieve.test.js | 20 ++++ 8 files changed, 280 insertions(+), 60 deletions(-) diff --git a/internal/adapter/openai/history_split.go b/internal/adapter/openai/history_split.go index b65e2e2..39c9d79 100644 --- a/internal/adapter/openai/history_split.go +++ b/internal/adapter/openai/history_split.go @@ -30,6 +30,7 @@ func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, st return stdReq, nil } + reasoningContent := extractHistorySplitReasoningContent(historyMessages) historyText := buildOpenAIHistoryTranscript(historyMessages) if strings.TrimSpace(historyText) == "" { return stdReq, errors.New("history split produced empty transcript") @@ -51,12 +52,12 @@ func (h *Handler) applyHistorySplit(ctx context.Context, a *auth.RequestAuth, st stdReq.Messages = promptMessages stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID) - stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking) + stdReq.FinalPrompt, stdReq.ToolNames = buildHistorySplitPrompt(promptMessages, reasoningContent, stdReq.ToolsRaw, stdReq.ToolChoice, stdReq.Thinking) return stdReq, nil } -func buildHistorySplitPrompt(messages []any, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) { - if len(messages) == 0 { +func buildHistorySplitPrompt(messages []any, reasoningContent string, toolsRaw any, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) { + if len(messages) == 0 && strings.TrimSpace(reasoningContent) == "" { return "", nil } instruction := historySplitPromptInstruction() @@ -65,7 +66,7 @@ func buildHistorySplitPrompt(messages []any, toolsRaw any, toolPolicy util.ToolC "role": "system", "content": instruction, }) - withInstruction = append(withInstruction, messages...) + withInstruction = append(withInstruction, injectHistorySplitReasoningMessage(messages, reasoningContent)...) return buildOpenAIFinalPromptWithPolicy(withInstruction, toolsRaw, "", toolPolicy, thinkingEnabled) } @@ -150,7 +151,7 @@ func buildOpenAIHistoryTranscript(messages []any) string { func buildOpenAIHistoryEntry(role string, msg map[string]any) string { switch role { case "assistant": - return strings.TrimSpace(buildAssistantContentForPrompt(msg)) + return strings.TrimSpace(buildAssistantHistoryContent(msg)) case "tool", "function": return strings.TrimSpace(buildToolHistoryContent(msg)) case "user": @@ -160,6 +161,10 @@ func buildOpenAIHistoryEntry(role string, msg map[string]any) string { } } +func buildAssistantHistoryContent(msg map[string]any) string { + return strings.TrimSpace(buildAssistantContentForPrompt(msg)) +} + func buildToolHistoryContent(msg map[string]any) string { content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"])) parts := make([]string, 0, 2) @@ -183,6 +188,68 @@ func buildToolHistoryContent(msg map[string]any) string { } } +func extractHistorySplitReasoningContent(messages []any) string { + for i := len(messages) - 1; i >= 0; i-- { + msg, ok := messages[i].(map[string]any) + if !ok { + continue + } + role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) + if role != "assistant" { + continue + } + reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"])) + if reasoning == "" { + reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"])) + } + if reasoning != "" { + return reasoning + } + } + return "" +} + +func injectHistorySplitReasoningMessage(messages []any, reasoningContent string) []any { + reasoningContent = strings.TrimSpace(reasoningContent) + if reasoningContent == "" { + return messages + } + reasoningMsg := map[string]any{ + "role": "assistant", + "content": "", + "reasoning_content": reasoningContent, + } + lastUserIndex := lastOpenAIUserMessageIndex(messages) + if lastUserIndex < 0 { + out := make([]any, 0, len(messages)+1) + out = append(out, reasoningMsg) + out = append(out, messages...) + return out + } + out := make([]any, 0, len(messages)+1) + for i, raw := range messages { + if i == lastUserIndex { + out = append(out, reasoningMsg) + } + out = append(out, raw) + } + return out +} + +func lastOpenAIUserMessageIndex(messages []any) int { + last := -1 + for i, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) == "user" { + last = i + } + } + return last +} + func roleLabelForHistory(role string) string { role = strings.ToLower(strings.TrimSpace(role)) switch role { diff --git a/internal/adapter/openai/history_split_test.go b/internal/adapter/openai/history_split_test.go index 46fa366..864c763 100644 --- a/internal/adapter/openai/history_split_test.go +++ b/internal/adapter/openai/history_split_test.go @@ -59,8 +59,11 @@ func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T) if !strings.Contains(transcript, "tool_call_id=call-1") { t.Fatalf("expected tool call id in transcript, got %s", transcript) } - if strings.Contains(transcript, "hidden reasoning") { - t.Fatalf("did not expect hidden reasoning in transcript, got %s", transcript) + if !strings.Contains(transcript, "[reasoning_content]") { + t.Fatalf("expected reasoning block in HISTORY.txt, got %s", transcript) + } + if !strings.Contains(transcript, "hidden reasoning") { + t.Fatalf("expected reasoning text in HISTORY.txt, got %s", transcript) } userIdx := strings.Index(transcript, "=== 1. USER ===") @@ -72,14 +75,24 @@ func TestBuildOpenAIHistoryTranscriptPreservesOrderAndToolHistory(t *testing.T) if userIdx >= assistantIdx || assistantIdx >= toolIdx { t.Fatalf("expected USER -> ASSISTANT -> TOOL order, got %s", transcript) } + if reasoningIdx := strings.Index(transcript, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(transcript, "") { + t.Fatalf("expected reasoning block before tool calls, got %s", transcript) + } + reasoning := extractHistorySplitReasoningContent(historyMessages) + if reasoning != "hidden reasoning" { + t.Fatalf("expected latest assistant reasoning to be extracted, got %q", reasoning) + } - finalPrompt, _ := buildHistorySplitPrompt(promptMessages, nil, util.DefaultToolChoicePolicy(), false) + finalPrompt, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false) if !strings.Contains(finalPrompt, "latest user turn") { t.Fatalf("expected latest user turn in final prompt, got %s", finalPrompt) } if strings.Contains(finalPrompt, "first user turn") { t.Fatalf("expected earlier history to be removed from final prompt, got %s", finalPrompt) } + if !strings.Contains(finalPrompt, "[reasoning_content]") || !strings.Contains(finalPrompt, "hidden reasoning") { + t.Fatalf("expected latest assistant reasoning to be attached to prompt, got %s", finalPrompt) + } if !strings.Contains(finalPrompt, "HISTORY.txt") { t.Fatalf("expected history instruction in final prompt, got %s", finalPrompt) } @@ -118,8 +131,12 @@ func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) { if len(promptMessages) == 0 || len(historyMessages) == 0 { t.Fatalf("expected both prompt and history messages, got prompt=%d history=%d", len(promptMessages), len(historyMessages)) } + reasoning := extractHistorySplitReasoningContent(historyMessages) + if reasoning != "" { + t.Fatalf("expected no reasoning in this fixture, got %q", reasoning) + } - promptText := buildOpenAIFinalPromptForSplitTest(promptMessages) + promptText, _ := buildHistorySplitPrompt(promptMessages, reasoning, nil, util.DefaultToolChoicePolicy(), false) if !strings.Contains(promptText, "latest user turn") { t.Fatalf("expected latest user turn in prompt, got %s", promptText) } @@ -136,11 +153,6 @@ func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) { } } -func buildOpenAIFinalPromptForSplitTest(messages []any) string { - prompt, _ := buildHistorySplitPrompt(messages, nil, util.DefaultToolChoicePolicy(), false) - return prompt -} - func TestApplyHistorySplitSkipsFirstTurn(t *testing.T) { ds := &inlineUploadDSStub{} h := &Handler{ @@ -233,6 +245,9 @@ func TestChatCompletionsHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testin if strings.Contains(promptText, "first user turn") { t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText) } + if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") { + t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText) + } if !strings.Contains(promptText, "HISTORY.txt") { t.Fatalf("expected history instruction in completion prompt, got %s", promptText) } @@ -283,6 +298,9 @@ func TestResponsesHistorySplitUploadsHistoryAndKeepsLatestPrompt(t *testing.T) { if strings.Contains(promptText, "first user turn") { t.Fatalf("expected historical turns removed from completion prompt, got %s", promptText) } + if !strings.Contains(promptText, "[reasoning_content]") || !strings.Contains(promptText, "hidden reasoning") { + t.Fatalf("expected latest assistant reasoning to be attached to completion prompt, got %s", promptText) + } } func TestChatCompletionsHistorySplitUploadFailureReturnsInternalServerError(t *testing.T) { diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go index 94c67ad..906c377 100644 --- a/internal/adapter/openai/message_normalize.go +++ b/internal/adapter/openai/message_normalize.go @@ -6,6 +6,8 @@ import ( "ds2api/internal/prompt" ) +const assistantReasoningLabel = "reasoning_content" + func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]any { _ = traceID out := make([]map[string]any, 0, len(raw)) @@ -55,17 +57,95 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an func buildAssistantContentForPrompt(msg map[string]any) string { content := strings.TrimSpace(normalizeOpenAIContentForPrompt(msg["content"])) - toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"]) - switch { - case content == "" && toolHistory == "": - return "" - case content == "": - return toolHistory - case toolHistory == "": - return content - default: - return content + "\n\n" + toolHistory + reasoning := strings.TrimSpace(normalizeOpenAIReasoningContentForPrompt(msg["reasoning_content"])) + if reasoning == "" { + reasoning = strings.TrimSpace(extractOpenAIReasoningContentFromMessage(msg["content"])) } + toolHistory := prompt.FormatToolCallsForPrompt(msg["tool_calls"]) + parts := make([]string, 0, 3) + if reasoning != "" { + parts = append(parts, formatPromptLabeledBlock(assistantReasoningLabel, reasoning)) + } + if content != "" { + parts = append(parts, content) + } + if toolHistory != "" { + parts = append(parts, toolHistory) + } + switch len(parts) { + case 0: + return "" + case 1: + return parts[0] + default: + return strings.Join(parts, "\n\n") + } +} + +func normalizeOpenAIReasoningContentForPrompt(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n") + case map[string]any: + return extractOpenAIReasoningTextFromItem(x) + default: + return "" + } +} + +func extractOpenAIReasoningContentFromMessage(v any) string { + switch x := v.(type) { + case []any: + return strings.Join(extractOpenAIReasoningPartsFromItems(x), "\n") + case map[string]any: + return extractOpenAIReasoningTextFromItem(x) + default: + return "" + } +} + +func extractOpenAIReasoningPartsFromItems(items []any) []string { + parts := make([]string, 0, len(items)) + for _, item := range items { + if text := extractOpenAIReasoningTextFromItemMap(item); text != "" { + parts = append(parts, text) + } + } + return parts +} + +func extractOpenAIReasoningTextFromItemMap(item any) string { + m, ok := item.(map[string]any) + if !ok { + return "" + } + return extractOpenAIReasoningTextFromItem(m) +} + +func extractOpenAIReasoningTextFromItem(m map[string]any) string { + if m == nil { + return "" + } + switch strings.ToLower(strings.TrimSpace(asString(m["type"]))) { + case "reasoning", "thinking": + for _, key := range []string{"text", "thinking", "content"} { + if text := strings.TrimSpace(asString(m[key])); text != "" { + return text + } + } + } + return "" +} + +func formatPromptLabeledBlock(label, text string) string { + label = strings.TrimSpace(label) + text = strings.TrimSpace(text) + if label == "" { + return text + } + return "[" + label + "]\n" + text + "\n[/" + label + "]" } func buildToolContentForPrompt(msg map[string]any) string { diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go index 00b3ef4..564fea7 100644 --- a/internal/adapter/openai/message_normalize_test.go +++ b/internal/adapter/openai/message_normalize_test.go @@ -296,3 +296,31 @@ func TestNormalizeOpenAIMessagesForPrompt_AssistantArrayContentFallbackWhenTextE t.Fatalf("expected content fallback text preserved, got %q", content) } } + +func TestNormalizeOpenAIMessagesForPrompt_AssistantReasoningContentPreserved(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "assistant", + "content": "visible answer", + "reasoning_content": "internal reasoning", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw, "") + if len(normalized) != 1 { + t.Fatalf("expected one normalized assistant message, got %#v", normalized) + } + content, _ := normalized[0]["content"].(string) + if !strings.Contains(content, "[reasoning_content]") { + t.Fatalf("expected labeled reasoning block in assistant content, got %q", content) + } + if !strings.Contains(content, "internal reasoning") { + t.Fatalf("expected reasoning text in assistant content, got %q", content) + } + if !strings.Contains(content, "visible answer") { + t.Fatalf("expected visible answer in assistant content, got %q", content) + } + if reasoningIdx := strings.Index(content, "[reasoning_content]"); reasoningIdx < 0 || reasoningIdx > strings.Index(content, "visible answer") { + t.Fatalf("expected reasoning block before visible answer, got %q", content) + } +} diff --git a/internal/adapter/openai/tool_sieve_state.go b/internal/adapter/openai/tool_sieve_state.go index 09de2a5..8128f8c 100644 --- a/internal/adapter/openai/tool_sieve_state.go +++ b/internal/adapter/openai/tool_sieve_state.go @@ -12,7 +12,6 @@ type toolStreamSieveState struct { codeFenceStack []int codeFencePendingTicks int codeFenceLineStart bool - recentTextTail string pendingToolRaw string pendingToolCalls []toolcall.ParsedToolCall disableDeltas bool @@ -36,9 +35,6 @@ type toolCallDelta struct { Arguments string } -// Keep in sync with JS TOOL_SIEVE_CONTEXT_TAIL_LIMIT. -const toolSieveContextTailLimit = 2048 - func (s *toolStreamSieveState) resetIncrementalToolState() { s.disableDeltas = false s.toolNameSent = false @@ -54,18 +50,6 @@ func (s *toolStreamSieveState) noteText(content string) { return } updateCodeFenceState(s, content) - s.recentTextTail = appendTail(s.recentTextTail, content, toolSieveContextTailLimit) -} - -func appendTail(prev, next string, max int) string { - if max <= 0 { - return "" - } - combined := prev + next - if len(combined) <= max { - return combined - } - return combined[len(combined)-max:] } func hasMeaningfulText(text string) bool { diff --git a/internal/adapter/openai/tool_sieve_xml_test.go b/internal/adapter/openai/tool_sieve_xml_test.go index 7fd123d..16827cc 100644 --- a/internal/adapter/openai/tool_sieve_xml_test.go +++ b/internal/adapter/openai/tool_sieve_xml_test.go @@ -42,6 +42,49 @@ func TestProcessToolSieveInterceptsXMLToolCallWithoutLeak(t *testing.T) { } } +func TestProcessToolSieveHandlesLongXMLToolCall(t *testing.T) { + var state toolStreamSieveState + const toolName = "write_to_file" + payload := strings.Repeat("x", 4096) + splitAt := len(payload) / 2 + chunks := []string{ + "\n \n " + toolName + "\n \n \n \n \n", + } + + var events []toolStreamEvent + for _, c := range chunks { + events = append(events, processToolSieveChunk(&state, c, []string{toolName})...) + } + events = append(events, flushToolSieve(&state, []string{toolName})...) + + var textContent strings.Builder + toolCalls := 0 + var gotPayload any + for _, evt := range events { + if evt.Content != "" { + textContent.WriteString(evt.Content) + } + if len(evt.ToolCalls) > 0 && gotPayload == nil { + gotPayload = evt.ToolCalls[0].Input["content"] + } + toolCalls += len(evt.ToolCalls) + } + + if toolCalls != 1 { + t.Fatalf("expected one long XML tool call, got %d events=%#v", toolCalls, events) + } + if textContent.Len() != 0 { + t.Fatalf("expected no leaked text for long XML tool call, got %q", textContent.String()) + } + got, _ := gotPayload.(string) + if got != payload { + t.Fatalf("expected long XML payload to survive intact, got len=%d want=%d", len(got), len(payload)) + } +} + func TestProcessToolSieveXMLWithLeadingText(t *testing.T) { var state toolStreamSieveState // Model outputs some prose then an XML tool call. diff --git a/internal/js/helpers/stream-tool-sieve/state.js b/internal/js/helpers/stream-tool-sieve/state.js index 9a5b1c3..447ecdf 100644 --- a/internal/js/helpers/stream-tool-sieve/state.js +++ b/internal/js/helpers/stream-tool-sieve/state.js @@ -1,14 +1,10 @@ 'use strict'; -// Keep in sync with Go toolSieveContextTailLimit. -const TOOL_SIEVE_CONTEXT_TAIL_LIMIT = 2048; - function createToolSieveState() { return { pending: '', capture: '', capturing: false, - recentTextTail: '', codeFenceStack: [], codeFencePendingTicks: 0, codeFenceLineStart: true, @@ -39,20 +35,6 @@ function noteText(state, text) { return; } updateCodeFenceState(state, text); - state.recentTextTail = appendTail(state.recentTextTail, text, TOOL_SIEVE_CONTEXT_TAIL_LIMIT); -} - -function appendTail(prev, next, max) { - const left = typeof prev === 'string' ? prev : ''; - const right = typeof next === 'string' ? next : ''; - if (!Number.isFinite(max) || max <= 0) { - return ''; - } - const combined = left + right; - if (combined.length <= max) { - return combined; - } - return combined.slice(combined.length - max); } function looksLikeToolExampleContext(text) { @@ -171,11 +153,9 @@ function toStringSafe(v) { } module.exports = { - TOOL_SIEVE_CONTEXT_TAIL_LIMIT, createToolSieveState, resetIncrementalToolState, noteText, - appendTail, looksLikeToolExampleContext, insideCodeFence, insideCodeFenceWithState, diff --git a/tests/node/stream-tool-sieve.test.js b/tests/node/stream-tool-sieve.test.js index 57c29f0..a5f29ac 100644 --- a/tests/node/stream-tool-sieve.test.js +++ b/tests/node/stream-tool-sieve.test.js @@ -98,6 +98,26 @@ test('sieve emits tool_calls when XML tag spans multiple chunks', () => { assert.equal(finalCalls[0].name, 'read_file'); }); +test('sieve keeps long XML tool calls buffered until the closing tag arrives', () => { + const longContent = 'x'.repeat(4096); + const splitAt = longContent.length / 2; + const events = runSieve( + [ + '\n \n write_to_file\n \n \n \n \n', + ], + ['write_to_file'], + ); + const leakedText = collectText(events); + const finalCalls = events.filter((evt) => evt.type === 'tool_calls').flatMap((evt) => evt.calls || []); + assert.equal(leakedText, ''); + assert.equal(finalCalls.length, 1); + assert.equal(finalCalls[0].name, 'write_to_file'); + assert.equal(finalCalls[0].input.content, longContent); +}); + test('sieve passes JSON tool_calls payload through as text (XML-only)', () => { const events = runSieve( ['{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}'],