mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-16 06:05:07 +08:00
Align tool-call parsing across Go/JS and pass quality gates
This commit is contained in:
@@ -1,11 +1,20 @@
|
||||
package gemini
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxGeminiRawPromptChars = 1024
|
||||
|
||||
func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
out := make([]any, 0, 8)
|
||||
toolCallCounter := 0
|
||||
nextToolCallID := func() string {
|
||||
toolCallCounter++
|
||||
return fmt.Sprintf("call_gemini_%d", toolCallCounter)
|
||||
}
|
||||
lastToolCallIDByName := map[string]string{}
|
||||
if sys := normalizeGeminiSystemInstruction(req["systemInstruction"]); strings.TrimSpace(sys) != "" {
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
@@ -61,8 +70,11 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
if name := strings.TrimSpace(asString(fnCall["name"])); name != "" {
|
||||
callID := strings.TrimSpace(asString(fnCall["id"]))
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
if callID = strings.TrimSpace(asString(fnCall["call_id"])); callID == "" {
|
||||
callID = nextToolCallID()
|
||||
}
|
||||
}
|
||||
lastToolCallIDByName[strings.ToLower(name)] = callID
|
||||
out = append(out, map[string]any{
|
||||
"role": "assistant",
|
||||
"tool_calls": []any{
|
||||
@@ -91,7 +103,10 @@ func geminiMessagesFromRequest(req map[string]any) []any {
|
||||
callID = strings.TrimSpace(asString(fnResp["tool_call_id"]))
|
||||
}
|
||||
if callID == "" {
|
||||
callID = "call_gemini"
|
||||
callID = strings.TrimSpace(lastToolCallIDByName[strings.ToLower(name)])
|
||||
}
|
||||
if callID == "" {
|
||||
callID = nextToolCallID()
|
||||
}
|
||||
content := fnResp["response"]
|
||||
if content == nil {
|
||||
|
||||
@@ -82,3 +82,48 @@ func TestGeminiMessagesFromRequestPreservesUnknownPartAsRawJSONText(t *testing.T
|
||||
t.Fatalf("expected raw base64 payload not to be embedded, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiMessagesFromRequestBackfillsFunctionResponseCallIDByName(t *testing.T) {
|
||||
req := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"role": "model",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionCall": map[string]any{
|
||||
"name": "search_web",
|
||||
"args": map[string]any{"query": "docs"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"parts": []any{
|
||||
map[string]any{
|
||||
"functionResponse": map[string]any{
|
||||
"name": "search_web",
|
||||
"response": map[string]any{"ok": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
got := geminiMessagesFromRequest(req)
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected two normalized messages, got %#v", got)
|
||||
}
|
||||
assistant, _ := got[0].(map[string]any)
|
||||
tc, _ := assistant["tool_calls"].([]any)
|
||||
call, _ := tc[0].(map[string]any)
|
||||
callID, _ := call["id"].(string)
|
||||
if !strings.HasPrefix(callID, "call_gemini_") {
|
||||
t.Fatalf("expected generated call id prefix, got %#v", call)
|
||||
}
|
||||
toolMsg, _ := got[1].(map[string]any)
|
||||
if toolMsg["tool_call_id"] != callID {
|
||||
t.Fatalf("expected tool response to inherit generated call id, tool=%#v call=%#v", toolMsg, call)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user