refactor: remove legacy function call support and simplify tool sieve logic

This commit is contained in:
CJACK
2026-04-19 04:38:48 +08:00
parent 146d59e7bf
commit 0c644d1f4d
21 changed files with 55 additions and 2747 deletions

View File

@@ -138,77 +138,6 @@ func TestHandleClaudeStreamRealtimeThinkingDelta(t *testing.T) {
}
}
func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, false, false, []string{"search"})
frames := parseClaudeFrames(t, rec.Body.String())
for _, f := range findClaudeFrames(frames, "content_block_delta") {
delta, _ := f.Payload["delta"].(map[string]any)
if delta["type"] == "text_delta" && strings.Contains(asString(delta["text"]), `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in text delta: body=%s", rec.Body.String())
}
}
foundToolUse := false
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" {
foundToolUse = true
break
}
}
if !foundToolUse {
t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String())
}
foundToolUseStop := false
for _, f := range findClaudeFrames(frames, "message_delta") {
delta, _ := f.Payload["delta"].(map[string]any)
if delta["stop_reason"] == "tool_use" {
foundToolUseStop = true
break
}
}
if !foundToolUseStop {
t.Fatalf("expected stop_reason=tool_use, body=%s", rec.Body.String())
}
}
func TestHandleClaudeStreamRealtimeToolDetectionFromThinkingFallback(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(
`data: {"p":"response/thinking_content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
`data: {"p":"response/thinking_content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", nil)
h.handleClaudeStreamRealtime(rec, req, resp, "claude-sonnet-4-5", []any{map[string]any{"role": "user", "content": "use tool"}}, true, false, []string{"search"})
frames := parseClaudeFrames(t, rec.Body.String())
foundToolUse := false
for _, f := range findClaudeFrames(frames, "content_block_start") {
contentBlock, _ := f.Payload["content_block"].(map[string]any)
if contentBlock["type"] == "tool_use" && contentBlock["name"] == "search" {
foundToolUse = true
break
}
}
if !foundToolUse {
t.Fatalf("expected tool_use block from thinking fallback, body=%s", rec.Body.String())
}
}
func TestHandleClaudeStreamRealtimeSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
h := &Handler{}
resp := makeClaudeSSEHTTPResponse(

View File

@@ -3,7 +3,6 @@ package openai
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -121,160 +120,7 @@ func streamToolCallArgumentChunks(frames []map[string]any) []string {
return out
}
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid1", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
if len(choices) != 1 {
t.Fatalf("unexpected choices: %#v", out["choices"])
}
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if msg["content"] != nil {
t.Fatalf("expected content nil, got %#v", msg["content"])
}
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %#v", msg["tool_calls"])
}
}
func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/thinking_content","v":"先想一下"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2", "deepseek-reasoner", "prompt", true, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
msg, _ := choice["message"].(map[string]any)
if msg["reasoning_content"] != "先想一下" {
t.Fatalf("expected reasoning_content, got %#v", msg["reasoning_content"])
}
if msg["content"] != nil {
t.Fatalf("expected content nil, got %#v", msg["content"])
}
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
}
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2b", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
}
}
func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"请勿执行。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected one tool_call field for embedded example: %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected raw tool_calls json stripped from content, got %#v", content)
}
}
func TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
"data: {\"p\":\"response/content\",\"v\":\"```json\\n{\\\"tool_calls\\\":[{\\\"name\\\":\\\"search\\\",\\\"input\\\":{\\\"q\\\":\\\"go\\\"}}]}\\n```\"}",
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2d", "deepseek-chat", "prompt", false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] == "tool_calls" {
t.Fatalf("expected fenced example to remain content-only, got finish_reason=%#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 0 {
t.Fatalf("expected no tool_call field for fenced example: %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected fenced example content preserved, got %q", content)
}
}
// Backward-compatible alias for historical test name used in CI logs.
func TestHandleNonStreamFencedToolCallExamplePromotesToolCall(t *testing.T) {
TestHandleNonStreamFencedToolCallExampleDoesNotPromoteToolCall(t)
}
func TestHandleNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -332,193 +178,6 @@ func TestHandleNonStreamReturns429WhenUpstreamHasOnlyThinking(t *testing.T) {
}
}
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
`data: {"p":"response/content","v":",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid3", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
foundToolIndex := false
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if _, ok := tcm["index"].(float64); ok {
foundToolIndex = true
}
}
}
}
if !foundToolIndex {
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallLargeArgumentsStillIntercepted(t *testing.T) {
h := &Handler{}
large := strings.Repeat("a", 9000)
payload := fmt.Sprintf(`{"tool_calls":[{"name":"search","input":{"q":"%s"}}]}`, large)
splitAt := len(payload) / 2
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[:splitAt]),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, payload[splitAt:]),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid3-large", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/thinking_content","v":"思考中"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid4", "deepseek-reasoner", "prompt", true, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
foundToolIndex := false
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if _, ok := tcm["index"].(float64); ok {
foundToolIndex = true
}
}
}
}
if !foundToolIndex {
t.Fatalf("expected stream tool_calls item with index, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
hasThinkingDelta := false
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if _, ok := delta["reasoning_content"]; ok {
hasThinkingDelta = true
}
}
}
if !hasThinkingDelta {
t.Fatalf("expected reasoning_content delta in reasoner stream: %s", rec.Body.String())
}
}
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid5", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid5b", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -557,287 +216,6 @@ func TestHandleStreamToolsPlainTextStreamsBeforeFinish(t *testing.T) {
}
}
func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"请勿执行。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "下面是示例:") || !strings.Contains(got, "请勿执行。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls for mixed prose, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallAfterLeadingTextRemainsText(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"我将调用工具。"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7b", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "我将调用工具。") {
t.Fatalf("expected leading text to keep streaming, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallWithSameChunkTrailingTextRemainsText(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}接下来我会继续说明。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7c", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "接下来我会继续说明。") {
t.Fatalf("expected trailing plain text to be preserved, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamFencedToolCallSnippetPromotesToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "下面是调用示例:\n```json\n"),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\n仅示例不要执行。"),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7f", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for fenced snippet, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected raw fenced tool_calls snippet stripped from content, got=%q", got)
}
if strings.Contains(strings.ToLower(got), "```json") || strings.Contains(got, "\n```\n") {
t.Fatalf("expected consumed fenced tool payload to not leave empty code fence, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamStandaloneToolCallAfterClosedFenceKeepsFence(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "先给一个代码示例:\n```text\nhello\n```\n"),
fmt.Sprintf(`data: {"p":"response/content","v":%q}`, "{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"),
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid7g", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for standalone payload, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "```") {
t.Fatalf("expected closed fence before standalone tool json to be preserved, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamToolCallKeyAppearsLateRemainsText(t *testing.T) {
h := &Handler{}
spaces := strings.Repeat(" ", 200)
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{`+spaces+`"}`,
`data: {"p":"response/content","v":"\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"后置正文C。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid8", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "后置正文C。") {
t.Fatalf("expected stream to continue after tool json convergence, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"前置正文D。"}`,
`data: {"p":"response/content","v":"{'tool_calls':[{'name':'search','input':{'q':'go'}}]}"}`,
`data: {"p":"response/content","v":"后置正文E。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid9", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta for invalid json, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
if c, ok := delta["content"].(string); ok {
content.WriteString(c)
}
}
}
got := content.String()
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
}
if !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
}
}
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -871,108 +249,3 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
}
}
func TestHandleStreamToolCallArgumentsEmitAsSingleCompletedChunk(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
argChunks := streamToolCallArgumentChunks(frames)
if len(argChunks) == 0 {
t.Fatalf("expected tool call arguments chunk, got=%v body=%s", argChunks, rec.Body.String())
}
joined := strings.Join(argChunks, "")
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
t.Fatalf("unexpected merged arguments stream: %q", joined)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}
func TestHandleStreamMultiToolCallDoesNotMergeNamesOrArguments(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search_web\",\"input\":{\"query\":\"latest ai news\"}},{"}`,
`data: {"p":"response/content","v":"\"name\":\"eval_javascript\",\"input\":{\"code\":\"1+1\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid12", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
foundSearch := false
foundEval := false
foundIndex1 := false
toolCallsDeltaLens := make([]int, 0, 2)
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
rawToolCalls, hasToolCalls := delta["tool_calls"]
if !hasToolCalls {
continue
}
toolCalls, _ := rawToolCalls.([]any)
toolCallsDeltaLens = append(toolCallsDeltaLens, len(toolCalls))
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
if idx, ok := tcm["index"].(float64); ok && int(idx) == 1 {
foundIndex1 = true
}
fn, _ := tcm["function"].(map[string]any)
name, _ := fn["name"].(string)
switch name {
case "search_web":
foundSearch = true
case "eval_javascript":
foundEval = true
case "search_webeval_javascript":
t.Fatalf("unexpected merged tool name: %s, body=%s", name, rec.Body.String())
}
if args, ok := fn["arguments"].(string); ok && strings.Contains(args, `}{"`) {
t.Fatalf("unexpected concatenated tool arguments: %q, body=%s", args, rec.Body.String())
}
}
}
}
if !foundSearch || !foundEval {
t.Fatalf("expected both tool names in stream deltas, foundSearch=%v foundEval=%v body=%s", foundSearch, foundEval, rec.Body.String())
}
if len(toolCallsDeltaLens) != 1 || toolCallsDeltaLens[0] != 2 {
t.Fatalf("expected exactly one tool_calls delta with two calls, got lens=%v body=%s", toolCallsDeltaLens, rec.Body.String())
}
if !foundIndex1 {
t.Fatalf("expected second tool call index in stream deltas, body=%s", rec.Body.String())
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

View File

@@ -12,149 +12,6 @@ import (
"ds2api/internal/util"
)
func TestHandleResponsesStreamToolCallsHideRawOutputTextInCompleted(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
rawToolJSON := `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`
streamBody := sseLine(rawToolJSON) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
completed, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed event, body=%s", rec.Body.String())
}
responseObj, _ := completed["response"].(map[string]any)
outputText, _ := responseObj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected empty output_text for tool_calls response, got output_text=%q", outputText)
}
output, _ := responseObj["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected structured output entries, got %#v", responseObj["output"])
}
hasFunctionCall := false
hasLegacyWrapper := false
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
if m["type"] == "function_call" {
hasFunctionCall = true
}
if m["type"] == "tool_calls" {
hasLegacyWrapper = true
}
}
if !hasFunctionCall {
t.Fatalf("expected function_call item, got %#v", responseObj["output"])
}
if hasLegacyWrapper {
t.Fatalf("did not expect legacy tool_calls wrapper, got %#v", responseObj["output"])
}
if strings.Contains(outputText, `"tool_calls"`) {
t.Fatalf("raw tool_calls JSON leaked in output_text: %q", outputText)
}
}
func TestHandleResponsesStreamUsesOfficialOutputItemEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.output_item.added") {
t.Fatalf("expected response.output_item.added event, body=%s", body)
}
if !strings.Contains(body, "event: response.output_item.done") {
t.Fatalf("expected response.output_item.done event, body=%s", body)
}
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected response.function_call_arguments.done event, body=%s", body)
}
if strings.Contains(body, "event: response.output_tool_call.delta") || strings.Contains(body, "event: response.output_tool_call.done") {
t.Fatalf("legacy response.output_tool_call.* event must not appear, body=%s", body)
}
addedPayloads := extractAllSSEEventPayloads(body, "response.output_item.added")
hasFunctionCallAdded := false
for _, payload := range addedPayloads {
item, _ := payload["item"].(map[string]any)
if item == nil || asString(item["type"]) != "function_call" {
continue
}
hasFunctionCallAdded = true
if asString(item["arguments"]) != "" {
t.Fatalf("expected in-progress function_call.arguments to start empty string, got %#v", item["arguments"])
}
}
if !hasFunctionCallAdded {
t.Fatalf("expected function_call output_item.added payload, body=%s", body)
}
donePayload, ok := extractSSEEventPayload(body, "response.function_call_arguments.done")
if !ok {
t.Fatalf("expected to parse response.function_call_arguments.done payload, body=%s", body)
}
doneCallID := strings.TrimSpace(asString(donePayload["call_id"]))
if doneCallID == "" {
t.Fatalf("expected non-empty call_id in done payload, payload=%#v", donePayload)
}
completed, ok := extractSSEEventPayload(body, "response.completed")
if !ok {
t.Fatalf("expected response.completed payload, body=%s", body)
}
responseObj, _ := completed["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
var completedCallID string
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil || m["type"] != "function_call" {
continue
}
completedCallID = strings.TrimSpace(asString(m["call_id"]))
if completedCallID != "" {
break
}
}
if completedCallID == "" {
t.Fatalf("expected function_call.call_id in completed output, output=%#v", output)
}
if completedCallID != doneCallID {
t.Fatalf("expected completed call_id to match stream done call_id, done=%q completed=%q", doneCallID, completedCallID)
}
}
func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -181,51 +38,6 @@ func TestHandleResponsesStreamDoesNotEmitReasoningTextCompatEvents(t *testing.T)
}
}
func TestHandleResponsesStreamMultiToolCallKeepsNameAndCallIDAligned(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"search_web","input":{"query":"latest ai news"}},`) +
sseLine(`{"name":"eval_javascript","input":{"code":"1+1"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"search_web", "eval_javascript"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
donePayloads := extractAllSSEEventPayloads(body, "response.function_call_arguments.done")
if len(donePayloads) != 2 {
t.Fatalf("expected two response.function_call_arguments.done events, got %d body=%s", len(donePayloads), body)
}
seenNames := map[string]string{}
for _, payload := range donePayloads {
name := strings.TrimSpace(asString(payload["name"]))
callID := strings.TrimSpace(asString(payload["call_id"]))
if name != "search_web" && name != "eval_javascript" {
t.Fatalf("unexpected tool name in done payload: %#v", payload)
}
if callID == "" {
t.Fatalf("expected non-empty call_id in done payload: %#v", payload)
}
seenNames[name] = callID
}
if seenNames["search_web"] == seenNames["eval_javascript"] {
t.Fatalf("expected distinct call_id per tool, got %#v", seenNames)
}
}
func TestHandleResponsesStreamEmitsOutputTextDoneBeforeContentPartDone(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -297,123 +109,6 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
}
}
func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, value string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": value,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", "thinking...") +
sseLine("response/content", "先读取文件。") +
sseLine("response/content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-reasoner", "prompt", true, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
addedPayloads := extractAllSSEEventPayloads(rec.Body.String(), "response.output_item.added")
if len(addedPayloads) < 1 {
t.Fatalf("expected at least one output_item.added event, got %d body=%s", len(addedPayloads), rec.Body.String())
}
completedPayload, ok := extractSSEEventPayload(rec.Body.String(), "response.completed")
if !ok {
t.Fatalf("expected response.completed payload, body=%s", rec.Body.String())
}
responseObj, _ := completedPayload["response"].(map[string]any)
output, _ := responseObj["output"].([]any)
hasMessage := false
hasFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m == nil {
continue
}
if asString(m["type"]) == "message" {
hasMessage = true
}
if asString(m["type"]) == "function_call" {
hasFunctionCall = true
}
}
if !hasMessage {
t.Fatalf("expected message output for mixed prose tool example, output=%#v", output)
}
if !hasFunctionCall {
t.Fatalf("expected function_call output for mixed prose tool example, output=%#v", output)
}
}
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
}
}
func TestHandleResponsesStreamMalformedToolJSONFallsBackToText(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
// invalid JSON (NaN) should remain plain text in strict mode.
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if strings.Contains(body, "event: response.function_call_arguments.delta") || strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("did not expect function_call events for malformed payload in strict mode, body=%s", body)
}
if !strings.Contains(body, "event: response.output_text.delta") {
t.Fatalf("expected response.output_text.delta for malformed payload, body=%s", body)
}
if !strings.Contains(body, "event: response.completed") {
t.Fatalf("expected response.completed event, body=%s", body)
}
}
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -448,76 +143,6 @@ func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
}
}
func TestHandleResponsesStreamRequiredToolChoiceIgnoresThinkingToolPayload(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(path, value string) string {
b, _ := json.Marshal(map[string]any{
"p": path,
"v": value,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("response/thinking_content", `{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"}}]}`) +
sseLine("response/content", "plain text only") +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", true, false, []string{"read_file"}, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
t.Fatalf("expected response.failed event for required tool_choice violation, body=%s", body)
}
if strings.Contains(body, "event: response.completed") {
t.Fatalf("did not expect response.completed after failure, body=%s", body)
}
}
func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"read_file","input":{"path":"README.MD"},"x":NaN}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
policy := util.ToolChoicePolicy{
Mode: util.ToolChoiceRequired,
Allowed: map[string]struct{}{"read_file": {}},
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, policy, "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.failed") {
t.Fatalf("expected response.failed event, body=%s", body)
}
if strings.Contains(body, "event: response.completed") {
t.Fatalf("did not expect response.completed, body=%s", body)
}
}
func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -556,32 +181,6 @@ func TestHandleResponsesStreamFailsWhenUpstreamHasOnlyThinking(t *testing.T) {
}
}
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine(`{"tool_calls":[{"name":"not_in_schema","input":{"q":"go"}}]}`) + "data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
}
}
func TestHandleResponsesNonStreamRequiredToolChoiceViolation(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
@@ -635,36 +234,6 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
}
}
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}"}` + "\n" +
`data: [DONE]` + "\n",
)),
}
policy := util.ToolChoicePolicy{Mode: util.ToolChoiceNone}
h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
}
out := decodeJSONBody(t, rec.Body.String())
output, _ := out["output"].([]any)
foundFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
foundFunctionCall = true
}
}
if !foundFunctionCall {
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
}
}
func TestHandleResponsesNonStreamReturns429WhenUpstreamOutputEmpty(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()

View File

@@ -146,53 +146,6 @@ func TestResponsesStreamStatusCapturedAs200(t *testing.T) {
}
}
func TestResponsesNonStreamMixedProseToolPayloadHandlerPath(t *testing.T) {
statuses := make([]int, 0, 1)
content, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": "我来调用工具\n{\"tool_calls\":[{\"name\":\"read_file\",\"input\":{\"path\":\"README.MD\"}}]}",
})
h := &Handler{
Store: mockOpenAIConfig{wideInput: true},
Auth: streamStatusAuthStub{},
DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse("data: "+string(content), "data: [DONE]")},
}
r := chi.NewRouter()
r.Use(captureStatusMiddleware(&statuses))
RegisterRoutes(r, h)
reqBody := `{"model":"deepseek-chat","input":"请调用工具","tools":[{"type":"function","function":{"name":"read_file","description":"read","parameters":{"type":"object","properties":{"path":{"type":"string"}}}}}],"stream":false}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if len(statuses) != 1 || statuses[0] != http.StatusOK {
t.Fatalf("expected captured status 200, got %#v", statuses)
}
var out map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String())
}
outputText, _ := out["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
}
output, _ := out["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected one output item, got %#v", output)
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected function_call output item, got %#v", output)
}
}
func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T) {
statuses := make([]int, 0, 1)
h := &Handler{

View File

@@ -147,35 +147,13 @@ func splitSafeContentForToolDetection(s string) (safe, hold string) {
if s == "" {
return "", ""
}
suspiciousStart := findSuspiciousPrefixStart(s)
if suspiciousStart < 0 {
return s, ""
}
if suspiciousStart > 0 {
return s[:suspiciousStart], s[suspiciousStart:]
}
// If suspicious content starts at position 0, keep holding until we can
// parse a complete tool JSON block or reach stream flush.
return "", s
}
func findSuspiciousPrefixStart(s string) int {
start := -1
indices := []int{
strings.LastIndex(s, "{"),
strings.LastIndex(s, "["),
strings.LastIndex(s, "```"),
}
for _, idx := range indices {
if idx > start {
start = idx
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 {
if xmlIdx > 0 {
return s[:xmlIdx], s[xmlIdx:]
}
return "", s
}
// Also check for partial XML tool tag at end of string.
if xmlIdx := findPartialXMLToolTagStart(s); xmlIdx >= 0 && xmlIdx > start {
start = xmlIdx
}
return start
return s, ""
}
func findToolSegmentStart(s string) int {
@@ -183,47 +161,14 @@ func findToolSegmentStart(s string) int {
return -1
}
lower := strings.ToLower(s)
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
bestKeyIdx := -1
for _, kw := range keywords {
idx := strings.Index(lower, kw)
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
bestKeyIdx = idx
}
}
if fnKeyIdx := findQuotedFunctionCallKeyStart(s); fnKeyIdx >= 0 && (bestKeyIdx < 0 || fnKeyIdx < bestKeyIdx) {
bestKeyIdx = fnKeyIdx
}
// Also detect XML tool call tags.
for _, tag := range xmlToolTagsToDetect {
idx := strings.Index(lower, tag)
if idx >= 0 && (bestKeyIdx < 0 || idx < bestKeyIdx) {
bestKeyIdx = idx
}
}
if bestKeyIdx < 0 {
return -1
}
// For XML tags, the '<' is itself the segment start.
if bestKeyIdx < len(s) && s[bestKeyIdx] == '<' {
if fenceStart, ok := openFenceStartBefore(s, bestKeyIdx); ok {
return fenceStart
}
return bestKeyIdx
}
start := strings.LastIndex(s[:bestKeyIdx], "{")
if start < 0 {
start = bestKeyIdx
}
// If the keyword matched inside an XML tag (e.g. "tool_calls" in "<tool_calls>"),
// back up past the '<' to capture the full tag.
if start > 0 && s[start-1] == '<' {
start--
}
if fenceStart, ok := openFenceStartBefore(s, start); ok {
return fenceStart
}
return start
return bestKeyIdx
}
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []toolcall.ParsedToolCall, suffix string, ready bool) {
@@ -232,7 +177,7 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
return "", nil, "", false
}
// Try XML tool call extraction first.
// XML tool call extraction only.
if xmlPrefix, xmlCalls, xmlSuffix, xmlReady := consumeXMLToolCapture(captured, toolNames); xmlReady {
return xmlPrefix, xmlCalls, xmlSuffix, true
}
@@ -240,45 +185,5 @@ func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix
if hasOpenXMLToolTag(captured) {
return "", nil, "", false
}
lower := strings.ToLower(captured)
keyIdx := -1
keywords := []string{"tool_calls", "\"function\"", "function.name:", "\"tool_use\""}
for _, kw := range keywords {
idx := strings.Index(lower, kw)
if idx >= 0 && (keyIdx < 0 || idx < keyIdx) {
keyIdx = idx
}
}
if fnKeyIdx := findQuotedFunctionCallKeyStart(captured); fnKeyIdx >= 0 && (keyIdx < 0 || fnKeyIdx < keyIdx) {
keyIdx = fnKeyIdx
}
if keyIdx < 0 {
return "", nil, "", false
}
start := strings.LastIndex(captured[:keyIdx], "{")
if start < 0 {
start = keyIdx
}
obj, end, ok := extractJSONObjectFrom(captured, start)
if !ok {
return "", nil, "", false
}
prefixPart := captured[:start]
suffixPart := captured[end:]
parsed := toolcall.ParseStandaloneToolCallsDetailed(obj, toolNames)
if len(parsed.Calls) == 0 {
if parsed.SawToolCallSyntax && parsed.RejectedByPolicy {
// Parsed as tool-call payload but rejected by schema/policy:
// consume it to avoid leaking raw tool_calls JSON to user content.
return prefixPart, nil, suffixPart, true
}
// If it has obvious keywords but failed to parse even after loose repair,
// we still might want to intercept it if it looks like an attempt at tool call.
// For now, keep the original logic but rely on loose JSON repair.
return captured, nil, "", true
}
prefixPart, suffixPart = trimWrappingJSONFence(prefixPart, suffixPart)
return prefixPart, parsed.Calls, suffixPart, true
return "", nil, "", false
}

View File

@@ -1,100 +0,0 @@
package openai
import "strings"
func findQuotedFunctionCallKeyStart(s string) int {
lower := strings.ToLower(s)
quotedIdx := findFunctionCallKeyStart(lower, `"functioncall"`)
bareIdx := findFunctionCallKeyStart(lower, "functioncall")
// Prefer the quoted JSON key whenever we have a structural match.
// Bare-key detection is only for loose payloads where the quoted form
// is absent.
if quotedIdx >= 0 {
return quotedIdx
}
return bareIdx
}
func findFunctionCallKeyStart(lower, key string) int {
for from := 0; from < len(lower); {
rel := strings.Index(lower[from:], key)
if rel < 0 {
return -1
}
idx := from + rel
if isInsideJSONString(lower, idx) {
from = idx + 1
continue
}
if !hasJSONObjectContextPrefix(lower[:idx]) {
from = idx + 1
continue
}
if !hasJSONKeyBoundary(lower, idx, len(key)) {
from = idx + 1
continue
}
j := idx + len(key)
for j < len(lower) && (lower[j] == ' ' || lower[j] == '\t' || lower[j] == '\r' || lower[j] == '\n') {
j++
}
if j < len(lower) && lower[j] == ':' {
k := j + 1
for k < len(lower) && (lower[k] == ' ' || lower[k] == '\t' || lower[k] == '\r' || lower[k] == '\n') {
k++
}
if k < len(lower) && lower[k] != '{' {
from = idx + 1
continue
}
return idx
}
from = idx + 1
}
return -1
}
func isInsideJSONString(s string, idx int) bool {
inString := false
escaped := false
for i := 0; i < idx; i++ {
c := s[i]
if escaped {
escaped = false
continue
}
if c == '\\' && inString {
escaped = true
continue
}
if c == '"' {
inString = !inString
}
}
return inString
}
func hasJSONObjectContextPrefix(prefix string) bool {
return strings.LastIndex(prefix, "{") >= 0
}
func hasJSONKeyBoundary(s string, idx, keyLen int) bool {
if idx > 0 {
prev := s[idx-1]
if isLowerAlphaNumeric(prev) {
return false
}
}
if end := idx + keyLen; end < len(s) {
next := s[end]
if isLowerAlphaNumeric(next) {
return false
}
}
return true
}
func isLowerAlphaNumeric(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= '0' && b <= '9') || b == '_'
}

View File

@@ -1,23 +0,0 @@
package openai
import "testing"
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierBareKey(t *testing.T) {
input := `{functionCall:{"name":"a","arguments":"{}"},"message":"literal text: \"functionCall\": not a key"}`
got := findQuotedFunctionCallKeyStart(input)
want := 1
if got != want {
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
}
}
func TestFindQuotedFunctionCallKeyStart_PrefersEarlierQuotedKey(t *testing.T) {
input := `{"functionCall":{"name":"a","arguments":"{}"},"note":"functionCall appears in prose"}`
got := findQuotedFunctionCallKeyStart(input)
want := 1
if got != want {
t.Fatalf("findQuotedFunctionCallKeyStart() = %d, want %d", got, want)
}
}

View File

@@ -147,11 +147,11 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
want int
}{
{"tool_calls_tag", "some text <tool_calls>\n", 10},
{"gemini_function_call_json", `some text {"functionCall":{"name":"search","args":{"q":"latest"}}}`, 10},
{"tool_call_tag", "prefix <tool_call>\n", 7},
{"invoke_tag", "text <invoke name=\"foo\">body</invoke>", 5},
{"function_call_tag", "<function_call name=\"foo\">body</function_call>", 0},
{"no_xml", "just plain text", -1},
{"gemini_json_no_detect", `some text {"functionCall":{"name":"search"}}`, -1},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
@@ -163,81 +163,6 @@ func TestFindToolSegmentStartDetectsXMLToolCalls(t *testing.T) {
}
}
func TestFindToolSegmentStartIgnoresFunctionCallProse(t *testing.T) {
input := "Please explain the functionCall API field and how clients should parse it."
if got := findToolSegmentStart(input); got != -1 {
t.Fatalf("expected no tool segment start for prose, got %d", got)
}
}
func TestFindToolSegmentStartDetectsQuotedFunctionCallKey(t *testing.T) {
input := `prefix {"functionCall": {"name":"search_web","args":{"query":"x"}}}`
want := strings.Index(input, "{")
if got := findToolSegmentStart(input); got != want {
t.Fatalf("expected JSON object start %d, got %d", want, got)
}
}
func TestFindToolSegmentStartDetectsLooseFunctionCallKey(t *testing.T) {
input := `prefix {functionCall: {"name":"search_web","args":{"query":"x"}}}`
want := strings.Index(input, "{")
if got := findToolSegmentStart(input); got != want {
t.Fatalf("expected JSON object start %d, got %d", want, got)
}
}
func TestFindToolSegmentStartPrefersQuotedFunctionCallOverEarlierBareProse(t *testing.T) {
input := `prefix {note} functionCall: docs hint {"functionCall":{"name":"search_web","args":{"query":"x"}}}`
want := strings.Index(input, `{"functionCall"`)
if got := findToolSegmentStart(input); got != want {
t.Fatalf("expected quoted functionCall JSON start %d, got %d", want, got)
}
}
func TestFindToolSegmentStartIgnoresLooseFunctionCallProse(t *testing.T) {
input := "Please explain why functionCall: is used in documentation examples."
if got := findToolSegmentStart(input); got != -1 {
t.Fatalf("expected no tool segment start for prose, got %d", got)
}
}
func TestProcessToolSieveDoesNotBufferFunctionCallProse(t *testing.T) {
var state toolStreamSieveState
chunk := "Please explain the functionCall API field and keep streaming this sentence."
events := processToolSieveChunk(&state, chunk, []string{"search_web"})
var text string
for _, evt := range events {
text += evt.Content
if len(evt.ToolCalls) > 0 {
t.Fatalf("expected no tool calls for prose, got %#v", evt.ToolCalls)
}
}
if text != chunk {
t.Fatalf("expected prose to pass through immediately, got %q", text)
}
}
func TestProcessToolSieveDetectsGeminiFunctionCallPayload(t *testing.T) {
var state toolStreamSieveState
events := processToolSieveChunk(&state, `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`, []string{"search_web"})
events = append(events, flushToolSieve(&state, []string{"search_web"})...)
var textContent string
var toolCalls int
for _, evt := range events {
if evt.Content != "" {
textContent += evt.Content
}
toolCalls += len(evt.ToolCalls)
}
if toolCalls != 1 {
t.Fatalf("expected one tool call from functionCall payload, got events=%#v", events)
}
if strings.Contains(strings.ToLower(textContent), "functioncall") {
t.Fatalf("functionCall json leaked into text content: %q", textContent)
}
}
func TestFindPartialXMLToolTagStart(t *testing.T) {
cases := []struct {
name string

View File

@@ -1,12 +1,10 @@
package compat
import (
"ds2api/internal/toolcall"
"encoding/json"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"ds2api/internal/sse"
@@ -65,55 +63,6 @@ func TestGoCompatSSEFixtures(t *testing.T) {
}
}
func TestGoCompatToolcallFixtures(t *testing.T) {
files, err := filepath.Glob(compatPath("fixtures", "toolcalls", "*.json"))
if err != nil {
t.Fatalf("glob toolcall fixtures failed: %v", err)
}
if len(files) == 0 {
t.Fatal("no toolcall fixtures found")
}
for _, fixturePath := range files {
name := trimExt(filepath.Base(fixturePath))
expectedPath := compatPath("expected", "toolcalls_"+name+".json")
var fixture struct {
Text string `json:"text"`
ToolNames []string `json:"tool_names"`
Mode string `json:"mode"`
}
mustLoadJSON(t, fixturePath, &fixture)
var expected struct {
Calls []toolcall.ParsedToolCall `json:"calls"`
SawToolCallSyntax bool `json:"sawToolCallSyntax"`
RejectedByPolicy bool `json:"rejectedByPolicy"`
RejectedToolNames []string `json:"rejectedToolNames"`
}
mustLoadJSON(t, expectedPath, &expected)
var got toolcall.ToolCallParseResult
switch strings.ToLower(strings.TrimSpace(fixture.Mode)) {
case "standalone":
got = toolcall.ParseStandaloneToolCallsDetailed(fixture.Text, fixture.ToolNames)
default:
got = toolcall.ParseToolCallsDetailed(fixture.Text, fixture.ToolNames)
}
if got.Calls == nil {
got.Calls = []toolcall.ParsedToolCall{}
}
if got.RejectedToolNames == nil {
got.RejectedToolNames = []string{}
}
if !reflect.DeepEqual(got.Calls, expected.Calls) ||
got.SawToolCallSyntax != expected.SawToolCallSyntax ||
got.RejectedByPolicy != expected.RejectedByPolicy ||
!reflect.DeepEqual(got.RejectedToolNames, expected.RejectedToolNames) {
t.Fatalf("toolcall fixture %s mismatch:\n got=%#v\nwant=%#v", name, got, expected)
}
}
}
func TestGoCompatTokenFixtures(t *testing.T) {
var fixture struct {
Cases []struct {

View File

@@ -2,32 +2,6 @@ package claude
import "testing"
func TestBuildMessageResponseDetectsToolCallsFromThinkingFallback(t *testing.T) {
resp := BuildMessageResponse(
"msg_1",
"claude-sonnet-4-5",
[]any{map[string]any{"role": "user", "content": "hi"}},
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
"",
[]string{"search"},
)
if resp["stop_reason"] != "tool_use" {
t.Fatalf("expected stop_reason=tool_use, got=%#v", resp["stop_reason"])
}
content, _ := resp["content"].([]map[string]any)
if len(content) < 2 {
t.Fatalf("expected thinking + tool_use content blocks, got=%#v", resp["content"])
}
last := content[len(content)-1]
if last["type"] != "tool_use" {
t.Fatalf("expected last content block tool_use, got=%#v", last["type"])
}
if last["name"] != "search" {
t.Fatalf("expected tool name search, got=%#v", last["name"])
}
}
func TestBuildMessageResponseSkipsThinkingFallbackWhenFinalTextExists(t *testing.T) {
resp := BuildMessageResponse(
"msg_1",

View File

@@ -1,75 +1,10 @@
package openai
import (
"encoding/json"
"strings"
"testing"
)
func TestBuildResponseObjectToolCallsFollowChatShape(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"golang"}}]}`,
[]string{"search"},
)
outputText, _ := obj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text to be hidden for tool calls, got %q", outputText)
}
output, _ := obj["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected function_call output only, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected first output item type function_call, got %#v", first["type"])
}
if first["call_id"] == "" {
t.Fatalf("expected function_call item to have call_id, got %#v", first)
}
if first["name"] != "search" {
t.Fatalf("unexpected function name: %#v", first["name"])
}
argsRaw, _ := first["arguments"].(string)
var args map[string]any
if err := json.Unmarshal([]byte(argsRaw), &args); err != nil {
t.Fatalf("arguments should be valid json string, got=%q err=%v", argsRaw, err)
}
if args["q"] != "golang" {
t.Fatalf("unexpected arguments: %#v", args)
}
}
func TestBuildResponseObjectPromotesMixedProseToolPayloadToFunctionCall(t *testing.T) {
obj := BuildResponseObject(
"resp_test",
"gpt-4o",
"prompt",
"",
`示例格式:{"tool_calls":[{"name":"search","input":{"q":"golang"}}]},但这条是普通回答。`,
[]string{"search"},
)
outputText, _ := obj["output_text"].(string)
if outputText != "" {
t.Fatalf("expected output_text hidden for mixed prose tool payload, got %q", outputText)
}
output, _ := obj["output"].([]any)
if len(output) != 1 {
t.Fatalf("expected one function_call output item, got %#v", obj["output"])
}
first, _ := output[0].(map[string]any)
if first["type"] != "function_call" {
t.Fatalf("expected function_call output type, got %#v", first["type"])
}
}
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
obj := BuildResponseObject(
"resp_test",

View File

@@ -4,7 +4,7 @@ import (
"testing"
)
// ─── FormatOpenAIStreamToolCalls ─────────────────────────────────────
// --- FormatOpenAIStreamToolCalls ---
func TestFormatOpenAIStreamToolCalls(t *testing.T) {
formatted := FormatOpenAIStreamToolCalls([]ParsedToolCall{
@@ -22,15 +22,7 @@ func TestFormatOpenAIStreamToolCalls(t *testing.T) {
}
}
// ─── ParseToolCalls more edge cases ──────────────────────────────────
func TestParseToolCallsNoToolNames(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
calls := ParseToolCalls(text, nil)
if len(calls) != 1 {
t.Fatalf("expected 1 call with nil tool names, got %d", len(calls))
}
}
// --- ParseToolCalls edge cases ---
func TestParseToolCallsEmptyText(t *testing.T) {
calls := ParseToolCalls("", []string{"search"})
@@ -38,55 +30,3 @@ func TestParseToolCallsEmptyText(t *testing.T) {
t.Fatalf("expected 0 calls for empty text, got %d", len(calls))
}
}
func TestParseToolCallsMultipleTools(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}},{"name":"get_weather","input":{"city":"beijing"}}]}`
calls := ParseToolCalls(text, []string{"search", "get_weather"})
if len(calls) != 2 {
t.Fatalf("expected 2 calls, got %d", len(calls))
}
}
func TestParseToolCallsInputAsString(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":"{\"q\":\"golang\"}"}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Input["q"] != "golang" {
t.Fatalf("expected parsed string input, got %#v", calls[0].Input)
}
}
func TestParseToolCallsWithFunctionWrapper(t *testing.T) {
text := `{"tool_calls":[{"function":{"name":"calc","arguments":{"x":1,"y":2}}}]}`
calls := ParseToolCalls(text, []string{"calc"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "calc" {
t.Fatalf("expected calc, got %q", calls[0].Name)
}
}
func TestParseStandaloneToolCallsFencedCodeBlock(t *testing.T) {
fenced := "Here's an example:\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```\nDon't execute this."
calls := ParseStandaloneToolCalls(fenced, []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected fenced code block to be ignored, got %d calls", len(calls))
}
}
// ─── looksLikeToolExampleContext ─────────────────────────────────────
func TestLooksLikeToolExampleContextNone(t *testing.T) {
if looksLikeToolExampleContext("I will call the tool now") {
t.Fatal("expected false for non-example context")
}
}
func TestLooksLikeToolExampleContextFenced(t *testing.T) {
if !looksLikeToolExampleContext("```json") {
t.Fatal("expected true for fenced code block context")
}
}

View File

@@ -1,205 +0,0 @@
package toolcall
import (
"regexp"
"strings"
)
var toolCallPattern = regexp.MustCompile(`\{\s*["']tool_calls["']\s*:\s*\[(.*?)\]\s*\}`)
var fencedJSONPattern = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
var fencedCodeBlockPattern = regexp.MustCompile("(?s)```[\\s\\S]*?```")
//nolint:unused // retained for future markup tool-call heuristics.
var markupToolSyntaxPattern = regexp.MustCompile(`(?i)<(?:(?:[a-z0-9_:-]+:)?(?:tool_call|function_call|invoke)\b|(?:[a-z0-9_:-]+:)?function_calls\b|(?:[a-z0-9_:-]+:)?tool_use\b)`)
func buildToolCallCandidates(text string) []string {
trimmed := strings.TrimSpace(text)
candidates := []string{trimmed}
// fenced code block candidates: ```json ... ```
for _, match := range fencedJSONPattern.FindAllStringSubmatch(trimmed, -1) {
if len(match) >= 2 {
candidates = append(candidates, strings.TrimSpace(match[1]))
}
}
// best-effort extraction around tool call keywords in mixed text payloads.
candidates = append(candidates, extractToolCallObjects(trimmed)...)
// best-effort object slice: from first '{' to last '}'
first := strings.Index(trimmed, "{")
last := strings.LastIndex(trimmed, "}")
if first >= 0 && last > first {
candidates = append(candidates, strings.TrimSpace(trimmed[first:last+1]))
}
// best-effort array slice: from first '[' to last ']'
firstArr := strings.Index(trimmed, "[")
lastArr := strings.LastIndex(trimmed, "]")
if firstArr >= 0 && lastArr > firstArr {
candidates = append(candidates, strings.TrimSpace(trimmed[firstArr:lastArr+1]))
}
// legacy regex extraction fallback
if m := toolCallPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
candidates = append(candidates, "{"+`"tool_calls":[`+m[1]+"]}")
}
uniq := make([]string, 0, len(candidates))
seen := map[string]struct{}{}
for _, c := range candidates {
if c == "" {
continue
}
if _, ok := seen[c]; ok {
continue
}
seen[c] = struct{}{}
uniq = append(uniq, c)
}
return uniq
}
func extractToolCallObjects(text string) []string {
if text == "" {
return nil
}
lower := strings.ToLower(text)
out := []string{}
offset := 0
keywords := []string{"tool_calls", "\"function\"", "function.name:", "functioncall", "\"tool_use\""}
for {
bestIdx := -1
matchedKeyword := ""
for _, kw := range keywords {
idx := strings.Index(lower[offset:], kw)
if idx >= 0 {
absIdx := offset + idx
if bestIdx < 0 || absIdx < bestIdx {
bestIdx = absIdx
matchedKeyword = kw
}
}
}
if bestIdx < 0 {
break
}
idx := bestIdx
// Avoid backtracking too far to prevent OOM on malicious or very long strings
searchLimit := idx - 2000
if searchLimit < offset {
searchLimit = offset
}
start := strings.LastIndex(text[searchLimit:idx], "{")
if start >= 0 {
start += searchLimit
}
if start < 0 {
offset = idx + len(matchedKeyword)
continue
}
foundObj := false
for start >= searchLimit {
candidate, end, ok := extractJSONObject(text, start)
if ok {
// Move forward to avoid repeatedly matching the same object.
offset = end
out = append(out, strings.TrimSpace(candidate))
foundObj = true
break
}
// Try previous '{'
if start > searchLimit {
prevStart := strings.LastIndex(text[searchLimit:start], "{")
if prevStart >= 0 {
start = searchLimit + prevStart
continue
}
}
break
}
if !foundObj {
offset = idx + len(matchedKeyword)
}
}
return out
}
func extractJSONObject(text string, start int) (string, int, bool) {
if start < 0 || start >= len(text) || text[start] != '{' {
return "", 0, false
}
depth := 0
quote := byte(0)
escaped := false
// Limit scan length to avoid OOM on unclosed objects
maxLen := start + 50000
if maxLen > len(text) {
maxLen = len(text)
}
for i := start; i < maxLen; i++ {
ch := text[i]
if quote != 0 {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == quote {
quote = 0
}
continue
}
if ch == '"' || ch == '\'' {
quote = ch
continue
}
if ch == '{' {
depth++
continue
}
if ch == '}' {
depth--
if depth == 0 {
return text[start : i+1], i + 1, true
}
}
}
return "", 0, false
}
func looksLikeToolExampleContext(text string) bool {
t := strings.ToLower(strings.TrimSpace(text))
if t == "" {
return false
}
return strings.Contains(t, "```")
}
func shouldSkipToolCallParsingForCodeFenceExample(text string) bool {
if !looksLikeToolCallSyntax(text) {
return false
}
stripped := strings.TrimSpace(stripFencedCodeBlocks(text))
return !looksLikeToolCallSyntax(stripped)
}
//nolint:unused // retained for future markup tool-call heuristics.
func looksLikeMarkupToolSyntax(text string) bool {
return markupToolSyntaxPattern.MatchString(text)
}
func stripFencedCodeBlocks(text string) string {
if text == "" {
return ""
}
return fencedCodeBlockPattern.ReplaceAllString(text, " ")
}

View File

@@ -68,8 +68,31 @@ func parseMarkupToolCalls(text string) []ParsedToolCall {
}
func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
if parsed := parseToolCallsPayload(inner); len(parsed) > 0 {
return parsed[0]
// Try parsing inner content as a JSON tool call object.
if raw := strings.TrimSpace(inner); raw != "" && strings.HasPrefix(raw, "{") {
var obj map[string]any
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
name, _ := obj["name"].(string)
if name == "" {
if fn, ok := obj["function"].(map[string]any); ok {
name, _ = fn["name"].(string)
}
}
if name == "" {
if fc, ok := obj["functionCall"].(map[string]any); ok {
name, _ = fc["name"].(string)
}
}
if strings.TrimSpace(name) != "" {
input := parseToolCallInput(obj["input"])
if input == nil || len(input) == 0 {
if args, ok := obj["arguments"]; ok {
input = parseToolCallInput(args)
}
}
return ParsedToolCall{Name: strings.TrimSpace(name), Input: input}
}
}
}
name := ""

View File

@@ -1,35 +0,0 @@
package toolcall
import (
"regexp"
"strings"
)
//nolint:unused // retained for policy-level tool-name matching compatibility.
var toolNameLoosePattern = regexp.MustCompile(`[^a-z0-9]+`)
//nolint:unused // retained for policy-level tool-name matching compatibility.
func resolveAllowedToolNameWithLooseMatch(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
if _, ok := allowed[name]; ok {
return name
}
lower := strings.ToLower(strings.TrimSpace(name))
if canonical, ok := allowedCanonical[lower]; ok {
return canonical
}
if idx := strings.LastIndex(lower, "."); idx >= 0 && idx < len(lower)-1 {
if canonical, ok := allowedCanonical[lower[idx+1:]]; ok {
return canonical
}
}
loose := toolNameLoosePattern.ReplaceAllString(lower, "")
if loose == "" {
return ""
}
for candidateLower, canonical := range allowedCanonical {
if toolNameLoosePattern.ReplaceAllString(candidateLower, "") == loose {
return canonical
}
}
return ""
}

View File

@@ -1,7 +1,6 @@
package toolcall
import (
"encoding/json"
"strings"
)
@@ -22,126 +21,33 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
}
func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
result := ToolCallParseResult{}
if strings.TrimSpace(text) == "" {
return result
}
result.SawToolCallSyntax = looksLikeToolCallSyntax(text)
if shouldSkipToolCallParsingForCodeFenceExample(text) {
return result
}
candidates := buildToolCallCandidates(text)
for _, candidate := range candidates {
if !isLikelyJSONToolPayloadCandidate(candidate) {
continue
}
tc := parseToolCallsPayload(candidate)
if len(tc) == 0 {
continue
}
parsed := tc
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
result.SawToolCallSyntax = true
return result
}
var parsed []ParsedToolCall
for _, candidate := range candidates {
tc := parseXMLToolCalls(candidate)
if len(tc) == 0 {
tc = parseMarkupToolCalls(candidate)
}
if len(tc) == 0 {
tc = parseToolCallsPayload(candidate)
}
if len(tc) == 0 {
tc = parseTextKVToolCalls(candidate)
}
if len(tc) > 0 {
parsed = tc
result.SawToolCallSyntax = true
break
}
}
if len(parsed) == 0 {
parsed = parseXMLToolCalls(text)
if len(parsed) == 0 {
parsed = parseTextKVToolCalls(text)
if len(parsed) == 0 {
return result
}
}
result.SawToolCallSyntax = true
}
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
return parseToolCallsDetailedXMLOnly(text)
}
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
return ParseStandaloneToolCallsDetailed(text, availableToolNames).Calls
}
func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) ToolCallParseResult {
return parseToolCallsDetailedXMLOnly(text)
}
func parseToolCallsDetailedXMLOnly(text string) ToolCallParseResult {
result := ToolCallParseResult{}
trimmed := strings.TrimSpace(text)
if trimmed == "" {
return result
}
result.SawToolCallSyntax = looksLikeToolCallSyntax(trimmed)
if shouldSkipToolCallParsingForCodeFenceExample(trimmed) {
return result
}
candidates := buildToolCallCandidates(trimmed)
var parsed []ParsedToolCall
for _, candidate := range candidates {
if !isLikelyJSONToolPayloadCandidate(candidate) {
continue
}
parsed = parseToolCallsPayload(candidate)
if len(parsed) == 0 {
continue
}
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
result.RejectedToolNames = rejectedNames
result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0
return result
}
for _, candidate := range candidates {
candidate = strings.TrimSpace(candidate)
if candidate == "" {
continue
}
parsed = parseXMLToolCalls(candidate)
if len(parsed) == 0 {
parsed = parseMarkupToolCalls(candidate)
}
if len(parsed) == 0 {
parsed = parseToolCallsPayload(candidate)
}
if len(parsed) == 0 {
parsed = parseTextKVToolCalls(candidate)
}
if len(parsed) > 0 {
break
}
parsed := parseXMLToolCalls(trimmed)
if len(parsed) == 0 {
parsed = parseMarkupToolCalls(trimmed)
}
if len(parsed) == 0 {
parsed = parseXMLToolCalls(trimmed)
if len(parsed) == 0 {
parsed = parseTextKVToolCalls(trimmed)
if len(parsed) == 0 {
return result
}
}
return result
}
result.SawToolCallSyntax = true
calls, rejectedNames := filterToolCallsDetailed(parsed)
result.Calls = calls
@@ -164,70 +70,16 @@ func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []strin
return out, nil
}
//nolint:unused // retained for policy-level tool-name matching compatibility.
func resolveAllowedToolName(name string, allowed map[string]struct{}, allowedCanonical map[string]string) string {
return resolveAllowedToolNameWithLooseMatch(name, allowed, allowedCanonical)
}
func parseToolCallsPayload(payload string) []ParsedToolCall {
var decoded any
if err := json.Unmarshal([]byte(payload), &decoded); err != nil {
// Try to repair backslashes first! Because LLMs often mix these two problems.
repaired := repairInvalidJSONBackslashes(payload)
// Try loose repair on top of that
repaired = RepairLooseJSON(repaired)
if err := json.Unmarshal([]byte(repaired), &decoded); err != nil {
return nil
}
}
switch v := decoded.(type) {
case map[string]any:
if tc, ok := v["tool_calls"]; ok {
if isLikelyChatMessageEnvelope(v) {
return nil
}
return parseToolCallList(tc)
}
if parsed, ok := parseToolCallItem(v); ok {
return []ParsedToolCall{parsed}
}
case []any:
return parseToolCallList(v)
}
return nil
}
func isLikelyChatMessageEnvelope(v map[string]any) bool {
if v == nil {
return false
}
if _, ok := v["tool_calls"]; !ok {
return false
}
if role, ok := v["role"].(string); ok {
switch strings.ToLower(strings.TrimSpace(role)) {
case "assistant", "tool", "user", "system":
return true
}
}
if _, ok := v["tool_call_id"]; ok {
return true
}
if _, ok := v["content"]; ok {
return true
}
return false
}
func looksLikeToolCallSyntax(text string) bool {
lower := strings.ToLower(text)
return strings.Contains(lower, "tool_calls") ||
strings.Contains(lower, "\"function\"") ||
strings.Contains(lower, "functioncall") ||
strings.Contains(lower, "\"tool_use\"") ||
return strings.Contains(lower, "<tool_calls") ||
strings.Contains(lower, "<tool_call") ||
strings.Contains(lower, "<function_calls") ||
strings.Contains(lower, "<function_call") ||
strings.Contains(lower, "<function_name") ||
strings.Contains(lower, "<invoke") ||
strings.Contains(lower, "function.name:")
strings.Contains(lower, "<tool_use") ||
strings.Contains(lower, "<attempt_completion") ||
strings.Contains(lower, "<ask_followup_question") ||
strings.Contains(lower, "<new_task") ||
strings.Contains(lower, "<result")
}

View File

@@ -1,87 +0,0 @@
package toolcall
import "strings"
func isLikelyJSONToolPayloadCandidate(candidate string) bool {
trimmed := strings.TrimSpace(candidate)
if trimmed == "" {
return false
}
if !strings.HasPrefix(trimmed, "{") && !strings.HasPrefix(trimmed, "[") {
return false
}
lower := strings.ToLower(trimmed)
return strings.Contains(lower, "tool_calls") ||
strings.Contains(lower, "\"function\"") ||
strings.Contains(lower, "functioncall") ||
strings.Contains(lower, "\"tool_use\"")
}
func parseToolCallList(v any) []ParsedToolCall {
items, ok := v.([]any)
if !ok {
return nil
}
out := make([]ParsedToolCall, 0, len(items))
for _, item := range items {
m, ok := item.(map[string]any)
if !ok {
continue
}
if tc, ok := parseToolCallItem(m); ok {
out = append(out, tc)
}
}
if len(out) == 0 {
return nil
}
return out
}
func parseToolCallItem(m map[string]any) (ParsedToolCall, bool) {
name, _ := m["name"].(string)
inputRaw, hasInput := m["input"]
if fnCall, ok := m["functionCall"].(map[string]any); ok {
if name == "" {
name, _ = fnCall["name"].(string)
}
if !hasInput {
if v, ok := fnCall["args"]; ok {
inputRaw = v
hasInput = true
}
}
if !hasInput {
if v, ok := fnCall["arguments"]; ok {
inputRaw = v
hasInput = true
}
}
}
if fn, ok := m["function"].(map[string]any); ok {
if name == "" {
name, _ = fn["name"].(string)
}
if !hasInput {
if v, ok := fn["arguments"]; ok {
inputRaw = v
hasInput = true
}
}
}
if !hasInput {
for _, key := range []string{"arguments", "args", "parameters", "params"} {
if v, ok := m[key]; ok {
inputRaw = v
break
}
}
}
if strings.TrimSpace(name) == "" {
return ParsedToolCall{}, false
}
return ParsedToolCall{
Name: strings.TrimSpace(name),
Input: parseToolCallInput(inputRaw),
}, true
}

View File

@@ -5,89 +5,6 @@ import (
"testing"
)
func TestParseToolCalls(t *testing.T) {
text := `prefix {"tool_calls":[{"name":"search","input":{"q":"golang"}}]} suffix`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "search" {
t.Fatalf("unexpected tool name: %s", calls[0].Name)
}
if calls[0].Input["q"] != "golang" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
}
}
func TestParseToolCallsIgnoresFencedJSON(t *testing.T) {
text := "I will call tools now\n```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"news\"}}]}\n```"
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 0 {
t.Fatalf("expected fenced tool_call payload to be ignored, got %#v", calls)
}
}
func TestParseToolCallsWithFunctionArgumentsString(t *testing.T) {
text := `{"tool_calls":[{"function":{"name":"get_weather","arguments":"{\"city\":\"beijing\"}"}}]}`
calls := ParseToolCalls(text, []string{"get_weather"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "get_weather" {
t.Fatalf("unexpected tool name: %s", calls[0].Name)
}
if calls[0].Input["city"] != "beijing" {
t.Fatalf("unexpected args: %#v", calls[0].Input)
}
}
func TestParseToolCallsKeepsUnknownToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
calls := ParseToolCalls(text, []string{"search"})
if len(calls) != 1 || calls[0].Name != "unknown" {
t.Fatalf("expected unknown tool to be preserved, got %#v", calls)
}
}
func TestParseToolCallsKeepsOriginalToolNameCase(t *testing.T) {
text := `{"tool_calls":[{"name":"Bash","input":{"command":"ls -al"}}]}`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "Bash" {
t.Fatalf("expected original tool name Bash, got %q", calls[0].Name)
}
}
func TestParseToolCallsDetailedDoesNotRejectByPolicy(t *testing.T) {
text := `{"tool_calls":[{"name":"unknown","input":{}}]}`
res := ParseToolCallsDetailed(text, []string{"search"})
if !res.SawToolCallSyntax {
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
}
if res.RejectedByPolicy {
t.Fatalf("expected RejectedByPolicy=false, got %#v", res)
}
if len(res.Calls) != 1 || res.Calls[0].Name != "unknown" {
t.Fatalf("expected call to be preserved, got %#v", res.Calls)
}
}
func TestParseToolCallsDetailedAllowsWhenAllowListEmpty(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
res := ParseToolCallsDetailed(text, nil)
if !res.SawToolCallSyntax {
t.Fatalf("expected SawToolCallSyntax=true, got %#v", res)
}
if res.RejectedByPolicy {
t.Fatalf("expected RejectedByPolicy=false, got %#v", res)
}
if len(res.Calls) != 1 || res.Calls[0].Name != "search" {
t.Fatalf("expected calls when allow-list is empty, got %#v", res.Calls)
}
}
func TestFormatOpenAIToolCalls(t *testing.T) {
formatted := FormatOpenAIToolCalls([]ParsedToolCall{{Name: "search", Input: map[string]any{"q": "x"}}})
if len(formatted) != 1 {
@@ -99,55 +16,6 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
}
}
func TestParseStandaloneToolCallsSupportsMixedProsePayload(t *testing.T) {
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 1 {
t.Fatalf("expected standalone parser to parse mixed prose payload, got %#v", calls)
}
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
calls := ParseStandaloneToolCalls(standalone, []string{"search"})
if len(calls) != 1 {
t.Fatalf("expected standalone parser to match, got %#v", calls)
}
}
func TestParseStandaloneToolCallsIgnoresFencedCodeBlock(t *testing.T) {
fenced := "```json\n{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}\n```"
if calls := ParseStandaloneToolCalls(fenced, []string{"search"}); len(calls) != 0 {
t.Fatalf("expected fenced tool_call payload to be ignored, got %#v", calls)
}
}
func TestParseStandaloneToolCallsIgnoresChatTranscriptEnvelope(t *testing.T) {
transcript := `[{"role":"user","content":"请展示完整会话"},{"role":"assistant","content":null,"tool_calls":[{"function":{"name":"search","arguments":"{\"q\":\"go\"}"}}]}]`
if calls := ParseStandaloneToolCalls(transcript, []string{"search"}); len(calls) != 0 {
t.Fatalf("expected transcript envelope not to trigger tool call parse, got %#v", calls)
}
}
func TestParseToolCallsAllowsQualifiedToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"mcp.search_web","input":{"q":"golang"}}]}`
calls := ParseToolCalls(text, []string{"search_web"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "mcp.search_web" {
t.Fatalf("expected original tool name mcp.search_web, got %q", calls[0].Name)
}
}
func TestParseToolCallsAllowsPunctuationVariantToolName(t *testing.T) {
text := `{"tool_calls":[{"name":"read-file","input":{"path":"README.md"}}]}`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "read-file" {
t.Fatalf("expected original tool name read-file, got %q", calls[0].Name)
}
}
func TestParseToolCallsSupportsClaudeXMLToolCall(t *testing.T) {
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command><description>show cwd</description></parameters></tool_call>`
calls := ParseToolCalls(text, []string{"bash"})
@@ -223,20 +91,6 @@ func TestParseToolCallsDoesNotTreatParameterNameTagAsToolName(t *testing.T) {
}
}
func TestParseToolCallsPrefersJSONPayloadOverIncidentalXMLInString(t *testing.T) {
text := `{"tool_calls":[{"name":"search","input":{"q":"latest <tool_call><tool_name>wrong</tool_name><parameters>{\"x\":1}</parameters></tool_call>"}}]}`
calls := ParseToolCallsDetailed(text, []string{"search"}).Calls
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "search" {
t.Fatalf("expected tool name search, got %q", calls[0].Name)
}
if calls[0].Input["q"] == nil {
t.Fatalf("expected q argument from json payload, got %#v", calls[0].Input)
}
}
func TestParseToolCallsDetailedMarksXMLToolCallSyntax(t *testing.T) {
text := `<tool_call><tool_name>Bash</tool_name><parameters><command>pwd</command></parameters></tool_call>`
res := ParseToolCallsDetailed(text, []string{"bash"})
@@ -318,34 +172,6 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
}
}
func TestParseToolCallsSupportsGeminiFunctionCallJSON(t *testing.T) {
text := `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`
calls := ParseToolCalls(text, []string{"search_web"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "search_web" {
t.Fatalf("expected search_web, got %q", calls[0].Name)
}
if calls[0].Input["query"] != "latest" {
t.Fatalf("expected query argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsClaudeToolUseJSON(t *testing.T) {
text := `{"type":"tool_use","name":"read_file","input":{"path":"README.md"}}`
calls := ParseToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("expected read_file, got %q", calls[0].Name)
}
if calls[0].Input["path"] != "README.md" {
t.Fatalf("expected path argument, got %#v", calls[0].Input)
}
}
func TestParseToolCallsSupportsToolUseFunctionParameterStyle(t *testing.T) {
text := `<tool_use><function name="search_web"><parameter name="query">test</parameter></function></tool_use>`
calls := ParseToolCalls(text, []string{"search_web"})
@@ -495,104 +321,6 @@ func TestRepairLooseJSON(t *testing.T) {
}
}
func TestParseToolCallsWithUnquotedKeys(t *testing.T) {
text := `这里是列表:{tool_calls: [{"name": "todowrite", "input": {"todos": "test"}}]}`
availableTools := []string{"todowrite"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(parsed))
}
if parsed[0].Name != "todowrite" {
t.Errorf("expected tool todowrite, got %s", parsed[0].Name)
}
}
func TestParseToolCallsWithInvalidBackslashes(t *testing.T) {
// DeepSeek sometimes outputs Windows paths with single backslashes in JSON strings
// Note: using raw string to simulate what AI actually sends in the stream
text := `好的,执行以下命令:{"name": "execute_command", "input": "{\"command\": \"cd D:\git_codes && dir\"}"}`
availableTools := []string{"execute_command"}
parsed := ParseToolCalls(text, availableTools)
// If standard JSON fails, buildToolCallCandidates should still extract the object,
// and parseToolCallsPayload should repair it.
if len(parsed) != 1 {
// If it still fails, let's see why
candidates := buildToolCallCandidates(text)
t.Logf("Candidates: %v", candidates)
t.Fatalf("expected 1 tool call, got %d", len(parsed))
}
cmd, ok := parsed[0].Input["command"].(string)
if !ok {
t.Fatalf("expected command string in input, got %v", parsed[0].Input)
}
expected := "cd D:\\git_codes && dir"
if cmd != expected {
t.Errorf("expected command %q, got %q", expected, cmd)
}
}
func TestParseToolCallsWithDeepSeekHallucination(t *testing.T) {
// 模拟 DeepSeek 典型的幻觉输出:未加引号的键名 + 包含 Windows 路径的嵌套 JSON 字符串 + 漏掉列表的方括号
text := `检测到实施意图——实现经典算法。需在misc/目录创建Python文件。
关键约束:
1. Windows UTF-8编码处理
2. 必须用绝对路径导入
3. 禁止write覆盖已有文件misc/目录允许创建新文件)
将任务分解并委托:
- 研究8皇后算法模式并行探索
- 实现带可视化输出的解决方案unspecified-high
先创建todo列表追踪步骤。
{tool_calls: [{"name": "todowrite", "input": {"todos": {"content": "研究8皇后问题算法模式回溯法和输出格式", "status": "pending", "priority": "high"}, {"content": "在misc/目录创建8皇后Python脚本包含完整解决方案和可视化输出", "status": "pending", "priority": "high"}, {"content": "验证脚本正确性(运行测试)", "status": "pending", "priority": "medium"}}}]}`
availableTools := []string{"todowrite"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
cands := buildToolCallCandidates(text)
for i, c := range cands {
t.Logf("CAND %d: %s", i, c)
repaired := RepairLooseJSON(c)
t.Logf(" REPAIRED: %s", repaired)
}
t.Fatalf("expected 1 tool call, got %d. Candidates: %v", len(parsed), buildToolCallCandidates(text))
}
if parsed[0].Name != "todowrite" {
t.Errorf("expected tool name 'todowrite', got %q", parsed[0].Name)
}
todos, ok := parsed[0].Input["todos"].([]any)
if !ok {
t.Fatalf("expected 'todos' to be parsed as a list, got %T: %#v", parsed[0].Input["todos"], parsed[0].Input["todos"])
}
if len(todos) != 3 {
t.Errorf("expected 3 todo items, got %d", len(todos))
}
}
func TestParseToolCallsWithMixedWindowsPaths(t *testing.T) {
// 更复杂的案例:嵌套 JSON 字符串中的反斜杠未转义
text := `关键约束: 1. Windows UTF-8编码处理 2. 必须用绝对路径导入 D:\git_codes\ds2api\misc
{tool_calls: [{"name": "write_file", "input": "{\"path\": \"D:\\git_codes\\ds2api\\misc\\queens.py\", \"content\": \"print('hello')\"}"}]}`
availableTools := []string{"write_file"}
parsed := ParseToolCalls(text, availableTools)
if len(parsed) != 1 {
t.Fatalf("expected 1 tool call from mixed text with paths, got %d", len(parsed))
}
path, _ := parsed[0].Input["path"].(string)
// 在解析后的 Go map 中,反斜杠应该被还原
if !strings.Contains(path, "D:\\git_codes") && !strings.Contains(path, "D:/git_codes") {
t.Errorf("expected path to contain Windows style separators, got %q", path)
}
}
func TestParseToolCallInputRepairsControlCharsInPath(t *testing.T) {
in := `{"path":"D:\tmp\new\readme.txt","content":"line1\nline2"}`
parsed := parseToolCallInput(in)
@@ -703,15 +431,3 @@ func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
t.Fatalf("expected html entities to be unescaped in command, got %q", cmd)
}
}
func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) {
text := `{"tool_calls":[{"name":"bash","input":{"command":"echo &gt; literally"}}]}`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected one call, got %#v", calls)
}
cmd, _ := calls[0].Input["command"].(string)
if cmd != "echo &gt; literally" {
t.Fatalf("expected json payload to keep literal entities, got %q", cmd)
}
}

View File

@@ -1,55 +0,0 @@
package toolcall
import (
"regexp"
"strings"
)
var textKVNamePattern = regexp.MustCompile(`(?is)function\.name:\s*([a-zA-Z0-9_\-.]+)`)
func parseTextKVToolCalls(text string) []ParsedToolCall {
var out []ParsedToolCall
matches := textKVNamePattern.FindAllStringSubmatchIndex(text, -1)
if len(matches) == 0 {
return nil
}
for i, match := range matches {
name := text[match[2]:match[3]]
offset := match[1]
endSearch := len(text)
if i+1 < len(matches) {
endSearch = matches[i+1][0]
}
searchArea := text[offset:endSearch]
argIdx := strings.Index(searchArea, "function.arguments:")
if argIdx < 0 {
continue
}
startIdx := offset + argIdx + len("function.arguments:")
braceIdx := strings.IndexByte(text[startIdx:endSearch], '{')
if braceIdx < 0 {
continue
}
actualStart := startIdx + braceIdx
objJson, _, ok := extractJSONObject(text, actualStart)
if !ok {
continue
}
input := parseToolCallInput(objJson)
out = append(out, ParsedToolCall{
Name: name,
Input: input,
})
}
if len(out) == 0 {
return nil
}
return out
}

View File

@@ -1,61 +0,0 @@
package toolcall
import (
"testing"
)
func TestParseTextKVToolCalls_Basic(t *testing.T) {
text := `
status: already_called
origin: assistant
not_user_input: true
tool_call_id: call_3fcd15235eb94f7eae3a8de5a9cfa36b
function.name: execute_command
function.arguments: {"command":"cd scripts && python check_syntax.py example.py","cwd":null,"timeout":30}
Some other text thinking...
`
calls := ParseToolCalls(text, []string{"execute_command"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "execute_command" {
t.Fatalf("unexpected name: %s", calls[0].Name)
}
if calls[0].Input["command"] != "cd scripts && python check_syntax.py example.py" {
t.Fatalf("unexpected command arg: %v", calls[0].Input["command"])
}
}
func TestParseTextKVToolCalls_Multiple(t *testing.T) {
text := `
function.name: read_file
function.arguments: {
"path": "abc.txt"
}
function.name: bash
function.arguments: {"command": "ls"}
`
calls := ParseToolCalls(text, []string{"read_file", "bash"})
if len(calls) != 2 {
t.Fatalf("expected 2 calls, got %d", len(calls))
}
if calls[0].Name != "read_file" {
t.Fatalf("unexpected 1st name: %s", calls[0].Name)
}
if calls[1].Name != "bash" {
t.Fatalf("unexpected 2nd name: %s", calls[1].Name)
}
}
func TestParseTextKVToolCalls_Standalone(t *testing.T) {
text := "function.name: read_file\nfunction.arguments: {\"path\":\"README.md\"}"
calls := ParseStandaloneToolCalls(text, []string{"read_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %d", len(calls))
}
if calls[0].Name != "read_file" {
t.Fatalf("unexpected name: %s", calls[0].Name)
}
}

View File

@@ -2,36 +2,6 @@ package util
import "testing"
func TestBuildOpenAIChatCompletionWithToolCalls(t *testing.T) {
out := BuildOpenAIChatCompletion(
"cid1",
"deepseek-chat",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["object"] != "chat.completion" {
t.Fatalf("unexpected object: %#v", out["object"])
}
choices, _ := out["choices"].([]map[string]any)
if len(choices) == 0 {
// json-like map from generic marshalling may be []any in some paths
rawChoices, _ := out["choices"].([]any)
if len(rawChoices) == 0 {
t.Fatalf("expected choices")
}
c0, _ := rawChoices[0].(map[string]any)
if c0["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", c0["finish_reason"])
}
return
}
if choices[0]["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choices[0]["finish_reason"])
}
}
func TestBuildOpenAIResponseObjectWithText(t *testing.T) {
out := BuildOpenAIResponseObject(
"resp_1",
@@ -53,42 +23,3 @@ func TestBuildOpenAIResponseObjectWithText(t *testing.T) {
t.Fatalf("expected first output type message, got %#v", first["type"])
}
}
func TestBuildOpenAIResponseObjectToolCallsHidesRawOutputText(t *testing.T) {
out := BuildOpenAIResponseObject(
"resp_2",
"gpt-4o",
"prompt",
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["output_text"] != "" {
t.Fatalf("expected empty output_text for tool_calls, got %#v", out["output_text"])
}
output, _ := out["output"].([]any)
if len(output) == 0 {
t.Fatalf("expected output entries")
}
first, _ := output[0].(map[string]any)
if first["type"] != "tool_calls" {
t.Fatalf("expected first output type tool_calls, got %#v", first["type"])
}
}
func TestBuildClaudeMessageResponseToolUse(t *testing.T) {
out := BuildClaudeMessageResponse(
"msg_1",
"claude-sonnet-4-5",
[]any{map[string]any{"role": "user", "content": "hi"}},
"",
`{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`,
[]string{"search"},
)
if out["type"] != "message" {
t.Fatalf("unexpected type: %#v", out["type"])
}
if out["stop_reason"] != "tool_use" {
t.Fatalf("expected stop_reason=tool_use, got %#v", out["stop_reason"])
}
}