mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: enhance message normalization for OpenAI tool calls and Claude system message tool injection
This commit is contained in:
@@ -27,9 +27,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
payload := cloneMap(req)
|
||||
payload["messages"] = normalizedMessages
|
||||
toolsRequested, _ := req["tools"].([]any)
|
||||
if len(toolsRequested) > 0 && !hasSystemMessage(normalizedMessages) {
|
||||
payload["messages"] = append([]any{map[string]any{"role": "system", "content": buildClaudeToolPrompt(toolsRequested)}}, normalizedMessages...)
|
||||
}
|
||||
payload["messages"] = injectClaudeToolPrompt(payload, normalizedMessages, toolsRequested)
|
||||
|
||||
dsPayload := convertClaudeToDeepSeek(payload, store)
|
||||
dsModel, _ := dsPayload["model"].(string)
|
||||
@@ -57,3 +55,59 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
|
||||
NormalizedMessages: normalizedMessages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func injectClaudeToolPrompt(payload map[string]any, normalizedMessages []any, tools []any) []any {
|
||||
if len(tools) == 0 {
|
||||
return normalizedMessages
|
||||
}
|
||||
toolPrompt := strings.TrimSpace(buildClaudeToolPrompt(tools))
|
||||
if toolPrompt == "" {
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
// Prefer top-level Anthropic-style system prompt when available.
|
||||
if systemText, ok := payload["system"].(string); ok && strings.TrimSpace(systemText) != "" {
|
||||
payload["system"] = mergeSystemPrompt(systemText, toolPrompt)
|
||||
return normalizedMessages
|
||||
}
|
||||
|
||||
messages := cloneAnySlice(normalizedMessages)
|
||||
for i := range messages {
|
||||
msg, ok := messages[i].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msg["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "system") {
|
||||
continue
|
||||
}
|
||||
copied := cloneMap(msg)
|
||||
copied["content"] = mergeSystemPrompt(strings.TrimSpace(fmt.Sprintf("%v", copied["content"])), toolPrompt)
|
||||
messages[i] = copied
|
||||
return messages
|
||||
}
|
||||
|
||||
return append([]any{map[string]any{"role": "system", "content": toolPrompt}}, messages...)
|
||||
}
|
||||
|
||||
func mergeSystemPrompt(base, extra string) string {
|
||||
base = strings.TrimSpace(base)
|
||||
extra = strings.TrimSpace(extra)
|
||||
switch {
|
||||
case base == "":
|
||||
return extra
|
||||
case extra == "":
|
||||
return base
|
||||
default:
|
||||
return base + "\n\n" + extra
|
||||
}
|
||||
}
|
||||
|
||||
func cloneAnySlice(in []any) []any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -36,3 +36,57 @@ func TestNormalizeClaudeRequest(t *testing.T) {
|
||||
t.Fatalf("expected non-empty final prompt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoExistingSystemMessage(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "system", "content": "baseline rule"},
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected into final prompt, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "baseline rule") {
|
||||
t.Fatalf("expected existing system message preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeClaudeRequestInjectsToolsIntoTopLevelSystem(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{}`)
|
||||
store := config.LoadStore()
|
||||
req := map[string]any{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": "top-level system",
|
||||
"messages": []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
"tools": []any{
|
||||
map[string]any{"name": "search", "description": "Search"},
|
||||
},
|
||||
}
|
||||
|
||||
norm, err := normalizeClaudeRequest(store, req)
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
if !containsStr(norm.Standard.FinalPrompt, "top-level system") {
|
||||
t.Fatalf("expected top-level system preserved, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
if !containsStr(norm.Standard.FinalPrompt, "You have access to these tools") {
|
||||
t.Fatalf("expected tool prompt injected, got=%q", norm.Standard.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -32,6 +33,82 @@ func TestResponsesMessagesFromRequestWithInstructions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesObjectRoleContentBlocks(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages(map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "input_text", "text": "line-1"},
|
||||
map[string]any{"type": "input_text", "text": "line-2"},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "user" {
|
||||
t.Fatalf("unexpected role: %#v", m)
|
||||
}
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(m["content"])) != "line-1\nline-2" {
|
||||
t.Fatalf("unexpected content: %#v", m["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallOutput(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": map[string]any{"ok": true},
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "tool" {
|
||||
t.Fatalf("expected tool role, got %#v", m)
|
||||
}
|
||||
if m["tool_call_id"] != "call_123" {
|
||||
t.Fatalf("expected tool_call_id propagated, got %#v", m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesInputAsMessagesFunctionCallItem(t *testing.T) {
|
||||
msgs := normalizeResponsesInputAsMessages([]any{
|
||||
map[string]any{
|
||||
"type": "function_call",
|
||||
"call_id": "call_456",
|
||||
"name": "search",
|
||||
"arguments": `{"q":"golang"}`,
|
||||
},
|
||||
})
|
||||
if len(msgs) != 1 {
|
||||
t.Fatalf("expected one message, got %d", len(msgs))
|
||||
}
|
||||
m, _ := msgs[0].(map[string]any)
|
||||
if m["role"] != "assistant" {
|
||||
t.Fatalf("expected assistant role, got %#v", m["role"])
|
||||
}
|
||||
toolCalls, _ := m["tool_calls"].([]any)
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool_call, got %#v", m["tool_calls"])
|
||||
}
|
||||
call, _ := toolCalls[0].(map[string]any)
|
||||
if call["id"] != "call_456" {
|
||||
t.Fatalf("expected call id preserved, got %#v", call)
|
||||
}
|
||||
if call["type"] != "function" {
|
||||
t.Fatalf("expected function type, got %#v", call)
|
||||
}
|
||||
fn, _ := call["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("expected call name preserved, got %#v", call)
|
||||
}
|
||||
if fn["arguments"] != `{"q":"golang"}` {
|
||||
t.Fatalf("expected call arguments preserved, got %#v", call)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractEmbeddingInputs(t *testing.T) {
|
||||
got := extractEmbeddingInputs([]any{"a", "b"})
|
||||
if len(got) != 2 || got[0] != "a" || got[1] != "b" {
|
||||
|
||||
@@ -203,40 +203,231 @@ func normalizeResponsesInputAsMessages(input any) []any {
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": v}}
|
||||
case []any:
|
||||
if len(v) == 0 {
|
||||
return nil
|
||||
}
|
||||
// If caller already provides role-shaped items, keep as-is.
|
||||
if first, ok := v[0].(map[string]any); ok {
|
||||
if _, hasRole := first["role"]; hasRole {
|
||||
return v
|
||||
}
|
||||
}
|
||||
parts := make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
parts = append(parts, txt)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
parts = append(parts, s)
|
||||
}
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
return []any{map[string]any{"role": "user", "content": strings.Join(parts, "\n")}}
|
||||
return normalizeResponsesInputArray(v)
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(v); msg != nil {
|
||||
return []any{msg}
|
||||
}
|
||||
if txt, _ := v["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": txt}}
|
||||
}
|
||||
if content, ok := v["content"].(string); ok && strings.TrimSpace(content) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
if content, ok := v["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return []any{map[string]any{"role": "user", "content": content}}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesInputArray(items []any) []any {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]any, 0, len(items))
|
||||
fallbackParts := make([]string, 0, len(items))
|
||||
flushFallback := func() {
|
||||
if len(fallbackParts) == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, map[string]any{"role": "user", "content": strings.Join(fallbackParts, "\n")})
|
||||
fallbackParts = fallbackParts[:0]
|
||||
}
|
||||
|
||||
for _, item := range items {
|
||||
switch x := item.(type) {
|
||||
case map[string]any:
|
||||
if msg := normalizeResponsesInputItem(x); msg != nil {
|
||||
flushFallback()
|
||||
out = append(out, msg)
|
||||
continue
|
||||
}
|
||||
if s := normalizeResponsesFallbackPart(x); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
default:
|
||||
if s := strings.TrimSpace(fmt.Sprintf("%v", item)); s != "" {
|
||||
fallbackParts = append(fallbackParts, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushFallback()
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeResponsesInputItem(m map[string]any) map[string]any {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role != "" {
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
|
||||
itemType := strings.ToLower(strings.TrimSpace(asString(m["type"])))
|
||||
switch itemType {
|
||||
case "message", "input_message":
|
||||
content := m["content"]
|
||||
if content == nil {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
content = txt
|
||||
}
|
||||
}
|
||||
if content == nil {
|
||||
return nil
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(asString(m["role"])))
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
content := m["output"]
|
||||
if content == nil {
|
||||
content = m["content"]
|
||||
}
|
||||
if content == nil {
|
||||
content = ""
|
||||
}
|
||||
out := map[string]any{
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["tool_call_id"])); callID != "" {
|
||||
out["tool_call_id"] = callID
|
||||
}
|
||||
if name := strings.TrimSpace(asString(m["name"])); name != "" {
|
||||
out["name"] = name
|
||||
} else if name = strings.TrimSpace(asString(m["tool_name"])); name != "" {
|
||||
out["name"] = name
|
||||
}
|
||||
return out
|
||||
case "function_call", "tool_call":
|
||||
name := strings.TrimSpace(asString(m["name"]))
|
||||
var fn map[string]any
|
||||
if rawFn, ok := m["function"].(map[string]any); ok {
|
||||
fn = rawFn
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(asString(fn["name"]))
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var argsRaw any
|
||||
if v, ok := m["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := m["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
if argsRaw == nil && fn != nil {
|
||||
if v, ok := fn["arguments"]; ok {
|
||||
argsRaw = v
|
||||
} else if v, ok := fn["input"]; ok {
|
||||
argsRaw = v
|
||||
}
|
||||
}
|
||||
|
||||
functionPayload := map[string]any{
|
||||
"name": name,
|
||||
"arguments": stringifyToolCallArguments(argsRaw),
|
||||
}
|
||||
call := map[string]any{
|
||||
"type": "function",
|
||||
"function": functionPayload,
|
||||
}
|
||||
if callID := strings.TrimSpace(asString(m["call_id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
} else if callID = strings.TrimSpace(asString(m["id"])); callID != "" {
|
||||
call["id"] = callID
|
||||
}
|
||||
return map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{call},
|
||||
}
|
||||
case "input_text":
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": txt,
|
||||
}
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if strings.TrimSpace(normalizeOpenAIContentForPrompt(content)) != "" {
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeResponsesFallbackPart(m map[string]any) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
if t, _ := m["type"].(string); strings.EqualFold(strings.TrimSpace(t), "input_text") {
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
}
|
||||
if txt, _ := m["text"].(string); strings.TrimSpace(txt) != "" {
|
||||
return txt
|
||||
}
|
||||
if content, ok := m["content"]; ok {
|
||||
if normalized := strings.TrimSpace(normalizeOpenAIContentForPrompt(content)); normalized != "" {
|
||||
return normalized
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", m))
|
||||
}
|
||||
|
||||
func stringifyToolCallArguments(v any) string {
|
||||
switch x := v.(type) {
|
||||
case nil:
|
||||
return "{}"
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return "{}"
|
||||
}
|
||||
return s
|
||||
default:
|
||||
b, err := json.Marshal(x)
|
||||
if err != nil || len(b) == 0 {
|
||||
return "{}"
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,8 +54,12 @@ func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T
|
||||
t.Fatalf("expected at least one tool_call in output, got %#v", first["tool_calls"])
|
||||
}
|
||||
call0, _ := toolCalls[0].(map[string]any)
|
||||
if call0["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", call0["name"])
|
||||
if call0["type"] != "function" {
|
||||
t.Fatalf("unexpected tool call type: %#v", call0["type"])
|
||||
}
|
||||
fn, _ := call0["function"].(map[string]any)
|
||||
if fn["name"] != "read_file" {
|
||||
t.Fatalf("unexpected tool call name: %#v", fn["name"])
|
||||
}
|
||||
if strings.Contains(outputText, `"tool_calls"`) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
|
||||
|
||||
@@ -48,17 +48,9 @@ func BuildResponseObject(responseID, model, finalPrompt, finalThinking, finalTex
|
||||
output := make([]any, 0, 2)
|
||||
if len(detected) > 0 {
|
||||
exposedOutputText = ""
|
||||
toolCalls := make([]any, 0, len(detected))
|
||||
for _, tc := range detected {
|
||||
toolCalls = append(toolCalls, map[string]any{
|
||||
"type": "tool_call",
|
||||
"name": tc.Name,
|
||||
"arguments": tc.Input,
|
||||
})
|
||||
}
|
||||
output = append(output, map[string]any{
|
||||
"type": "tool_calls",
|
||||
"tool_calls": toolCalls,
|
||||
"tool_calls": util.FormatOpenAIToolCalls(detected),
|
||||
})
|
||||
} else {
|
||||
content := []any{
|
||||
|
||||
64
internal/format/openai/render_test.go
Normal file
64
internal/format/openai/render_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
|
||||
obj := BuildResponseObject(
|
||||
"resp_test",
|
||||
"gpt-4o",
|
||||
"prompt",
|
||||
"",
|
||||
`{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
|
||||
[]string{"search"},
|
||||
)
|
||||
|
||||
outputText, _ := obj["output_text"].(string)
|
||||
if outputText != "" {
|
||||
t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
|
||||
}
|
||||
|
||||
output, _ := obj["output"].([]any)
|
||||
if len(output) != 1 {
|
||||
t.Fatalf("expected one tool_calls wrapper, got %#v", obj["output"])
|
||||
}
|
||||
|
||||
first, _ := output[0].(map[string]any)
|
||||
if first["type"] != "tool_calls" {
|
||||
t.Fatalf("expected first output item type tool_calls, got %#v", first["type"])
|
||||
}
|
||||
var toolCalls []map[string]any
|
||||
switch v := first["tool_calls"].(type) {
|
||||
case []map[string]any:
|
||||
toolCalls = v
|
||||
case []any:
|
||||
toolCalls = make([]map[string]any, 0, len(v))
|
||||
for _, item := range v {
|
||||
m, _ := item.(map[string]any)
|
||||
if m != nil {
|
||||
toolCalls = append(toolCalls, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(toolCalls) != 1 {
|
||||
t.Fatalf("expected one tool call, got %#v", first["tool_calls"])
|
||||
}
|
||||
tc := toolCalls[0]
|
||||
if tc["type"] != "function" || tc["id"] == "" {
|
||||
t.Fatalf("unexpected tool call shape: %#v", tc)
|
||||
}
|
||||
fn, _ := tc["function"].(map[string]any)
|
||||
if fn["name"] != "search" {
|
||||
t.Fatalf("unexpected function name: %#v", fn["name"])
|
||||
}
|
||||
argsRaw, _ := fn["arguments"].(string)
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
|
||||
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
|
||||
}
|
||||
if args["q"] != "golang" {
|
||||
t.Fatalf("unexpected arguments: %#v", args)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user