From a50490562654490c774d35f4d0a320bfefc22a54 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Sun, 22 Mar 2026 12:47:00 +0800 Subject: [PATCH] Fix Claude/Gemini prompt flattening for tool history and binary parts --- internal/adapter/claude/handler_util_test.go | 15 ++- internal/adapter/claude/handler_utils.go | 122 ++++++++++++++++-- internal/adapter/gemini/convert_messages.go | 88 ++++++++++++- .../adapter/gemini/convert_messages_test.go | 8 +- 4 files changed, 218 insertions(+), 15 deletions(-) diff --git a/internal/adapter/claude/handler_util_test.go b/internal/adapter/claude/handler_util_test.go index 169b0b2..3212cca 100644 --- a/internal/adapter/claude/handler_util_test.go +++ b/internal/adapter/claude/handler_util_test.go @@ -1,6 +1,7 @@ package claude import ( + "strings" "testing" ) @@ -91,6 +92,10 @@ func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) { if call["id"] != "call_1" { t.Fatalf("expected call id preserved, got %#v", call) } + content, _ := m["content"].(string) + if !containsStr(content, "search_web") || !containsStr(content, `"arguments":"{\"query\":\"latest\"}"`) { + t.Fatalf("expected assistant content to include serialized tool call for prompt roundtrip, got %q", content) + } } func TestNormalizeClaudeMessagesSkipsNonMap(t *testing.T) { @@ -125,7 +130,7 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { "role": "user", "content": []any{ map[string]any{"type": "text", "text": "Hello"}, - map[string]any{"type": "image", "source": "data:..."}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}}, map[string]any{"type": "text", "text": "World"}, }, }, @@ -134,7 +139,13 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) { m := got[0].(map[string]any) content, _ := m["content"].(string) if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) { - t.Fatalf("expected text plus raw non-text block preserved, got %q", content) + t.Fatalf("expected text plus non-text block marker preserved, got %q", content) + } + if !containsStr(content, omittedBinaryMarker) { + t.Fatalf("expected binary payload omitted marker, got %q", content) + } + if containsStr(content, strings.Repeat("A", 100)) { + t.Fatalf("expected raw base64 payload not to be included, got %q", content) } } diff --git a/internal/adapter/claude/handler_utils.go b/internal/adapter/claude/handler_utils.go index 3702202..50da3ec 100644 --- a/internal/adapter/claude/handler_utils.go +++ b/internal/adapter/claude/handler_utils.go @@ -6,6 +6,11 @@ import ( "strings" ) +const ( + maxClaudeRawPromptChars = 1024 + omittedBinaryMarker = "[omitted_binary_payload]" +) + func normalizeClaudeMessages(messages []any) []any { out := make([]any, 0, len(messages)) for _, m := range messages { @@ -49,7 +54,7 @@ func normalizeClaudeMessages(messages []any) []any { out = append(out, toolMsg) } default: - if raw := strings.TrimSpace(formatClaudeBlockRaw(b)); raw != "" { + if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" { textParts = append(textParts, raw) } } @@ -128,19 +133,21 @@ func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any { if err != nil || len(argsJSON) == 0 { argsJSON = []byte("{}") } - return map[string]any{ - "role": "assistant", - "tool_calls": []any{ - map[string]any{ - "id": callID, - "type": "function", - "function": map[string]any{ - "name": name, - "arguments": string(argsJSON), - }, + toolCalls := []any{ + map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": string(argsJSON), }, }, } + return map[string]any{ + "role": "assistant", + "content": marshalCompactJSON(toolCalls), + "tool_calls": toolCalls, + } } func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any { @@ -176,6 +183,99 @@ func formatClaudeBlockRaw(block map[string]any) string { return string(b) } +func formatClaudeUnknownBlockForPrompt(block map[string]any) string { + if block == nil { + return "" + } + safe := sanitizeClaudeBlockForPrompt(block) + raw := strings.TrimSpace(formatClaudeBlockRaw(safe)) + if raw == "" { + return "" + } + if len(raw) > maxClaudeRawPromptChars { + return raw[:maxClaudeRawPromptChars] + "...(truncated)" + } + return raw +} + +func sanitizeClaudeBlockForPrompt(block map[string]any) map[string]any { + out := cloneMap(block) + for k, v := range out { + if looksLikeBinaryFieldName(k) { + out[k] = omittedBinaryMarker + continue + } + switch inner := v.(type) { + case map[string]any: + out[k] = sanitizeClaudeBlockForPrompt(inner) + case []any: + out[k] = sanitizeClaudeArrayForPrompt(inner) + case string: + out[k] = sanitizeClaudeStringForPrompt(k, inner) + } + } + return out +} + +func sanitizeClaudeArrayForPrompt(items []any) []any { + out := make([]any, 0, len(items)) + for _, item := range items { + switch v := item.(type) { + case map[string]any: + out = append(out, sanitizeClaudeBlockForPrompt(v)) + case []any: + out = append(out, sanitizeClaudeArrayForPrompt(v)) + default: + out = append(out, v) + } + } + return out +} + +func sanitizeClaudeStringForPrompt(key, value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if looksLikeBinaryFieldName(key) || looksLikeBase64Payload(trimmed) { + return omittedBinaryMarker + } + if len(trimmed) > maxClaudeRawPromptChars { + return trimmed[:maxClaudeRawPromptChars] + "...(truncated)" + } + return trimmed +} + +func looksLikeBinaryFieldName(name string) bool { + n := strings.ToLower(strings.TrimSpace(name)) + return n == "data" || n == "bytes" || n == "base64" || n == "inline_data" || n == "inlinedata" +} + +func looksLikeBase64Payload(v string) bool { + if len(v) < 512 { + return false + } + compact := strings.TrimRight(v, "=") + if compact == "" { + return false + } + for _, ch := range compact { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { + continue + } + return false + } + return true +} + +func marshalCompactJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + func hasSystemMessage(messages []any) bool { for _, m := range messages { msg, ok := m.(map[string]any) diff --git a/internal/adapter/gemini/convert_messages.go b/internal/adapter/gemini/convert_messages.go index 79a4de1..ec3f174 100644 --- a/internal/adapter/gemini/convert_messages.go +++ b/internal/adapter/gemini/convert_messages.go @@ -2,6 +2,8 @@ package gemini import "strings" +const maxGeminiRawPromptChars = 1024 + func geminiMessagesFromRequest(req map[string]any) []any { out := make([]any, 0, 8) if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" { @@ -110,7 +112,7 @@ func geminiMessagesFromRequest(req map[string]any) []any { continue } - if raw := strings.TrimSpace(stringifyJSON(part)); raw != "" && raw != "null" { + if raw := strings.TrimSpace(formatGeminiUnknownPartForPrompt(part)); raw != "" && raw != "null" { textParts = append(textParts, raw) } } @@ -156,3 +158,87 @@ func mapGeminiRole(v any) string { return "" } } + +func formatGeminiUnknownPartForPrompt(part map[string]any) string { + safe := sanitizeGeminiPartForPrompt(part) + raw := strings.TrimSpace(stringifyJSON(safe)) + if raw == "" { + return "" + } + if len(raw) > maxGeminiRawPromptChars { + return raw[:maxGeminiRawPromptChars] + "...(truncated)" + } + return raw +} + +func sanitizeGeminiPartForPrompt(part map[string]any) map[string]any { + out := make(map[string]any, len(part)) + for k, v := range part { + if looksLikeGeminiBinaryField(k) { + out[k] = "[omitted_binary_payload]" + continue + } + switch x := v.(type) { + case map[string]any: + out[k] = sanitizeGeminiPartForPrompt(x) + case []any: + out[k] = sanitizeGeminiArrayForPrompt(x) + case string: + out[k] = sanitizeGeminiStringForPrompt(k, x) + default: + out[k] = v + } + } + return out +} + +func sanitizeGeminiArrayForPrompt(items []any) []any { + out := make([]any, 0, len(items)) + for _, item := range items { + switch x := item.(type) { + case map[string]any: + out = append(out, sanitizeGeminiPartForPrompt(x)) + case []any: + out = append(out, sanitizeGeminiArrayForPrompt(x)) + default: + out = append(out, x) + } + } + return out +} + +func sanitizeGeminiStringForPrompt(key, value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "" + } + if looksLikeGeminiBinaryField(key) || looksLikeGeminiBase64(trimmed) { + return "[omitted_binary_payload]" + } + if len(trimmed) > maxGeminiRawPromptChars { + return trimmed[:maxGeminiRawPromptChars] + "...(truncated)" + } + return trimmed +} + +func looksLikeGeminiBinaryField(name string) bool { + n := strings.ToLower(strings.TrimSpace(name)) + return n == "data" || n == "bytes" || n == "inlinedata" || n == "inline_data" || n == "base64" +} + +func looksLikeGeminiBase64(v string) bool { + if len(v) < 512 { + return false + } + compact := strings.TrimRight(v, "=") + if compact == "" { + return false + } + for _, ch := range compact { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '+' || ch == '/' || ch == '-' || ch == '_' { + continue + } + return false + } + return true +} diff --git a/internal/adapter/gemini/convert_messages_test.go b/internal/adapter/gemini/convert_messages_test.go index b66b2b3..4c98778 100644 --- a/internal/adapter/gemini/convert_messages_test.go +++ b/internal/adapter/gemini/convert_messages_test.go @@ -60,7 +60,7 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T "role": "user", "parts": []any{ map[string]any{"text": "hello"}, - map[string]any{"inlineData": map[string]any{"mimeType": "image/png"}}, + map[string]any{"inlineData": map[string]any{"mimeType": "image/png", "data": strings.Repeat("A", 2048)}}, }, }, }, @@ -75,4 +75,10 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T if !strings.Contains(content, "hello") || !strings.Contains(content, "inlineData") { t.Fatalf("expected unknown part preserved as raw json text, got %q", content) } + if !strings.Contains(content, "[omitted_binary_payload]") { + t.Fatalf("expected inlineData payload to be redacted, got %q", content) + } + if strings.Contains(content, strings.Repeat("A", 100)) { + t.Fatalf("expected raw base64 payload not to be embedded, got %q", content) + } }