mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-12 04:07:42 +08:00
补充工具调用行为说明并修正测试文档过时命令
This commit is contained in:
@@ -99,7 +99,7 @@ func TestGeminiRoutesRegistered(t *testing.T) {
|
||||
|
||||
func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{
|
||||
@@ -143,6 +143,42 @@ func TestGenerateContentReturnsFunctionCallParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateContentMixedToolSnippetAlsoTriggersFunctionCall(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"我来调用工具\n{\"tool_calls\":[{\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
h := &Handler{Store: testGeminiConfig{}, Auth: testGeminiAuth{}, DS: testGeminiDS{resp: upstream}}
|
||||
r := chi.NewRouter()
|
||||
RegisterRoutes(r, h)
|
||||
|
||||
body := `{
|
||||
"contents":[{"role":"user","parts":[{"text":"call tool"}]}],
|
||||
"tools":[{"functionDeclarations":[{"name":"eval_javascript","description":"eval","parameters":{"type":"object","properties":{"code":{"type":"string"}}}}]}]
|
||||
}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-pro:generateContent", strings.NewReader(body))
|
||||
req.Header.Set("Authorization", "Bearer direct-token")
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
candidates, _ := out["candidates"].([]any)
|
||||
c0, _ := candidates[0].(map[string]any)
|
||||
content, _ := c0["content"].(map[string]any)
|
||||
parts, _ := content["parts"].([]any)
|
||||
part0, _ := parts[0].(map[string]any)
|
||||
functionCall, _ := part0["functionCall"].(map[string]any)
|
||||
if functionCall["name"] != "eval_javascript" {
|
||||
t.Fatalf("expected functionCall name eval_javascript for mixed snippet, got %#v", functionCall)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamGenerateContentEmitsSSE(t *testing.T) {
|
||||
upstream := makeGeminiUpstreamResponse(
|
||||
`data: {"p":"response/content","v":"hello "}`,
|
||||
|
||||
@@ -513,8 +513,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -531,11 +531,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded tool json to remain text in strict mode, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,8 +552,8 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -573,11 +570,9 @@ func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
|
||||
if !strings.Contains(got, "我将调用工具。") {
|
||||
t.Fatalf("expected leading text to keep streaming, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,8 +591,8 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -614,8 +609,45 @@ func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T)
|
||||
if !strings.Contains(got, "接下来我会继续说明。") {
|
||||
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected tool_calls example text preserved, got=%q", got)
|
||||
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamFencedToolCallSnippetRemainsText(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
|
||||
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例,不要执行。"),
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta for fenced snippet, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
if c, ok := delta["content"].(string); ok {
|
||||
content.WriteString(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "```json") || !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected fenced tool snippet in content, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
@@ -640,8 +672,8 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta, body=%s", rec.Body.String())
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -655,14 +687,11 @@ func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
|
||||
}
|
||||
}
|
||||
got := content.String()
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") || !strings.Contains(got, "{") {
|
||||
t.Fatalf("expected embedded tool json to remain in text, got=%q", got)
|
||||
}
|
||||
if !strings.Contains(got, "后置正文C。") {
|
||||
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,9 +33,9 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
"role": "user",
|
||||
"content": formatToolResultForPrompt(msg),
|
||||
})
|
||||
case "user", "system":
|
||||
case "user", "system", "developer":
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": normalizeOpenAIContentForPrompt(msg["content"]),
|
||||
})
|
||||
default:
|
||||
@@ -47,7 +47,7 @@ func normalizeOpenAIMessagesForPrompt(raw []any, traceID string) []map[string]an
|
||||
role = "user"
|
||||
}
|
||||
out = append(out, map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
@@ -189,6 +189,14 @@ func marshalToPromptString(v any) string {
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func normalizeOpenAIRoleForPrompt(role string) string {
|
||||
role = strings.ToLower(strings.TrimSpace(role))
|
||||
if role == "developer" {
|
||||
return "system"
|
||||
}
|
||||
return role
|
||||
}
|
||||
|
||||
func asString(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
|
||||
@@ -193,3 +193,17 @@ func TestNormalizeOpenAIMessagesForPrompt_PreservesConcatenatedToolArguments(t *
|
||||
t.Fatalf("expected original concatenated arguments in tool history, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOpenAIMessagesForPrompt_DeveloperRoleMapsToSystem(t *testing.T) {
|
||||
raw := []any{
|
||||
map[string]any{"role": "developer", "content": "必须先走工具调用"},
|
||||
map[string]any{"role": "user", "content": "你好"},
|
||||
}
|
||||
normalized := normalizeOpenAIMessagesForPrompt(raw, "")
|
||||
if len(normalized) != 2 {
|
||||
t.Fatalf("expected 2 normalized messages, got %d", len(normalized))
|
||||
}
|
||||
if normalized[0]["role"] != "system" {
|
||||
t.Fatalf("expected developer role converted to system, got %#v", normalized[0]["role"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
return nil
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func normalizeResponsesInputItemWithState(m map[string]any, callNameByID map[str
|
||||
role = "user"
|
||||
}
|
||||
return map[string]any{
|
||||
"role": role,
|
||||
"role": normalizeOpenAIRoleForPrompt(role),
|
||||
"content": content,
|
||||
}
|
||||
case "function_call_output", "tool_result":
|
||||
|
||||
@@ -288,12 +288,8 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
|
||||
if len(addedPayloads) != 1 {
|
||||
t.Fatalf("expected only one message output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
item, _ := addedPayloads[0]["item"].(map[string]any)
|
||||
if asString(item["type"]) != "message" {
|
||||
t.Fatalf("expected only message output item in strict mode, got %#v", item)
|
||||
if len(addedPayloads) < 1 {
|
||||
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
|
||||
}
|
||||
|
||||
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
|
||||
@@ -302,15 +298,22 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleRemainMessageOnly(t *te
|
||||
}
|
||||
responseObj, _ := completedPayload["response"].(map[string]any)
|
||||
output, _ := responseObj["output"].([]any)
|
||||
hasMessage := false
|
||||
for _, item := range output {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if asString(m["type"]) == "message" {
|
||||
hasMessage = true
|
||||
}
|
||||
if asString(m["type"]) == "function_call" {
|
||||
t.Fatalf("did not expect function_call output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
if !hasMessage {
|
||||
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
|
||||
|
||||
@@ -15,19 +15,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
events := make([]toolStreamEvent, 0, 2)
|
||||
if len(state.pendingToolCalls) > 0 {
|
||||
pending := state.pending.String()
|
||||
if strings.TrimSpace(pending) != "" {
|
||||
content := state.pendingToolRaw + pending
|
||||
state.pending.Reset()
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
state.noteText(content)
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
} else {
|
||||
// Wait for either more non-whitespace content (demote to plain text)
|
||||
// or stream flush (promote to executable tool calls).
|
||||
return events
|
||||
}
|
||||
events = append(events, toolStreamEvent{ToolCalls: state.pendingToolCalls})
|
||||
state.pendingToolRaw = ""
|
||||
state.pendingToolCalls = nil
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -45,7 +35,14 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if len(calls) > 0 {
|
||||
state.pendingToolRaw = captured
|
||||
if prefix != "" {
|
||||
state.noteText(prefix)
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if suffix != "" {
|
||||
state.pending.WriteString(suffix)
|
||||
}
|
||||
_ = captured
|
||||
state.pendingToolCalls = calls
|
||||
continue
|
||||
}
|
||||
@@ -211,11 +208,6 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
|
||||
if insideCodeFence(state.recentTextTail + prefixPart) {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
// Strict mode: only standalone tool payloads are executable. If the
|
||||
// payload is wrapped by non-whitespace prose, keep it as plain text.
|
||||
if strings.TrimSpace(state.recentTextTail) != "" || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "" {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
parsed := util.ParseStandaloneToolCallsDetailed(obj, toolNames)
|
||||
if len(parsed.Calls) == 0 {
|
||||
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
|
||||
|
||||
@@ -2,9 +2,12 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
|
||||
|
||||
type ParsedToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Input map[string]any `json:"input"`
|
||||
@@ -121,12 +124,7 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
||||
if tc.Name == "" {
|
||||
continue
|
||||
}
|
||||
matchedName := ""
|
||||
if _, ok := allowed[tc.Name]; ok {
|
||||
matchedName = tc.Name
|
||||
} else if canonical, ok := allowedCanonical[strings.ToLower(tc.Name)]; ok {
|
||||
matchedName = canonical
|
||||
}
|
||||
matchedName := resolveAllowedToolName(tc.Name, allowed, allowedCanonical)
|
||||
if matchedName == "" {
|
||||
rejectedSet[tc.Name] = struct{}{}
|
||||
continue
|
||||
@@ -144,6 +142,31 @@ func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []strin
|
||||
return out, rejected
|
||||
}
|
||||
|
||||
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
|
||||
if _, ok := allowed[name]; ok {
|
||||
return name
|
||||
}
|
||||
lower := strings.ToLower(strings.TrimSpace(name))
|
||||
if canonical, ok := allowedCanonical[lower]; ok {
|
||||
return canonical
|
||||
}
|
||||
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
|
||||
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
|
||||
if loose == "" {
|
||||
return ""
|
||||
}
|
||||
for candidateLower, canonical := range allowedCanonical {
|
||||
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseToolCallsPayload(payload string) []ParsedToolCall {
|
||||
var decoded any
|
||||
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
|
||||
|
||||
@@ -115,3 +115,25 @@ func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
|
||||
t.Fatalf("expected fenced tool_call example to be ignored, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"search_web"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "search_web" {
|
||||
t.Fatalf("expected canonical tool name search_web, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
|
||||
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
|
||||
calls := ParseToolCalls(text, []string{"read_file"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected 1 call, got %#v", calls)
|
||||
}
|
||||
if calls[0].Name != "read_file" {
|
||||
t.Fatalf("expected canonical tool name read_file, got %q", calls[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user