From 210d9f5793635b02200a67d92437b1f7398ab7cf Mon Sep 17 00:00:00 2001 From: CJACK Date: Thu, 19 Feb 2026 04:44:01 +0800 Subject: [PATCH] feat: enhance message normalization for OpenAI tool calls and Claude system message tool injection --- internal/adapter/claude/standard_request.go | 60 ++++- .../adapter/claude/standard_request_test.go | 54 ++++ .../openai/responses_embeddings_test.go | 77 ++++++ internal/adapter/openai/responses_handler.go | 249 ++++++++++++++++-- .../adapter/openai/responses_stream_test.go | 8 +- internal/format/openai/render.go | 10 +- internal/format/openai/render_test.go | 64 +++++ 7 files changed, 479 insertions(+), 43 deletions(-) create mode 100644 internal/format/openai/render_test.go diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index cdbb675..23520c0 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma payload := cloneMap(req) payload["messages"] = normalizedMessages toolsRequested, _ := req["tools"].([]any) - if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) { - payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...) - } + payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested) dsPayload := convertClaudeToDeepSeek(payload, store) dsModel, _ := dsPayload["model"].(string) @@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma NormalizedMessages: normalizedMessages, }, nil } + +func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any { + if len(tools) == 0 { + return normalizedMessages + } + toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools)) + if toolPrompt == "" { + return normalizedMessages + } + + // Prefer top-level Anthropic-style system prompt when available. + if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" { + payload["system"] = mergeSystemPrompt(systemText, toolPrompt) + return normalizedMessages + } + + messages := cloneAnySlice(normalizedMessages) + for i := range messages { + msg, ok := messages[i].(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if !strings.EqualFold(strings.TrimSpace(role), "system") { + continue + } + copied := cloneMap(msg) + copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt) + messages[i] = copied + return messages + } + + return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...) +} + +func mergeSystemPrompt(base, extra string) string { + base = strings.TrimSpace(base) + extra = strings.TrimSpace(extra) + switch { + case base == "": + return extra + case extra == "": + return base + default: + return base + "\n\n" + extra + } +} + +func cloneAnySlice(in []any) []any { + if len(in) == 0 { + return nil + } + out := make([]any, len(in)) + copy(out, in) + return out +} diff --git a/internal/adapter/claude/standard_request_test.go b/internal/adapter/claude/standard_request_test.go index 7ffdfb8..6110124 100644 --- a/internal/adapter/claude/standard_request_test.go +++ b/internal/adapter/claude/standard_request_test.go @@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) { t.Fatalf("expected non-empty final prompt") } } + +func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "messages": []any{ + map[string]any{"role": "system", "content": "baseline rule"}, + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "baseline rule") { + t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt) + } +} + +func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{}`) + store := config.LoadStore() + req := map[string]any{ + "model": "claude-sonnet-4-5", + "system": "top-level system", + "messages": []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + "tools": []any{ + map[string]any{"name": "search", "description": "Search"}, + }, + } + + norm, err := normalizeClaudeRequest(store, req) + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + if !containsStr(norm.Standard.FinalPrompt, "top-level system") { + t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt) + } + if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") { + t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt) + } +} diff --git a/internal/adapter/openai/responses_embeddings_test.go b/internal/adapter/openai/responses_embeddings_test.go index d270e1a..a5e2b72 100644 --- a/internal/adapter/openai/responses_embeddings_test.go +++ b/internal/adapter/openai/responses_embeddings_test.go @@ -1,6 +1,7 @@ package openai import ( + "strings" "testing" "time" ) @@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) { } } +func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) { + msgs := normalizeResponsesInputAsMessages(map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "input_text", "text": "line-1"}, + map[string]any{"type": "input_text", "text": "line-2"}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "user" { + t.Fatalf("unexpected role: %#v", m) + } + if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" { + t.Fatalf("unexpected content: %#v", m["content"]) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call_output", + "call_id": "call_123", + "output": map[string]any{"ok": true}, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "tool" { + t.Fatalf("expected tool role, got %#v", m) + } + if m["tool_call_id"] != "call_123" { + t.Fatalf("expected tool_call_id propagated, got %#v", m) + } +} + +func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) { + msgs := normalizeResponsesInputAsMessages([]any{ + map[string]any{ + "type": "function_call", + "call_id": "call_456", + "name": "search", + "arguments": `{"q":"golang"}`, + }, + }) + if len(msgs) != 1 { + t.Fatalf("expected one message, got %d", len(msgs)) + } + m, _ := msgs[0].(map[string]any) + if m["role"] != "assistant" { + t.Fatalf("expected assistant role, got %#v", m["role"]) + } + toolCalls, _ := m["tool_calls"].([]any) + if len(toolCalls) != 1 { + t.Fatalf("expected one tool_call, got %#v", m["tool_calls"]) + } + call, _ := toolCalls[0].(map[string]any) + if call["id"] != "call_456" { + t.Fatalf("expected call id preserved, got %#v", call) + } + if call["type"] != "function" { + t.Fatalf("expected function type, got %#v", call) + } + fn, _ := call["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("expected call name preserved, got %#v", call) + } + if fn["arguments"] != `{"q":"golang"}` { + t.Fatalf("expected call arguments preserved, got %#v", call) + } +} + func TestExtractEmbeddingInputs(t *testing.T) { got := extractEmbeddingInputs([]any{"a", "b"}) if len(got) != 2 || got[0] != "a" || got[1] != "b" { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index e767b2b..522521d 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any { } return []any{map[string]any{"role": "user", "content": v}} case []any: - if len(v) == 0 { - return nil - } - // If caller already provides role-shaped items, keep as-is. - if first, ok := v[0].(map[string]any); ok { - if _, hasRole := first["role"]; hasRole { - return v - } - } - parts := make([]string, 0, len(v)) - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { - if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { - parts = append(parts, txt) - continue - } - } - } - if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { - parts = append(parts, s) - } - } - if len(parts) == 0 { - return nil - } - return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}} + return normalizeResponsesInputArray(v) case map[string]any: + if msg := normalizeResponsesInputItem(v); msg != nil { + return []any{msg} + } if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" { return []any{map[string]any{"role": "user", "content": txt}} } - if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" { - return []any{map[string]any{"role": "user", "content": content}} + if content, ok := v["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return []any{map[string]any{"role": "user", "content": content}} + } } } return nil } + +func normalizeResponsesInputArray(items []any) []any { + if len(items) == 0 { + return nil + } + out := make([]any, 0, len(items)) + fallbackParts := make([]string, 0, len(items)) + flushFallback := func() { + if len(fallbackParts) == 0 { + return + } + out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")}) + fallbackParts = fallbackParts[:0] + } + + for _, item := range items { + switch x := item.(type) { + case map[string]any: + if msg := normalizeResponsesInputItem(x); msg != nil { + flushFallback() + out = append(out, msg) + continue + } + if s := normalizeResponsesFallbackPart(x); s != "" { + fallbackParts = append(fallbackParts, s) + } + default: + if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" { + fallbackParts = append(fallbackParts, s) + } + } + } + flushFallback() + if len(out) == 0 { + return nil + } + return out +} + +func normalizeResponsesInputItem(m map[string]any) map[string]any { + if m == nil { + return nil + } + + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role != "" { + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + return map[string]any{ + "role": role, + "content": content, + } + } + + itemType := strings.ToLower(strings.TrimSpace(asString(m["type"]))) + switch itemType { + case "message", "input_message": + content := m["content"] + if content == nil { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + content = txt + } + } + if content == nil { + return nil + } + role := strings.ToLower(strings.TrimSpace(asString(m["role"]))) + if role == "" { + role = "user" + } + return map[string]any{ + "role": role, + "content": content, + } + case "function_call_output", "tool_result": + content := m["output"] + if content == nil { + content = m["content"] + } + if content == nil { + content = "" + } + out := map[string]any{ + "role": "tool", + "content": content, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + out["tool_call_id"] = callID + } else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" { + out["tool_call_id"] = callID + } + if name := strings.TrimSpace(asString(m["name"])); name != "" { + out["name"] = name + } else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" { + out["name"] = name + } + return out + case "function_call", "tool_call": + name := strings.TrimSpace(asString(m["name"])) + var fn map[string]any + if rawFn, ok := m["function"].(map[string]any); ok { + fn = rawFn + if name == "" { + name = strings.TrimSpace(asString(fn["name"])) + } + } + if name == "" { + return nil + } + + var argsRaw any + if v, ok := m["arguments"]; ok { + argsRaw = v + } else if v, ok := m["input"]; ok { + argsRaw = v + } + if argsRaw == nil && fn != nil { + if v, ok := fn["arguments"]; ok { + argsRaw = v + } else if v, ok := fn["input"]; ok { + argsRaw = v + } + } + + functionPayload := map[string]any{ + "name": name, + "arguments": stringifyToolCallArguments(argsRaw), + } + call := map[string]any{ + "type": "function", + "function": functionPayload, + } + if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" { + call["id"] = callID + } else if callID = strings.TrimSpace(asString(m["id"])); callID != "" { + call["id"] = callID + } + return map[string]any{ + "role": "assistant", + "tool_calls": []any{call}, + } + case "input_text": + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + } + + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return map[string]any{ + "role": "user", + "content": txt, + } + } + if content, ok := m["content"]; ok { + if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" { + return map[string]any{ + "role": "user", + "content": content, + } + } + } + return nil +} + +func normalizeResponsesFallbackPart(m map[string]any) string { + if m == nil { + return "" + } + if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") { + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + } + if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" { + return txt + } + if content, ok := m["content"]; ok { + if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" { + return normalized + } + } + return strings.TrimSpace(fmt.Sprintf("%v", m)) +} + +func stringifyToolCallArguments(v any) string { + switch x := v.(type) { + case nil: + return "{}" + case string: + s := strings.TrimSpace(x) + if s == "" { + return "{}" + } + return s + default: + b, err := json.Marshal(x) + if err != nil || len(b) == 0 { + return "{}" + } + return string(b) + } +} diff --git a/internal/adapter/openai/responses_stream_test.go b/internal/adapter/openai/responses_stream_test.go index 9b0a5ac..03752a7 100644 --- a/internal/adapter/openai/responses_stream_test.go +++ b/internal/adapter/openai/responses_stream_test.go @@ -54,8 +54,12 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"]) } call0, _ := toolCalls[0].(map[string]any) - if call0["name"] != "read_file" { - t.Fatalf("unexpected tool call name: %#v", call0["name"]) + if call0["type"] != "function" { + t.Fatalf("unexpected tool call type: %#v", call0["type"]) + } + fn, _ := call0["function"].(map[string]any) + if fn["name"] != "read_file" { + t.Fatalf("unexpected tool call name: %#v", fn["name"]) } if strings.Contains(outputText, `"tool_calls"`) { t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText) diff --git a/internal/format/openai/render.go b/internal/format/openai/render.go index fc7473f..3d2f967 100644 --- a/internal/format/openai/render.go +++ b/internal/format/openai/render.go @@ -48,17 +48,9 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex output := make([]any, 0, 2) if len(detected) > 0 { exposedOutputText = "" - toolCalls := make([]any, 0, len(detected)) - for _, tc := range detected { - toolCalls = append(toolCalls, map[string]any{ - "type": "tool_call", - "name": tc.Name, - "arguments": tc.Input, - }) - } output = append(output, map[string]any{ "type": "tool_calls", - "tool_calls": toolCalls, + "tool_calls": util.FormatOpenAIToolCalls(detected), }) } else { content := []any{ diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go new file mode 100644 index 0000000..1da68d0 --- /dev/null +++ b/internal/format/openai/render_test.go @@ -0,0 +1,64 @@ +package openai + +import ( + "encoding/json" + "testing" +) + +func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) { + obj := BuildResponseObject( + "resp_test", + "gpt-4o", + "prompt", + "", + `{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`, + []string{"search"}, + ) + + outputText, _ := obj["output_text"].(string) + if outputText != "" { + t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText) + } + + output, _ := obj["output"].([]any) + if len(output) != 1 { + t.Fatalf("expected one tool_calls wrapper, got %#v", obj["output"]) + } + + first, _ := output[0].(map[string]any) + if first["type"] != "tool_calls" { + t.Fatalf("expected first output item type tool_calls, got %#v", first["type"]) + } + var toolCalls []map[string]any + switch v := first["tool_calls"].(type) { + case []map[string]any: + toolCalls = v + case []any: + toolCalls = make([]map[string]any, 0, len(v)) + for _, item := range v { + m, _ := item.(map[string]any) + if m != nil { + toolCalls = append(toolCalls, m) + } + } + } + if len(toolCalls) != 1 { + t.Fatalf("expected one tool call, got %#v", first["tool_calls"]) + } + tc := toolCalls[0] + if tc["type"] != "function" || tc["id"] == "" { + t.Fatalf("unexpected tool call shape: %#v", tc) + } + fn, _ := tc["function"].(map[string]any) + if fn["name"] != "search" { + t.Fatalf("unexpected function name: %#v", fn["name"]) + } + argsRaw, _ := fn["arguments"].(string) + var args map[string]any + if err := json.Unmarshal([]byte(argsRaw), &args); err != nil { + t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err) + } + if args["q"] != "golang" { + t.Fatalf("unexpected arguments: %#v", args) + } +}