refactor: thread tool schemas through chat tool outputs

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
shern-point
2026-04-28 13:38:57 +08:00
parent 206c3d5479
commit 801b5abce3
7 changed files with 31 additions and 26 deletions

View File

@@ -6,12 +6,12 @@ import (
"time"
)
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string, toolsRaw any) map[string]any {
detected := toolcall.ParseAssistantToolCallsDetailed(finalText, finalThinking, toolNames)
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
return BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
}
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall) map[string]any {
func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText string, detected []toolcall.ParsedToolCall, toolsRaw any) map[string]any {
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
@@ -19,7 +19,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
}
if len(detected) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected)
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, toolsRaw)
messageObj["content"] = nil
}

View File

@@ -194,7 +194,7 @@ func TestHandleStreamContextCancelledMarksHistoryStopped(t *testing.T) {
rec := httptest.NewRecorder()
resp := makeOpenAISSEHTTPResponse(`data: {"p":"response/content","v":"hello"}`, `data: [DONE]`)
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, session)
h.handleStream(rec, req, resp, "cid-stop", "deepseek-v4-flash", "prompt", false, false, nil, nil, session)
snapshot, err := historyStore.Snapshot()
if err != nil {

View File

@@ -21,6 +21,7 @@ type chatStreamRuntime struct {
model string
finalPrompt string
toolNames []string
toolsRaw any
thinkingEnabled bool
searchEnabled bool
@@ -61,6 +62,7 @@ func newChatStreamRuntime(
searchEnabled bool,
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
bufferToolContent bool,
emitEarlyToolDeltas bool,
) *chatStreamRuntime {
@@ -73,6 +75,7 @@ func newChatStreamRuntime(
model: model,
finalPrompt: finalPrompt,
toolNames: toolNames,
toolsRaw: toolsRaw,
thinkingEnabled: thinkingEnabled,
searchEnabled: searchEnabled,
stripReferenceMarkers: stripReferenceMarkers,
@@ -142,7 +145,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
if len(detected.Calls) > 0 && !s.toolCallsDoneEmitted {
finishReason = "tool_calls"
delta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(detected.Calls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
delta["role"] = "assistant"
@@ -164,7 +167,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"
@@ -320,7 +323,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
s.toolCallsEmitted = true
s.toolCallsDoneEmitted = true
tcDelta := map[string]any{
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs),
"tool_calls": formatFinalStreamToolCallsWithStableIDs(evt.ToolCalls, s.streamToolCallIDs, s.toolsRaw),
}
if !s.firstChunkSent {
tcDelta["role"] = "assistant"

View File

@@ -26,14 +26,14 @@ type chatNonStreamResult struct {
responseMessageID int
}
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Context, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
attempts := 0
currentResp := resp
usagePrompt := finalPrompt
accumulatedThinking := ""
accumulatedToolDetectionThinking := ""
for {
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames)
result, ok := h.collectChatNonStreamAttempt(w, currentResp, completionID, model, usagePrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw)
if !ok {
return
}
@@ -43,7 +43,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
result.toolDetectionThinking = accumulatedToolDetectionThinking
detected := detectAssistantToolCalls(result.text, result.thinking, result.toolDetectionThinking, toolNames)
result.detectedCalls = len(detected.Calls)
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls)
result.body = openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, result.thinking, result.text, detected.Calls, toolsRaw)
result.finishReason = chatFinishReason(result.body)
if !shouldRetryChatNonStream(result, attempts) {
h.finishChatNonStreamResult(w, result, attempts, usagePrompt, historySession)
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
}
}
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string) (chatNonStreamResult, bool) {
func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.Response, completionID, model, usagePrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any) (chatNonStreamResult, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -88,7 +88,7 @@ func (h *Handler) collectChatNonStreamAttempt(w http.ResponseWriter, resp *http.
finalText = replaceCitationMarkersWithLinks(finalText, result.CitationLinks)
}
detected := detectAssistantToolCalls(finalText, finalThinking, finalToolDetectionThinking, toolNames)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, usagePrompt, finalThinking, finalText, detected.Calls, toolsRaw)
return chatNonStreamResult{
thinking: finalThinking,
toolDetectionThinking: finalToolDetectionThinking,
@@ -139,8 +139,8 @@ func shouldRetryChatNonStream(result chatNonStreamResult, attempts int) bool {
strings.TrimSpace(result.text) == ""
}
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, historySession)
func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request, a *auth.RequestAuth, resp *http.Response, payload map[string]any, pow, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
streamRuntime, initialType, ok := h.prepareChatStreamRuntime(w, resp, completionID, model, finalPrompt, thinkingEnabled, searchEnabled, toolNames, toolsRaw, historySession)
if !ok {
return
}
@@ -182,7 +182,7 @@ func (h *Handler) handleStreamWithRetry(w http.ResponseWriter, r *http.Request,
}
}
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) (*chatStreamRuntime, string, bool) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -207,7 +207,7 @@ func (h *Handler) prepareChatStreamRuntime(w http.ResponseWriter, resp *http.Res
}
streamRuntime := newChatStreamRuntime(
w, rc, canFlush, completionID, time.Now().Unix(), model, finalPrompt,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames,
thinkingEnabled, searchEnabled, h.compatStripReferenceMarkers(), toolNames, toolsRaw,
len(toolNames) > 0, h.toolcallFeatureMatchEnabled() && h.toolcallEarlyEmitHighConfidence(),
)
return streamRuntime, initialType, true

View File

@@ -144,8 +144,8 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
return shared.FilterIncrementalToolCallDeltasByAllowed(deltas, seenNames)
}
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids)
func formatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
return shared.FormatFinalStreamToolCallsWithStableIDs(calls, ids, toolsRaw)
}
func detectAssistantToolCalls(text, exposedThinking, detectionThinking string, toolNames []string) toolcall.ToolCallParseResult {

View File

@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, historySession)
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -148,7 +148,7 @@ func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAu
}
}
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
if resp.StatusCode != http.StatusOK {
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
@@ -176,7 +176,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
writeUpstreamEmptyOutputError(w, finalText, finalThinking, result.ContentFilter)
return
}
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls)
respBody := openaifmt.BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThinking, finalText, detected.Calls, toolsRaw)
finishReason := "stop"
if choices, ok := respBody["choices"].([]map[string]any); ok && len(choices) > 0 {
if fr, _ := choices[0]["finish_reason"].(string); strings.TrimSpace(fr) != "" {
@@ -189,7 +189,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
writeJSON(w, http.StatusOK, respBody)
}
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, historySession *chatHistorySession) {
func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *http.Response, completionID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, historySession *chatHistorySession) {
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
@@ -230,6 +230,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
searchEnabled,
stripReferenceMarkers,
toolNames,
toolsRaw,
bufferToolContent,
emitEarlyToolDeltas,
)

View File

@@ -70,12 +70,13 @@ func FilterIncrementalToolCallDeltasByAllowed(deltas []toolstream.ToolCallDelta,
return out
}
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string) []map[string]any {
func FormatFinalStreamToolCallsWithStableIDs(calls []toolcall.ParsedToolCall, ids map[int]string, toolsRaw any) []map[string]any {
if len(calls) == 0 {
return nil
}
normalizedCalls := toolcall.NormalizeParsedToolCallsForSchemas(calls, toolsRaw)
out := make([]map[string]any, 0, len(calls))
for i, c := range calls {
for i, c := range normalizedCalls {
callID := ""
if ids != nil {
callID = strings.TrimSpace(ids[i])