From 19289c9008a5157b10dd4f8abf3ca8fefb90b83c Mon Sep 17 00:00:00 2001 From: CJACK Date: Wed, 18 Feb 2026 00:54:54 +0800 Subject: [PATCH] refactor: Modularize OpenAI message normalization and prompt building, enhancing `MessagesPrepare` to support additional content types and tool call formatting. --- internal/adapter/openai/handler.go | 20 +- internal/adapter/openai/message_normalize.go | 192 ++++++++++++++++++ .../adapter/openai/message_normalize_test.go | 121 +++++++++++ internal/adapter/openai/prompt_build.go | 12 ++ internal/adapter/openai/prompt_build_test.go | 80 ++++++++ internal/adapter/openai/vercel_stream.go | 6 +- internal/util/messages.go | 16 +- internal/util/messages_test.go | 27 +++ 8 files changed, 449 insertions(+), 25 deletions(-) create mode 100644 internal/adapter/openai/message_normalize.go create mode 100644 internal/adapter/openai/message_normalize_test.go create mode 100644 internal/adapter/openai/prompt_build.go create mode 100644 internal/adapter/openai/prompt_build_test.go diff --git a/internal/adapter/openai/handler.go b/internal/adapter/openai/handler.go index d0a2f1d..1602cf6 100644 --- a/internal/adapter/openai/handler.go +++ b/internal/adapter/openai/handler.go @@ -86,12 +86,7 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { return } - messages := normalizeMessages(messagesRaw) - toolNames := []string{} - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, toolNames = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) + finalPrompt, toolNames := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { @@ -405,17 +400,6 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt } } -func normalizeMessages(raw []any) []map[string]any { - out := make([]map[string]any, 0, len(raw)) - for _, item := range raw { - m, ok := item.(map[string]any) - if ok { - out = append(out, m) - } - } - return out -} - func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, []string) { toolSchemas := make([]string, 0, len(tools)) names := make([]string, 0, len(tools)) @@ -444,7 +428,7 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any, if len(toolSchemas) == 0 { return messages, names } - toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT: If calling tools, output ONLY the JSON. The response must start with { and end with }" + toolPrompt := "You have access to these tools:\n\n" + strings.Join(toolSchemas, "\n\n") + "\n\nWhen you need to use tools, output ONLY this JSON format (no other text):\n{\"tool_calls\": [{\"name\": \"tool_name\", \"input\": {\"param\": \"value\"}}]}\n\nIMPORTANT:\n1) If calling tools, output ONLY the JSON. The response must start with { and end with }.\n2) After receiving a tool result, you MUST use it to produce the final answer.\n3) Only call another tool when the previous result is missing required data or returned an error." for i := range messages { if messages[i]["role"] == "system" { diff --git a/internal/adapter/openai/message_normalize.go b/internal/adapter/openai/message_normalize.go new file mode 100644 index 0000000..3ebd1e7 --- /dev/null +++ b/internal/adapter/openai/message_normalize.go @@ -0,0 +1,192 @@ +package openai + +import ( + "encoding/json" + "fmt" + "strings" +) + +func normalizeOpenAIMessagesForPrompt(raw []any) []map[string]any { + out := make([]map[string]any, 0, len(raw)) + for _, item := range raw { + msg, ok := item.(map[string]any) + if !ok { + continue + } + role := strings.ToLower(strings.TrimSpace(asString(msg["role"]))) + switch role { + case "assistant": + content := normalizeOpenAIContentForPrompt(msg["content"]) + toolCalls := formatAssistantToolCallsForPrompt(msg) + combined := joinNonEmpty(content, toolCalls) + if combined == "" { + continue + } + out = append(out, map[string]any{ + "role": "assistant", + "content": combined, + }) + case "tool", "function": + out = append(out, map[string]any{ + "role": "user", + "content": formatToolResultForPrompt(msg), + }) + case "user", "system": + out = append(out, map[string]any{ + "role": role, + "content": normalizeOpenAIContentForPrompt(msg["content"]), + }) + default: + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + continue + } + if role == "" { + role = "user" + } + out = append(out, map[string]any{ + "role": role, + "content": content, + }) + } + } + return out +} + +func formatAssistantToolCallsForPrompt(msg map[string]any) string { + entries := make([]string, 0) + if calls, ok := msg["tool_calls"].([]any); ok { + for i, item := range calls { + call, ok := item.(map[string]any) + if !ok { + continue + } + id := strings.TrimSpace(asString(call["id"])) + if id == "" { + id = fmt.Sprintf("call_%d", i+1) + } + name := strings.TrimSpace(asString(call["name"])) + args := "" + + if fn, ok := call["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + args = normalizeOpenAIArgumentsForPrompt(fn["arguments"]) + } + if name == "" { + name = "unknown" + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["arguments"]) + } + if args == "" { + args = normalizeOpenAIArgumentsForPrompt(call["input"]) + } + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: %s\n- function.name: %s\n- function.arguments: %s", id, name, args)) + } + } + + if legacy, ok := msg["function_call"].(map[string]any); ok { + name := strings.TrimSpace(asString(legacy["name"])) + if name == "" { + name = "unknown" + } + args := normalizeOpenAIArgumentsForPrompt(legacy["arguments"]) + if args == "" { + args = "{}" + } + entries = append(entries, fmt.Sprintf("Tool call:\n- tool_call_id: call_legacy\n- function.name: %s\n- function.arguments: %s", name, args)) + } + + return strings.Join(entries, "\n\n") +} + +func formatToolResultForPrompt(msg map[string]any) string { + toolCallID := strings.TrimSpace(asString(msg["tool_call_id"])) + if toolCallID == "" { + toolCallID = strings.TrimSpace(asString(msg["id"])) + } + if toolCallID == "" { + toolCallID = "unknown" + } + + name := strings.TrimSpace(asString(msg["name"])) + if name == "" { + name = "unknown" + } + + content := normalizeOpenAIContentForPrompt(msg["content"]) + if content == "" { + content = "null" + } + + return fmt.Sprintf("Tool result:\n- tool_call_id: %s\n- name: %s\n- content: %s", toolCallID, name, content) +} + +func normalizeOpenAIContentForPrompt(v any) string { + switch x := v.(type) { + case string: + return x + case []any: + parts := make([]string, 0, len(x)) + for _, item := range x { + m, ok := item.(map[string]any) + if !ok { + continue + } + t := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + if t != "text" && t != "output_text" && t != "input_text" { + continue + } + if text := asString(m["text"]); text != "" { + parts = append(parts, text) + continue + } + if text := asString(m["content"]); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + default: + return marshalToPromptString(v) + } +} + +func normalizeOpenAIArgumentsForPrompt(v any) string { + switch x := v.(type) { + case string: + return strings.TrimSpace(x) + default: + return marshalToPromptString(v) + } +} + +func marshalToPromptString(v any) string { + b, err := json.Marshal(v) + if err != nil { + return strings.TrimSpace(fmt.Sprintf("%v", v)) + } + return string(b) +} + +func asString(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func joinNonEmpty(parts ...string) string { + nonEmpty := make([]string, 0, len(parts)) + for _, p := range parts { + if strings.TrimSpace(p) == "" { + continue + } + nonEmpty = append(nonEmpty, p) + } + return strings.Join(nonEmpty, "\n\n") +} diff --git a/internal/adapter/openai/message_normalize_test.go b/internal/adapter/openai/message_normalize_test.go new file mode 100644 index 0000000..bb648d3 --- /dev/null +++ b/internal/adapter/openai/message_normalize_test.go @@ -0,0 +1,121 @@ +package openai + +import ( + "strings" + "testing" + + "ds2api/internal/util" +) + +func TestNormalizeOpenAIMessagesForPrompt_AssistantToolCallsAndToolResult(t *testing.T) { + raw := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "content": nil, + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": "{\"temp\":18}", + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 4 { + t.Fatalf("expected 4 normalized messages, got %d", len(normalized)) + } + assistantContent, _ := normalized[2]["content"].(string) + if !strings.Contains(assistantContent, "tool_call_id: call_1") || + !strings.Contains(assistantContent, "function.name: get_weather") || + !strings.Contains(assistantContent, "function.arguments: {\"city\":\"beijing\"}") { + t.Fatalf("assistant tool call not serialized correctly: %q", assistantContent) + } + toolContent, _ := normalized[3]["content"].(string) + if !strings.Contains(toolContent, "Tool result:") || !strings.Contains(toolContent, "name: get_weather") { + t.Fatalf("tool result not serialized correctly: %q", toolContent) + } + + prompt := util.MessagesPrepare(normalized) + if !strings.Contains(prompt, "tool_call_id: call_1") || !strings.Contains(prompt, "Tool result:") { + t.Fatalf("expected prompt to include tool call + result semantics: %q", prompt) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolObjectContentPreserved(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_2", + "name": "get_weather", + "content": map[string]any{ + "temp": 18, + "condition": "sunny", + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, `"temp":18`) || !strings.Contains(got, `"condition":"sunny"`) { + t.Fatalf("expected serialized object in tool content, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_ToolArrayBlocksJoined(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "tool", + "tool_call_id": "call_3", + "name": "read_file", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "output_text", "text": "line-2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "line-1\nline-2") { + t.Fatalf("expected joined text blocks, got %q", got) + } +} + +func TestNormalizeOpenAIMessagesForPrompt_FunctionRoleCompatible(t *testing.T) { + raw := []any{ + map[string]any{ + "role": "function", + "tool_call_id": "call_4", + "name": "legacy_tool", + "content": map[string]any{ + "ok": true, + }, + }, + } + + normalized := normalizeOpenAIMessagesForPrompt(raw) + if len(normalized) != 1 { + t.Fatalf("expected one normalized message, got %d", len(normalized)) + } + if normalized[0]["role"] != "user" { + t.Fatalf("expected function role mapped to user, got %#v", normalized[0]["role"]) + } + got, _ := normalized[0]["content"].(string) + if !strings.Contains(got, "name: legacy_tool") || !strings.Contains(got, `"ok":true`) { + t.Fatalf("unexpected normalized function-role content: %q", got) + } +} diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go new file mode 100644 index 0000000..a7bbc92 --- /dev/null +++ b/internal/adapter/openai/prompt_build.go @@ -0,0 +1,12 @@ +package openai + +import "ds2api/internal/util" + +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any) (string, []string) { + messages := normalizeOpenAIMessagesForPrompt(messagesRaw) + toolNames := []string{} + if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { + messages, toolNames = injectToolPrompt(messages, tools) + } + return util.MessagesPrepare(messages), toolNames +} diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go new file mode 100644 index 0000000..1833860 --- /dev/null +++ b/internal/adapter/openai/prompt_build_test.go @@ -0,0 +1,80 @@ +package openai + +import ( + "strings" + "testing" +) + +func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *testing.T) { + messages := []any{ + map[string]any{"role": "user", "content": "查北京天气"}, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "function": map[string]any{ + "name": "get_weather", + "arguments": "{\"city\":\"beijing\"}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "name": "get_weather", + "content": map[string]any{"temp": 18, "condition": "sunny"}, + }, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "description": "Get weather", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools) + if len(toolNames) != 1 || toolNames[0] != "get_weather" { + t.Fatalf("unexpected tool names: %#v", toolNames) + } + if !strings.Contains(finalPrompt, "tool_call_id: call_1") || + !strings.Contains(finalPrompt, "function.name: get_weather") || + !strings.Contains(finalPrompt, "Tool result:") || + !strings.Contains(finalPrompt, `"condition":"sunny"`) { + t.Fatalf("handler finalPrompt missing tool roundtrip semantics: %q", finalPrompt) + } +} + +func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *testing.T) { + messages := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "请调用工具"}, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search docs", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools) + if !strings.Contains(finalPrompt, "After receiving a tool result, you MUST use it to produce the final answer.") { + t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) + } + if !strings.Contains(finalPrompt, "Only call another tool when the previous result is missing required data or returned an error.") { + t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) + } +} diff --git a/internal/adapter/openai/vercel_stream.go b/internal/adapter/openai/vercel_stream.go index 653f3cf..85c9cd8 100644 --- a/internal/adapter/openai/vercel_stream.go +++ b/internal/adapter/openai/vercel_stream.go @@ -68,11 +68,7 @@ func (h *Handler) handleVercelStreamPrepare(w http.ResponseWriter, r *http.Reque return } - messages := normalizeMessages(messagesRaw) - if tools, ok := req["tools"].([]any); ok && len(tools) > 0 { - messages, _ = injectToolPrompt(messages, tools) - } - finalPrompt := util.MessagesPrepare(messages) + finalPrompt, _ := buildOpenAIFinalPrompt(messagesRaw, req["tools"]) sessionID, err := h.DS.CreateSession(r.Context(), a, 3) if err != nil { diff --git a/internal/util/messages.go b/internal/util/messages.go index 19f2948..fcc9484 100644 --- a/internal/util/messages.go +++ b/internal/util/messages.go @@ -1,6 +1,8 @@ package util import ( + "encoding/json" + "fmt" "regexp" "strings" @@ -68,15 +70,25 @@ func normalizeContent(v any) string { if !ok { continue } - if m["type"] == "text" { + typeStr, _ := m["type"].(string) + typeStr = strings.ToLower(strings.TrimSpace(typeStr)) + if typeStr == "text" || typeStr == "output_text" || typeStr == "input_text" { if txt, ok := m["text"].(string); ok { parts = append(parts, txt) + continue + } + if txt, ok := m["content"].(string); ok { + parts = append(parts, txt) } } } return strings.Join(parts, "\n") default: - return "" + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprintf("%v", v) + } + return string(b) } } diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 30b8cc0..776853b 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -33,6 +33,33 @@ func TestMessagesPrepareRoles(t *testing.T) { } } +func TestMessagesPrepareObjectContent(t *testing.T) { + messages := []map[string]any{ + {"role": "user", "content": map[string]any{"temp": 18, "ok": true}}, + } + got := MessagesPrepare(messages) + if !contains(got, `"temp":18`) || !contains(got, `"ok":true`) { + t.Fatalf("expected serialized object content, got %q", got) + } +} + +func TestMessagesPrepareArrayTextVariants(t *testing.T) { + messages := []map[string]any{ + { + "role": "user", + "content": []any{ + map[string]any{"type": "output_text", "text": "line1"}, + map[string]any{"type": "input_text", "text": "line2"}, + map[string]any{"type": "image_url", "image_url": "https://example.com/a.png"}, + }, + }, + } + got := MessagesPrepare(messages) + if got != "line1\nline2" { + t.Fatalf("unexpected content from text variants: %q", got) + } +} + func TestConvertClaudeToDeepSeek(t *testing.T) { store := config.LoadStore() req := map[string]any{