补充工具调用行为说明并修正测试文档过时命令

This commit is contained in:
CJACK.
2026-03-03 00:39:02 +08:00
parent c329bf26b6
commit a6aa4a1839
12 changed files with 197 additions and 64 deletions

View File

@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
h := &Handler{
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
}
}
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
r := chi.NewRouter()
RegisterRoutes(r, h)
body := `{
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
}`
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
req.Header.Set("Authorization", "Bearer direct-token")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
candidates, _ := out["candidates"].([]any)
c0, _ := candidates[0].(map[string]any)
content, _ := c0["content"].(map[string]any)
parts, _ := content["parts"].([]any)
part0, _ := parts[0].(map[string]any)
functionCall, _ := part0["functionCall"].(map[string]any)
if functionCall["name"] != "eval_javascript" {
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
}
}
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
upstream := makeGeminiUpstreamResponse(
`data: {"p":"response/content","v":"hello "}`,

View File

@@ -513,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -531,11 +531,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
}
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
}
}
@@ -555,8 +552,8 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -573,11 +570,9 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
if !strings.Contains(got, "我将调用工具。") {
t.Fatalf("expected leading text to keep streaming, got=%q", got)
}
if !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
@@ -596,8 +591,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -614,8 +609,45 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
if !strings.Contains(got, "接下来我会继续说明。") {
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
}
if !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例不要执行。"),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta for fenced snippet, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "```json") || !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected fenced tool snippet in content, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
@@ -640,8 +672,8 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -655,14 +687,11 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
}
}
got := content.String()
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
}
if !strings.Contains(got, "后置正文C。") {
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

View File

@@ -33,9 +33,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
"role": "user",
"content": formatToolResultForPrompt(msg),
})
case "user", "system":
case "user", "system", "developer":
out = append(out, map[string]any{
"role": role,
"role": normalizeOpenAIRoleForPrompt(role),
"content": normalizeOpenAIContentForPrompt(msg["content"]),
})
default:
@@ -47,7 +47,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
role = "user"
}
out = append(out, map[string]any{
"role": role,
"role": normalizeOpenAIRoleForPrompt(role),
"content": content,
})
}
@@ -189,6 +189,14 @@ func marshalToPromptString(v any) string {
return string(b)
}
func normalizeOpenAIRoleForPrompt(role string) string {
role = strings.ToLower(strings.TrimSpace(role))
if role == "developer" {
return "system"
}
return role
}
func asString(v any) string {
if s, ok := v.(string); ok {
return s

View File

@@ -193,3 +193,17 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
}
}
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
raw := []any{
map[string]any{"role": "developer", "content": "必须先走工具调用"},
map[string]any{"role": "user", "content": "你好"},
}
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
if len(normalized) != 2 {
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
}
if normalized[0]["role"] != "system" {
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
}
}

View File

@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
return nil
}
return map[string]any{
"role": role,
"role": normalizeOpenAIRoleForPrompt(role),
"content": content,
}
}
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
role = "user"
}
return map[string]any{
"role": role,
"role": normalizeOpenAIRoleForPrompt(role),
"content": content,
}
case "function_call_output", "tool_result":

View File

@@ -288,12 +288,8 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
if len(addedPayloads) != 1 {
t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
}
item, _ := addedPayloads[0]["item"].(map[string]any)
if asString(item["type"]) != "message" {
t.Fatalf("expected only message output item in strict mode, got %#v", item)
if len(addedPayloads) < 1 {
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
}
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
@@ -302,15 +298,22 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
}
responseObj, _ := completedPayload["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
hasMessage := false
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
if asString(m["type"]) == "message" {
hasMessage = true
}
if asString(m["type"]) == "function_call" {
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
}
}
if !hasMessage {
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
}
}
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {

View File

@@ -15,19 +15,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
}
events := make([]toolStreamEvent, 0, 2)
if len(state.pendingToolCalls) > 0 {
pending := state.pending.String()
if strings.TrimSpace(pending) != "" {
content := state.pendingToolRaw + pending
state.pending.Reset()
state.pendingToolRaw = ""
state.pendingToolCalls = nil
state.noteText(content)
events = append(events, toolStreamEvent{Content: content})
} else {
// Wait for either more non-whitespace content (demote to plain text)
// or stream flush (promote to executable tool calls).
return events
}
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
state.pendingToolRaw = ""
state.pendingToolCalls = nil
}
for {
@@ -45,7 +35,14 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
state.capturing = false
state.resetIncrementalToolState()
if len(calls) > 0 {
state.pendingToolRaw = captured
if prefix != "" {
state.noteText(prefix)
events = append(events, toolStreamEvent{Content: prefix})
}
if suffix != "" {
state.pending.WriteString(suffix)
}
_ = captured
state.pendingToolCalls = calls
continue
}
@@ -211,11 +208,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
if insideCodeFence(state.recentTextTail + prefixPart) {
return captured, nil, "", true
}
// Strict mode: only standalone tool payloads are executable. If the
// payload is wrapped by non-whitespace prose, keep it as plain text.
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
return captured, nil, "", true
}
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
if len(parsed.Calls) == 0 {
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {

View File

@@ -2,9 +2,12 @@ package util
import (
"encoding/json"
"regexp"
"strings"
)
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
type ParsedToolCall struct {
Name string `json:"name"`
Input map[string]any `json:"input"`
@@ -121,12 +124,7 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
if tc.Name == "" {
continue
}
matchedName := ""
if _, ok := allowed[tc.Name]; ok {
matchedName = tc.Name
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
matchedName = canonical
}
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
if matchedName == "" {
rejectedSet[tc.Name] = struct{}{}
continue
@@ -144,6 +142,31 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
return out, rejected
}
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
if _, ok := allowed[name]; ok {
return name
}
lower := strings.ToLower(strings.TrimSpace(name))
if canonical, ok := allowedCanonical[lower]; ok {
return canonical
}
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
return canonical
}
}
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
if loose == "" {
return ""
}
for candidateLower, canonical := range allowedCanonical {
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
return canonical
}
}
return ""
}
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {

View File

@@ -115,3 +115,25 @@ func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
}
}
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
calls := ParseToolCalls(text, []string{"search_web"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "search_web" {
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
}
}
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
}
}