feat: support explicit prompt token tracking in SSE parsing and stream handlers

This commit is contained in:
CJACK
2026-04-07 01:39:27 +08:00
parent da778a18fb
commit b79a13efd5
13 changed files with 136 additions and 63 deletions

View File

@@ -37,6 +37,7 @@ type chatStreamRuntime struct {
streamToolNames map[int]string
thinking strings.Builder
text strings.Builder
promptTokens int
outputTokens int
}
@@ -170,11 +171,16 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
finishReason = "tool_calls"
}
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
if s.promptTokens > 0 {
usage["prompt_tokens"] = s.promptTokens
}
if s.outputTokens > 0 {
usage["completion_tokens"] = s.outputTokens
if prompt, ok := usage["prompt_tokens"].(int); ok {
usage["total_tokens"] = prompt + s.outputTokens
}
}
if s.promptTokens > 0 || s.outputTokens > 0 {
p := usage["prompt_tokens"].(int)
c := usage["completion_tokens"].(int)
usage["total_tokens"] = p + c
}
s.sendChunk(openaifmt.BuildChatStreamChunk(
s.completionID,
@@ -190,6 +196,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !parsed.Parsed {
return streamengine.ParsedDecision{}
}
if parsed.PromptTokens > 0 {
s.promptTokens = parsed.PromptTokens
}
if parsed.OutputTokens > 0 {
s.outputTokens = parsed.OutputTokens
}
@@ -243,7 +252,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
if !s.emitEarlyToolDeltas {
continue
}
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames)
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.streamToolNames)
if len(filtered) == 0 {
continue
}

View File

@@ -131,12 +131,17 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
return
}
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
if result.OutputTokens > 0 {
if result.PromptTokens > 0 || result.OutputTokens > 0 {
if usage, ok := respBody["usage"].(map[string]any); ok {
usage["completion_tokens"] = result.OutputTokens
if prompt, ok := usage["prompt_tokens"].(int); ok {
usage["total_tokens"] = prompt + result.OutputTokens
if result.PromptTokens > 0 {
usage["prompt_tokens"] = result.PromptTokens
}
if result.OutputTokens > 0 {
usage["completion_tokens"] = result.OutputTokens
}
p, _ := usage["prompt_tokens"].(int)
c, _ := usage["completion_tokens"].(int)
usage["total_tokens"] = p + c
}
}
writeJSON(w, http.StatusOK, respBody)

View File

@@ -113,7 +113,7 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s
return out
}
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta {
func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, seenNames map[int]string) []toolCallDelta {
if len(deltas) == 0 {
return nil
}

View File

@@ -48,7 +48,7 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven
if !s.emitEarlyToolDeltas {
continue
}
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames)
filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.functionNames)
if len(filtered) == 0 {
continue
}