fix: reset tool call state between separate tool blocks to ensure unique IDs across stream segments

This commit is contained in:
CJACK.
2026-04-22 20:10:06 +00:00
parent c291d333c4
commit 5cf56e7628
10 changed files with 175 additions and 3 deletions

View File

@@ -122,6 +122,11 @@ func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
s.sendDone()
}
func (s *chatStreamRuntime) resetStreamToolCallState() {
s.streamToolCallIDs = map[int]string{}
s.streamToolNames = map[int]string{}
}
func (s *chatStreamRuntime) finalize(finishReason string) {
finalThinking := s.thinking.String()
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
@@ -166,6 +171,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
nil,
))
s.resetStreamToolCallState()
}
if evt.Content == "" {
continue
@@ -309,6 +315,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
s.firstChunkSent = true
}
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
s.resetStreamToolCallState()
continue
}
if evt.Content != "" {

View File

@@ -213,3 +213,51 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
}
}
func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
`data: {"p":"response/content","v":"中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
h.handleStream(rec, req, resp, "cid-multi", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, nil)
frames, done := parseSSEDataFrames(t, rec.Body.String())
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
ids := make([]string, 0, 2)
seen := make(map[string]struct{})
for _, frame := range frames {
choices, _ := frame["choices"].([]any)
for _, item := range choices {
choice, _ := item.(map[string]any)
delta, _ := choice["delta"].(map[string]any)
toolCalls, _ := delta["tool_calls"].([]any)
for _, rawCall := range toolCalls {
call, _ := rawCall.(map[string]any)
id := asString(call["id"])
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
ids = append(ids, id)
}
}
}
if len(ids) != 2 {
t.Fatalf("expected two distinct tool call ids, got %#v body=%s", ids, rec.Body.String())
}
if ids[0] == ids[1] {
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
}
}

View File

@@ -128,7 +128,7 @@ func (s *responsesStreamRuntime) finalize() {
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
if s.bufferToolContent {
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true, true)
}
textParsed := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
@@ -224,7 +224,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
s.emitTextDelta(trimmed)
continue
}
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true)
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true, true)
}
return streamengine.ParsedDecision{ContentSeen: contentSeen}

View File

@@ -39,7 +39,7 @@ func (s *responsesStreamRuntime) sendDone() {
}
}
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool, resetAfterToolCalls bool) {
for _, evt := range events {
if emitContent && evt.Content != "" {
s.emitTextDelta(evt.Content)
@@ -56,6 +56,9 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven
}
if len(evt.ToolCalls) > 0 {
s.emitFunctionCallDoneEvents(evt.ToolCalls)
if resetAfterToolCalls {
s.resetStreamToolCallState()
}
}
}
}

View File

@@ -152,6 +152,16 @@ func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
return id
}
func (s *responsesStreamRuntime) resetStreamToolCallState() {
s.streamToolCallIDs = map[int]string{}
s.functionItemIDs = map[int]string{}
s.functionOutputIDs = map[int]int{}
s.functionArgs = map[int]string{}
s.functionDone = map[int]bool{}
s.functionAdded = map[int]bool{}
s.functionNames = map[int]string{}
}
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
if idx, ok := s.functionOutputIDs[callIndex]; ok {
return idx

View File

@@ -109,6 +109,57 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
}
}
func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
sseLine := func(v string) string {
b, _ := json.Marshal(map[string]any{
"p": "response/content",
"v": v,
})
return "data: " + string(b) + "\n"
}
streamBody := sseLine("前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>") +
sseLine("中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>") +
"data: [DONE]\n"
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(streamBody)),
}
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
if len(doneEvents) < 2 {
t.Fatalf("expected at least two function call done events, got %d body=%s", len(doneEvents), body)
}
ids := make([]string, 0, 2)
seen := make(map[string]struct{})
for _, payload := range doneEvents {
callID := asString(payload["call_id"])
if callID == "" {
continue
}
if _, ok := seen[callID]; ok {
continue
}
seen[callID] = struct{}{}
ids = append(ids, callID)
}
if len(ids) != 2 {
t.Fatalf("expected two distinct call ids, got %#v body=%s", ids, body)
}
if ids[0] == ids[1] {
t.Fatalf("expected distinct call ids across blocks, got %#v body=%s", ids, body)
}
}
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
@@ -325,3 +376,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
}
return nil, false
}
func extractSSEEventPayloads(body, targetEvent string) []map[string]any {
scanner := bufio.NewScanner(strings.NewReader(body))
matched := false
out := make([]map[string]any, 0, 4)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "event: ") {
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
matched = evt == targetEvent
continue
}
if !matched || !strings.HasPrefix(line, "data: ") {
continue
}
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
if raw == "" || raw == "[DONE]" {
continue
}
var payload map[string]any
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
continue
}
out = append(out, payload)
}
return out
}

View File

@@ -18,6 +18,7 @@ const {
normalizePreparedToolNames,
boolDefaultTrue,
filterIncrementalToolCallDeltasByAllowed,
resetStreamToolCallState,
} = require('./toolcall_policy');
const {
estimateTokens,
@@ -115,6 +116,7 @@ module.exports.__test = {
normalizePreparedToolNames,
boolDefaultTrue,
filterIncrementalToolCallDeltasByAllowed,
resetStreamToolCallState,
estimateTokens,
buildUsage,
filterLeakedContentFilterParts,

View File

@@ -98,6 +98,15 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
return out;
}
function resetStreamToolCallState(idStore, seenNames) {
if (idStore instanceof Map) {
idStore.clear();
}
if (seenNames instanceof Map) {
seenNames.clear();
}
}
function ensureStreamToolCallID(idStore, index) {
const key = Number.isInteger(index) ? index : 0;
const existing = idStore.get(key);
@@ -135,4 +144,5 @@ module.exports = {
boolDefaultTrue,
formatIncrementalToolCallDeltas,
filterIncrementalToolCallDeltasByAllowed,
resetStreamToolCallState,
};

View File

@@ -18,6 +18,7 @@ const {
formatIncrementalToolCallDeltas,
filterIncrementalToolCallDeltasByAllowed,
boolDefaultTrue,
resetStreamToolCallState,
} = require('./toolcall_policy');
const { createChatCompletionEmitter } = require('./stream_emitter');
const {
@@ -161,6 +162,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
if (evt.type === 'tool_calls' && Array.isArray(evt.calls) && evt.calls.length > 0) {
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
resetStreamToolCallState(streamToolCallIDs, streamToolNames);
continue;
}
if (evt.text) {
@@ -283,6 +285,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
if (evt.type === 'tool_calls') {
toolCallsEmitted = true;
sendDeltaFrame({ tool_calls: formatOpenAIStreamToolCalls(evt.calls, streamToolCallIDs) });
resetStreamToolCallState(streamToolCallIDs, streamToolNames);
continue;
}
if (evt.text) {

View File

@@ -17,6 +17,7 @@ const {
normalizePreparedToolNames,
boolDefaultTrue,
filterIncrementalToolCallDeltasByAllowed,
resetStreamToolCallState,
buildUsage,
estimateTokens,
shouldSkipPath,
@@ -107,6 +108,16 @@ test('incremental and final tool formatting share stable id via idStore', () =>
assert.equal(incremental[0].id, finalCalls[0].id);
});
test('resetStreamToolCallState gives each completed block a fresh id', () => {
const idStore = new Map();
const first = formatIncrementalToolCallDeltas([{ index: 0, name: 'read_file' }], idStore);
resetStreamToolCallState(idStore);
const second = formatIncrementalToolCallDeltas([{ index: 0, name: 'search' }], idStore);
assert.equal(first.length, 1);
assert.equal(second.length, 1);
assert.notEqual(first[0].id, second[0].id);
});
test('formatIncrementalToolCallDeltas drops empty deltas (Go parity)', () => {
const idStore = new Map();
const formatted = formatIncrementalToolCallDeltas([{ index: 0 }], idStore);