mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 08:55:28 +08:00
feat: Implement streaming incremental tool call deltas with a new tool sieve and standalone parser.
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
@@ -134,7 +135,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
|
||||
finalThinking := result.Thinking
|
||||
finalText := result.Text
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
finishReason := "stop"
|
||||
messageObj := map[string]any{"role": "assistant", "content": finalText}
|
||||
if thinkingEnabled && finalThinking != "" {
|
||||
@@ -188,6 +189,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
bufferToolContent := len(toolNames) > 0
|
||||
var toolSieve toolStreamSieveState
|
||||
toolCallsEmitted := false
|
||||
streamToolCallIDs := map[int]string{}
|
||||
initialType := "text"
|
||||
if thinkingEnabled {
|
||||
initialType = "thinking"
|
||||
@@ -220,7 +222,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
finalize := func(finishReason string) {
|
||||
finalThinking := thinking.String()
|
||||
finalText := text.String()
|
||||
detected := util.ParseToolCalls(finalText, toolNames)
|
||||
detected := util.ParseStandaloneToolCalls(finalText, toolNames)
|
||||
if len(detected) > 0 && !toolCallsEmitted {
|
||||
finishReason = "tool_calls"
|
||||
delta := map[string]any{
|
||||
@@ -352,6 +354,21 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
// Keep thinking delta only frame.
|
||||
}
|
||||
for _, evt := range events {
|
||||
if len(evt.ToolCallDeltas) > 0 {
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
"tool_calls": formatIncrementalStreamToolCallDeltas(evt.ToolCallDeltas, streamToolCallIDs),
|
||||
}
|
||||
if !firstChunkSent {
|
||||
tcDelta["role"] = "assistant"
|
||||
firstChunkSent = true
|
||||
}
|
||||
newChoices = append(newChoices, map[string]any{
|
||||
"delta": tcDelta,
|
||||
"index": 0,
|
||||
})
|
||||
continue
|
||||
}
|
||||
if len(evt.ToolCalls) > 0 {
|
||||
toolCallsEmitted = true
|
||||
tcDelta := map[string]any{
|
||||
@@ -441,6 +458,40 @@ func injectToolPrompt(messages []map[string]any, tools []any) ([]map[string]any,
|
||||
return messages, names
|
||||
}
|
||||
|
||||
func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]string) []map[string]any {
|
||||
if len(deltas) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]map[string]any, 0, len(deltas))
|
||||
for _, d := range deltas {
|
||||
if d.Name == "" && d.Arguments == "" {
|
||||
continue
|
||||
}
|
||||
callID, ok := ids[d.Index]
|
||||
if !ok || callID == "" {
|
||||
callID = "call_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
ids[d.Index] = callID
|
||||
}
|
||||
item := map[string]any{
|
||||
"index": d.Index,
|
||||
"id": callID,
|
||||
"type": "function",
|
||||
}
|
||||
fn := map[string]any{}
|
||||
if d.Name != "" {
|
||||
fn["name"] = d.Name
|
||||
}
|
||||
if d.Arguments != "" {
|
||||
fn["arguments"] = d.Arguments
|
||||
}
|
||||
if len(fn) > 0 {
|
||||
item["function"] = fn
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeOpenAIError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]any{
|
||||
"error": map[string]any{
|
||||
|
||||
@@ -100,6 +100,26 @@ func streamFinishReason(frames []map[string]any) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func streamToolCallArgumentChunks(frames []map[string]any) []string {
|
||||
out := make([]string, 0, 4)
|
||||
for _, frame := range frames {
|
||||
choices, _ := frame["choices"].([]any)
|
||||
for _, item := range choices {
|
||||
choice, _ := item.(map[string]any)
|
||||
delta, _ := choice["delta"].(map[string]any)
|
||||
toolCalls, _ := delta["tool_calls"].([]any)
|
||||
for _, tc := range toolCalls {
|
||||
tcm, _ := tc.(map[string]any)
|
||||
fn, _ := tcm["function"].(map[string]any)
|
||||
if args, ok := fn["arguments"].(string); ok && args != "" {
|
||||
out = append(out, args)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestHandleNonStreamToolCallInterceptsChatModel(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -190,6 +210,37 @@ func TestHandleNonStreamUnknownToolStillIntercepted(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleNonStreamEmbeddedToolCallExampleNotIntercepted(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"下面是示例:"}`,
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
|
||||
`data: {"p":"response/content","v":"请勿执行。"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handleNonStream(rec, context.Background(), resp, "cid2c", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %d", rec.Code)
|
||||
}
|
||||
|
||||
out := decodeJSONBody(t, rec.Body.String())
|
||||
choices, _ := out["choices"].([]any)
|
||||
choice, _ := choices[0].(map[string]any)
|
||||
if choice["finish_reason"] != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
|
||||
}
|
||||
msg, _ := choice["message"].(map[string]any)
|
||||
if _, ok := msg["tool_calls"]; ok {
|
||||
t.Fatalf("did not expect tool_calls field for embedded example: %#v", msg["tool_calls"])
|
||||
}
|
||||
content, _ := msg["content"].(string)
|
||||
if !strings.Contains(content, "示例") || !strings.Contains(content, `"tool_calls"`) {
|
||||
t.Fatalf("expected embedded example to pass through as text, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallInterceptsWithoutRawContentLeak(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
@@ -391,11 +442,8 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta in mixed stream, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in mixed stream: %s", rec.Body.String())
|
||||
if streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("did not expect tool_calls delta in mixed prose stream, body=%s", rec.Body.String())
|
||||
}
|
||||
content := strings.Builder{}
|
||||
for _, frame := range frames {
|
||||
@@ -412,8 +460,11 @@ func TestHandleStreamToolCallMixedWithPlainTextSegments(t *testing.T) {
|
||||
if !strings.Contains(got, "前置正文A。") || !strings.Contains(got, "后置正文B。") {
|
||||
t.Fatalf("expected pre/post plain text to pass sieve, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
if !strings.Contains(got, `"tool_calls"`) {
|
||||
t.Fatalf("expected mixed stream to preserve embedded tool_calls example text, got=%q", got)
|
||||
}
|
||||
if streamFinishReason(frames) != "stop" {
|
||||
t.Fatalf("expected finish_reason=stop for mixed prose, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,16 +546,16 @@ func TestHandleStreamInvalidToolJSONDoesNotLeakRawObject(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
got := strings.ToLower(content.String())
|
||||
if strings.Contains(got, "tool_calls") {
|
||||
t.Fatalf("unexpected raw tool_calls leak in content: %q", content.String())
|
||||
}
|
||||
if !strings.Contains(content.String(), "前置正文D。") || !strings.Contains(content.String(), "后置正文E。") {
|
||||
got := content.String()
|
||||
if !strings.Contains(got, "前置正文D。") || !strings.Contains(got, "后置正文E。") {
|
||||
t.Fatalf("expected pre/post plain text to remain, got=%q", content.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(got), "tool_calls") {
|
||||
t.Fatalf("expected invalid embedded tool-like json to pass through as text, got=%q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.T) {
|
||||
func TestHandleStreamIncompleteCapturedToolJSONFlushesAsTextOnFinalize(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\""}`,
|
||||
@@ -533,7 +584,42 @@ func TestHandleStreamIncompleteCapturedToolJSONDoesNotLeakOnFinalize(t *testing.
|
||||
}
|
||||
}
|
||||
}
|
||||
if strings.Contains(strings.ToLower(content.String()), "tool_calls") || strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("unexpected incomplete tool json leak in content: %q", content.String())
|
||||
if !strings.Contains(strings.ToLower(content.String()), "tool_calls") || !strings.Contains(content.String(), "{") {
|
||||
t.Fatalf("expected incomplete capture to flush as plain text instead of stalling, got=%q", content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamToolCallArgumentsEmitIncrementally(t *testing.T) {
|
||||
h := &Handler{}
|
||||
resp := makeSSEHTTPResponse(
|
||||
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go"}`,
|
||||
`data: {"p":"response/content","v":"lang\",\"page\":1}}]}"}`,
|
||||
`data: [DONE]`,
|
||||
)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
|
||||
h.handleStream(rec, req, resp, "cid11", "deepseek-chat", "prompt", false, false, []string{"search"})
|
||||
|
||||
frames, done := parseSSEDataFrames(t, rec.Body.String())
|
||||
if !done {
|
||||
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
|
||||
}
|
||||
if !streamHasToolCallsDelta(frames) {
|
||||
t.Fatalf("expected tool_calls delta, body=%s", rec.Body.String())
|
||||
}
|
||||
if streamHasRawToolJSONContent(frames) {
|
||||
t.Fatalf("raw tool_calls JSON leaked in content delta: %s", rec.Body.String())
|
||||
}
|
||||
argChunks := streamToolCallArgumentChunks(frames)
|
||||
if len(argChunks) < 2 {
|
||||
t.Fatalf("expected incremental arguments chunks, got=%v body=%s", argChunks, rec.Body.String())
|
||||
}
|
||||
joined := strings.Join(argChunks, "")
|
||||
if !strings.Contains(joined, `"q":"golang"`) || !strings.Contains(joined, `"page":1`) {
|
||||
t.Fatalf("unexpected merged arguments stream: %q", joined)
|
||||
}
|
||||
if streamFinishReason(frames) != "tool_calls" {
|
||||
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,14 +7,39 @@ import (
|
||||
)
|
||||
|
||||
type toolStreamSieveState struct {
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
pending strings.Builder
|
||||
capture strings.Builder
|
||||
capturing bool
|
||||
hasMeaningfulText bool
|
||||
toolNameSent bool
|
||||
toolName string
|
||||
toolArgsStart int
|
||||
toolArgsSent int
|
||||
toolArgsString bool
|
||||
toolArgsDone bool
|
||||
}
|
||||
|
||||
type toolStreamEvent struct {
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
Content string
|
||||
ToolCalls []util.ParsedToolCall
|
||||
ToolCallDeltas []toolCallDelta
|
||||
}
|
||||
|
||||
type toolCallDelta struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments string
|
||||
}
|
||||
|
||||
const toolSieveCaptureLimit = 8 * 1024
|
||||
|
||||
func (s *toolStreamSieveState) resetIncrementalToolState() {
|
||||
s.toolNameSent = false
|
||||
s.toolName = ""
|
||||
s.toolArgsStart = -1
|
||||
s.toolArgsSent = -1
|
||||
s.toolArgsString = false
|
||||
s.toolArgsDone = false
|
||||
}
|
||||
|
||||
func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames []string) []toolStreamEvent {
|
||||
@@ -32,13 +57,31 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
state.capture.WriteString(state.pending.String())
|
||||
state.pending.Reset()
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
if deltas := buildIncrementalToolDeltas(state); len(deltas) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCallDeltas: deltas})
|
||||
}
|
||||
prefix, calls, suffix, ready := consumeToolCapture(state, toolNames)
|
||||
if !ready {
|
||||
if state.capture.Len() > toolSieveCaptureLimit {
|
||||
content := state.capture.String()
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
if prefix != "" {
|
||||
if strings.TrimSpace(prefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
if len(calls) > 0 {
|
||||
@@ -58,11 +101,15 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
if start >= 0 {
|
||||
prefix := pending[:start]
|
||||
if prefix != "" {
|
||||
if strings.TrimSpace(prefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: prefix})
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.capture.WriteString(pending[start:])
|
||||
state.capturing = true
|
||||
state.resetIncrementalToolState()
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -72,6 +119,9 @@ func processToolSieveChunk(state *toolStreamSieveState, chunk string, toolNames
|
||||
}
|
||||
state.pending.Reset()
|
||||
state.pending.WriteString(hold)
|
||||
if strings.TrimSpace(safe) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: safe})
|
||||
}
|
||||
|
||||
@@ -84,25 +134,42 @@ func flushToolSieve(state *toolStreamSieveState, toolNames []string) []toolStrea
|
||||
}
|
||||
events := processToolSieveChunk(state, "", toolNames)
|
||||
if state.capturing {
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state.capture.String(), toolNames)
|
||||
consumedPrefix, consumedCalls, consumedSuffix, ready := consumeToolCapture(state, toolNames)
|
||||
if ready {
|
||||
if consumedPrefix != "" {
|
||||
if strings.TrimSpace(consumedPrefix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: consumedPrefix})
|
||||
}
|
||||
if len(consumedCalls) > 0 {
|
||||
events = append(events, toolStreamEvent{ToolCalls: consumedCalls})
|
||||
}
|
||||
if consumedSuffix != "" {
|
||||
if strings.TrimSpace(consumedSuffix) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: consumedSuffix})
|
||||
}
|
||||
} else {
|
||||
// Incomplete captured tool JSON at stream end: suppress raw capture.
|
||||
content := state.capture.String()
|
||||
if content != "" {
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
}
|
||||
}
|
||||
state.capture.Reset()
|
||||
state.capturing = false
|
||||
state.resetIncrementalToolState()
|
||||
}
|
||||
if state.pending.Len() > 0 {
|
||||
events = append(events, toolStreamEvent{Content: state.pending.String()})
|
||||
content := state.pending.String()
|
||||
if strings.TrimSpace(content) != "" {
|
||||
state.hasMeaningfulText = true
|
||||
}
|
||||
events = append(events, toolStreamEvent{Content: content})
|
||||
state.pending.Reset()
|
||||
}
|
||||
return events
|
||||
@@ -154,7 +221,8 @@ func findToolSegmentStart(s string) int {
|
||||
return keyIdx
|
||||
}
|
||||
|
||||
func consumeToolCapture(captured string, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
func consumeToolCapture(state *toolStreamSieveState, toolNames []string) (prefix string, calls []util.ParsedToolCall, suffix string, ready bool) {
|
||||
captured := state.capture.String()
|
||||
if captured == "" {
|
||||
return "", nil, "", false
|
||||
}
|
||||
@@ -171,13 +239,25 @@ func consumeToolCapture(captured string, toolNames []string) (prefix string, cal
|
||||
if !ok {
|
||||
return "", nil, "", false
|
||||
}
|
||||
parsed := util.ParseToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
// `tool_calls` key exists but strict JSON parse failed.
|
||||
// Drop the captured object body to avoid leaking raw tool JSON.
|
||||
return captured[:start], nil, captured[end:], true
|
||||
prefixPart := captured[:start]
|
||||
suffixPart := captured[end:]
|
||||
if !state.toolNameSent && (state.hasMeaningfulText || strings.TrimSpace(prefixPart) != "" || strings.TrimSpace(suffixPart) != "") {
|
||||
return captured, nil, "", true
|
||||
}
|
||||
return captured[:start], parsed, captured[end:], true
|
||||
parsed := util.ParseStandaloneToolCalls(obj, toolNames)
|
||||
if len(parsed) == 0 {
|
||||
if state.toolNameSent {
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return captured, nil, "", true
|
||||
}
|
||||
if state.toolNameSent {
|
||||
if len(parsed) > 1 {
|
||||
return prefixPart, parsed[1:], suffixPart, true
|
||||
}
|
||||
return prefixPart, nil, suffixPart, true
|
||||
}
|
||||
return prefixPart, parsed, suffixPart, true
|
||||
}
|
||||
|
||||
func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
@@ -221,3 +301,320 @@ func extractJSONObjectFrom(text string, start int) (string, int, bool) {
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func buildIncrementalToolDeltas(state *toolStreamSieveState) []toolCallDelta {
|
||||
captured := state.capture.String()
|
||||
if captured == "" || state.hasMeaningfulText {
|
||||
return nil
|
||||
}
|
||||
lower := strings.ToLower(captured)
|
||||
keyIdx := strings.Index(lower, "tool_calls")
|
||||
if keyIdx < 0 {
|
||||
return nil
|
||||
}
|
||||
start := strings.LastIndex(captured[:keyIdx], "{")
|
||||
if start < 0 || strings.TrimSpace(captured[:start]) != "" {
|
||||
return nil
|
||||
}
|
||||
callStart, ok := findFirstToolCallObjectStart(captured, keyIdx)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
deltas := make([]toolCallDelta, 0, 2)
|
||||
if state.toolName == "" {
|
||||
name, ok := extractToolCallName(captured, callStart)
|
||||
if !ok || name == "" {
|
||||
return nil
|
||||
}
|
||||
state.toolName = name
|
||||
}
|
||||
if state.toolArgsStart < 0 {
|
||||
argsStart, stringMode, ok := findToolCallArgsStart(captured, callStart)
|
||||
if ok {
|
||||
state.toolArgsString = stringMode
|
||||
if stringMode {
|
||||
state.toolArgsStart = argsStart + 1
|
||||
} else {
|
||||
state.toolArgsStart = argsStart
|
||||
}
|
||||
state.toolArgsSent = state.toolArgsStart
|
||||
}
|
||||
}
|
||||
if !state.toolNameSent {
|
||||
if state.toolArgsStart < 0 {
|
||||
return nil
|
||||
}
|
||||
state.toolNameSent = true
|
||||
deltas = append(deltas, toolCallDelta{Index: 0, Name: state.toolName})
|
||||
}
|
||||
if state.toolArgsStart < 0 || state.toolArgsDone {
|
||||
return deltas
|
||||
}
|
||||
end, complete, ok := scanToolCallArgsProgress(captured, state.toolArgsStart, state.toolArgsString)
|
||||
if !ok {
|
||||
return deltas
|
||||
}
|
||||
if end > state.toolArgsSent {
|
||||
deltas = append(deltas, toolCallDelta{
|
||||
Index: 0,
|
||||
Arguments: captured[state.toolArgsSent:end],
|
||||
})
|
||||
state.toolArgsSent = end
|
||||
}
|
||||
if complete {
|
||||
state.toolArgsDone = true
|
||||
}
|
||||
return deltas
|
||||
}
|
||||
|
||||
func findFirstToolCallObjectStart(text string, keyIdx int) (int, bool) {
|
||||
arrStart, ok := findToolCallsArrayStart(text, keyIdx)
|
||||
if !ok {
|
||||
return -1, false
|
||||
}
|
||||
i := skipSpaces(text, arrStart+1)
|
||||
if i >= len(text) || text[i] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func findToolCallsArrayStart(text string, keyIdx int) (int, bool) {
|
||||
i := keyIdx + len("tool_calls")
|
||||
for i < len(text) && text[i] != ':' {
|
||||
i++
|
||||
}
|
||||
if i >= len(text) {
|
||||
return -1, false
|
||||
}
|
||||
i = skipSpaces(text, i+1)
|
||||
if i >= len(text) || text[i] != '[' {
|
||||
return -1, false
|
||||
}
|
||||
return i, true
|
||||
}
|
||||
|
||||
func extractToolCallName(text string, callStart int) (string, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return "", false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, []string{"name"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '"' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
name, _, ok := parseJSONStringLiteral(text, valueStart)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return name, true
|
||||
}
|
||||
|
||||
func findToolCallArgsStart(text string, callStart int) (int, bool, bool) {
|
||||
keys := []string{"input", "arguments", "args", "parameters", "params"}
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, keys)
|
||||
if !ok {
|
||||
fnStart, fnOK := findFunctionObjectStart(text, callStart)
|
||||
if !fnOK {
|
||||
return -1, false, false
|
||||
}
|
||||
valueStart, ok = findObjectFieldValueStart(text, fnStart, keys)
|
||||
if !ok {
|
||||
return -1, false, false
|
||||
}
|
||||
}
|
||||
if valueStart >= len(text) {
|
||||
return -1, false, false
|
||||
}
|
||||
ch := text[valueStart]
|
||||
if ch == '{' || ch == '[' {
|
||||
return valueStart, false, true
|
||||
}
|
||||
if ch == '"' {
|
||||
return valueStart, true, true
|
||||
}
|
||||
return -1, false, false
|
||||
}
|
||||
|
||||
func scanToolCallArgsProgress(text string, start int, stringMode bool) (int, bool, bool) {
|
||||
if start < 0 || start > len(text) {
|
||||
return 0, false, false
|
||||
}
|
||||
if stringMode {
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return i, true, true
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
if start >= len(text) {
|
||||
return start, false, false
|
||||
}
|
||||
if text[start] != '{' && text[start] != '[' {
|
||||
return 0, false, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := start; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' || ch == '[' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' || ch == ']' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i + 1, true, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(text), false, true
|
||||
}
|
||||
|
||||
func findObjectFieldValueStart(text string, objStart int, keys []string) (int, bool) {
|
||||
if objStart < 0 || objStart >= len(text) || text[objStart] != '{' {
|
||||
return 0, false
|
||||
}
|
||||
depth := 0
|
||||
quote := byte(0)
|
||||
escaped := false
|
||||
for i := objStart; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if quote != 0 {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == quote {
|
||||
quote = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
if ch == '"' || ch == '\'' {
|
||||
if depth == 1 {
|
||||
key, end, ok := parseJSONStringLiteral(text, i)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
j := skipSpaces(text, end)
|
||||
if j >= len(text) || text[j] != ':' {
|
||||
i = end - 1
|
||||
continue
|
||||
}
|
||||
j = skipSpaces(text, j+1)
|
||||
if j >= len(text) {
|
||||
return 0, false
|
||||
}
|
||||
if containsKey(keys, key) {
|
||||
return j, true
|
||||
}
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
quote = ch
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
continue
|
||||
}
|
||||
if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func findFunctionObjectStart(text string, callStart int) (int, bool) {
|
||||
valueStart, ok := findObjectFieldValueStart(text, callStart, []string{"function"})
|
||||
if !ok || valueStart >= len(text) || text[valueStart] != '{' {
|
||||
return -1, false
|
||||
}
|
||||
return valueStart, true
|
||||
}
|
||||
|
||||
func parseJSONStringLiteral(text string, start int) (string, int, bool) {
|
||||
if start < 0 || start >= len(text) || text[start] != '"' {
|
||||
return "", 0, false
|
||||
}
|
||||
var b strings.Builder
|
||||
escaped := false
|
||||
for i := start + 1; i < len(text); i++ {
|
||||
ch := text[i]
|
||||
if escaped {
|
||||
b.WriteByte(ch)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
return b.String(), i + 1, true
|
||||
}
|
||||
b.WriteByte(ch)
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
func containsKey(keys []string, value string) bool {
|
||||
for _, k := range keys {
|
||||
if k == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func skipSpaces(text string, i int) int {
|
||||
for i < len(text) {
|
||||
switch text[i] {
|
||||
case ' ', '\t', '\n', '\r':
|
||||
i++
|
||||
default:
|
||||
return i
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
@@ -33,6 +33,36 @@ func ParseToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
return nil
|
||||
}
|
||||
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
}
|
||||
|
||||
func ParseStandaloneToolCalls(text string, availableToolNames []string) []ParsedToolCall {
|
||||
trimmed := strings.TrimSpace(text)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
candidates := []string{trimmed}
|
||||
if strings.HasPrefix(trimmed, "```") && strings.HasSuffix(trimmed, "```") {
|
||||
if m := fencedJSONPattern.FindStringSubmatch(trimmed); len(m) >= 2 {
|
||||
candidates = append(candidates, strings.TrimSpace(m[1]))
|
||||
}
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(candidate, "{") && !strings.HasPrefix(candidate, "[") {
|
||||
continue
|
||||
}
|
||||
if parsed := parseToolCallsPayload(candidate); len(parsed) > 0 {
|
||||
return filterToolCalls(parsed, availableToolNames)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterToolCalls(parsed []ParsedToolCall, availableToolNames []string) []ParsedToolCall {
|
||||
allowed := map[string]struct{}{}
|
||||
for _, name := range availableToolNames {
|
||||
allowed[name] = struct{}{}
|
||||
|
||||
@@ -62,3 +62,16 @@ func TestFormatOpenAIToolCalls(t *testing.T) {
|
||||
t.Fatalf("unexpected function name: %#v", fn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseStandaloneToolCallsOnlyMatchesStandalonePayload(t *testing.T) {
|
||||
mixed := `这里是示例:{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
if calls := ParseStandaloneToolCalls(mixed, []string{"search"}); len(calls) != 0 {
|
||||
t.Fatalf("expected standalone parser to ignore mixed prose, got %#v", calls)
|
||||
}
|
||||
|
||||
standalone := `{"tool_calls":[{"name":"search","input":{"q":"go"}}]}`
|
||||
calls := ParseStandaloneToolCalls(standalone, []string{"search"})
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("expected standalone parser to match, got %#v", calls)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user