feat: Implement streaming incremental tool call deltas with a new tool sieve and standalone parser.

This commit is contained in:
CJACK
2026-02-18 16:10:35 +08:00
parent 19289c9008
commit 7beeea5779
9 changed files with 1324 additions and 114 deletions

View File

@@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string {
return ""
}
func streamToolCallArgumentChunks(frames []map[string]any) []string {
out := make([]string, 0, 4)
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, tc := range toolCalls {
tcm, _ := tc.(map[string]any)
fn, _ := tcm["function"].(map[string]any)
if args, ok := fn["arguments"].(string); ok && args != "" {
out = append(out, args)
}
}
}
}
return out
}
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -190,6 +210,37 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
}
}
func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"下面是示例:"}`,
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: {"p":"response/content","v":"请勿执行。"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"})
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if _, ok := msg["tool_calls"]; ok {
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected embedded example to pass through as text, got %q", content)
}
}
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
@@ -391,11 +442,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
}
content := strings.Builder{}
for _, frame := range frames {
@@ -412,8 +460,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
if !strings.Contains(got, `"tool_calls"`) {
t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got)
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
}
}
@@ -495,16 +546,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
}
}
}
got := strings.ToLower(content.String())
if strings.Contains(got, "tool_calls") {
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
}
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
got := content.String()
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
}
if !strings.Contains(strings.ToLower(got), "tool_calls") {
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
}
}
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
@@ -533,7 +584,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
}
}
}
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
}
}
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
}
argChunks := streamToolCallArgumentChunks(frames)
if len(argChunks) < 2 {
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
}
joined := strings.Join(argChunks, "")
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
t.Fatalf("unexpected merged arguments stream: %q", joined)
}
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}