mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-23 09:27:46 +08:00
fix: reset tool call state between separate tool blocks to ensure unique IDs across stream segments
This commit is contained in:
@@ -122,6 +122,11 @@ func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
|
||||
s.sendDone()
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.streamToolNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finalThinking := s.thinking.String()
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
@@ -166,6 +171,7 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
[]map[string]any{openaifmt.BuildChatStreamDeltaChoice(0, tcDelta)},
|
||||
nil,
|
||||
))
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
if evt.Content == "" {
|
||||
continue
|
||||
@@ -309,6 +315,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
s.firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, openaifmt.BuildChatStreamDeltaChoice(0, tcDelta))
|
||||
s.resetStreamToolCallState()
|
||||
continue
|
||||
}
|
||||
if evt.Content != "" {
|
||||
|
||||
@@ -213,3 +213,51 @@ func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testin
|
||||
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: {"p":"response/content","v":"中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid-multi", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, nil)
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, rawCall := range toolCalls {
|
||||
call, _ := rawCall.(map[string]any)
|
||||
id := asString(call["id"])
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct tool call ids, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct tool call ids across blocks, got %#v body=%s", ids, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
finalText := cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
|
||||
|
||||
if s.bufferToolContent {
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true)
|
||||
s.processToolStreamEvents(flushToolSieve(&s.sieve, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
textParsed := toolcall.ParseStandaloneToolCallsDetailed(finalText, s.toolNames)
|
||||
@@ -224,7 +224,7 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
s.emitTextDelta(trimmed)
|
||||
continue
|
||||
}
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true)
|
||||
s.processToolStreamEvents(processToolSieveChunk(&s.sieve, trimmed, s.toolNames), true, true)
|
||||
}
|
||||
|
||||
return streamengine.ParsedDecision{ContentSeen: contentSeen}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (s *responsesStreamRuntime) sendDone() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool) {
|
||||
func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEvent, emitContent bool, resetAfterToolCalls bool) {
|
||||
for _, evt := range events {
|
||||
if emitContent && evt.Content != "" {
|
||||
s.emitTextDelta(evt.Content)
|
||||
@@ -56,6 +56,9 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
s.emitFunctionCallDoneEvents(evt.ToolCalls)
|
||||
if resetAfterToolCalls {
|
||||
s.resetStreamToolCallState()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +152,16 @@ func (s *responsesStreamRuntime) ensureToolCallID(callIndex int) string {
|
||||
return id
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) resetStreamToolCallState() {
|
||||
s.streamToolCallIDs = map[int]string{}
|
||||
s.functionItemIDs = map[int]string{}
|
||||
s.functionOutputIDs = map[int]int{}
|
||||
s.functionArgs = map[int]string{}
|
||||
s.functionDone = map[int]bool{}
|
||||
s.functionAdded = map[int]bool{}
|
||||
s.functionNames = map[int]string{}
|
||||
}
|
||||
|
||||
func (s *responsesStreamRuntime) ensureFunctionOutputIndex(callIndex int) int {
|
||||
if idx, ok := s.functionOutputIDs[callIndex]; ok {
|
||||
return idx
|
||||
|
||||
@@ -109,6 +109,57 @@ func TestHandleResponsesStreamOutputTextDeltaCarriesItemIndexes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamEmitsDistinctToolCallIDsAcrossSeparateToolBlocks(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
sseLine := func(v string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"p": "response/content",
|
||||
"v": v,
|
||||
})
|
||||
return "data: " + string(b) + "\n"
|
||||
}
|
||||
|
||||
streamBody := sseLine("前置文本\n<tool_calls>\n <tool_call>\n <tool_name>read_file</tool_name>\n <parameters>{\"path\":\"README.MD\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
sseLine("中间文本\n<tool_calls>\n <tool_call>\n <tool_name>search</tool_name>\n <parameters>{\"q\":\"golang\"}</parameters>\n </tool_call>\n</tool_calls>") +
|
||||
"data: [DONE]\n"
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(streamBody)),
|
||||
}
|
||||
|
||||
h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file", "search"}, util.DefaultToolChoicePolicy(), "")
|
||||
|
||||
body := rec.Body.String()
|
||||
doneEvents := extractSSEEventPayloads(body, "response.function_call_arguments.done")
|
||||
if len(doneEvents) < 2 {
|
||||
t.Fatalf("expected at least two function call done events, got %d body=%s", len(doneEvents), body)
|
||||
}
|
||||
|
||||
ids := make([]string, 0, 2)
|
||||
seen := make(map[string]struct{})
|
||||
for _, payload := range doneEvents {
|
||||
callID := asString(payload["call_id"])
|
||||
if callID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[callID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[callID] = struct{}{}
|
||||
ids = append(ids, callID)
|
||||
}
|
||||
|
||||
if len(ids) != 2 {
|
||||
t.Fatalf("expected two distinct call ids, got %#v body=%s", ids, body)
|
||||
}
|
||||
if ids[0] == ids[1] {
|
||||
t.Fatalf("expected distinct call ids across blocks, got %#v body=%s", ids, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleResponsesStreamRequiredToolChoiceFailure(t *testing.T) {
|
||||
h := &Handler{}
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
@@ -325,3 +376,30 @@ func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func extractSSEEventPayloads(body, targetEvent string) []map[string]any {
|
||||
scanner := bufio.NewScanner(strings.NewReader(body))
|
||||
matched := false
|
||||
out := make([]map[string]any, 0, 4)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "event: ") {
|
||||
evt := strings.TrimSpace(strings.TrimPrefix(line, "event: "))
|
||||
matched = evt == targetEvent
|
||||
continue
|
||||
}
|
||||
if !matched || !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(strings.TrimPrefix(line, "data: "))
|
||||
if raw == "" || raw == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, payload)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user