feat: support explicit prompt token tracking in SSE parsing and stream handlers

This commit is contained in:
CJACK
2026-04-07 01:39:27 +08:00
parent da778a18fb
commit b79a13efd5
13 changed files with 136 additions and 63 deletions

View File

@@ -12,6 +12,7 @@ import (
type CollectResult struct {
Text string
Thinking string
PromptTokens int
OutputTokens int
ContentFilter bool
}
@@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
}
text := strings.Builder{}
thinking := strings.Builder{}
promptTokens := 0
outputTokens := 0
contentFilter := false
currentType := "text"
@@ -40,18 +42,18 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
if !result.Parsed {
return true
}
if result.PromptTokens > 0 {
promptTokens = result.PromptTokens
}
if result.OutputTokens > 0 {
outputTokens = result.OutputTokens
}
if result.Stop {
if result.ContentFilter {
contentFilter = true
}
if result.OutputTokens > 0 {
outputTokens = result.OutputTokens
}
return false
}
if result.OutputTokens > 0 {
outputTokens = result.OutputTokens
}
for _, p := range result.Parts {
if p.Type == "thinking" {
trimmed := TrimContinuationOverlap(thinking.String(), p.Text)
@@ -66,6 +68,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
return CollectResult{
Text: text.String(),
Thinking: thinking.String(),
PromptTokens: promptTokens,
OutputTokens: outputTokens,
ContentFilter: contentFilter,
}

View File

@@ -10,6 +10,7 @@ type LineResult struct {
ErrorMessage string
Parts []ContentPart
NextType string
PromptTokens int
OutputTokens int
}
@@ -20,9 +21,9 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
if !parsed {
return LineResult{NextType: currentType}
}
outputTokens := extractAccumulatedTokenUsage(chunk)
promptTokens, outputTokens := extractAccumulatedTokenUsage(chunk)
if done {
return LineResult{Parsed: true, Stop: true, NextType: currentType, OutputTokens: outputTokens}
return LineResult{Parsed: true, Stop: true, NextType: currentType, PromptTokens: promptTokens, OutputTokens: outputTokens}
}
if errObj, hasErr := chunk["error"]; hasErr {
return LineResult{
@@ -30,6 +31,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
Stop: true,
ErrorMessage: fmt.Sprintf("%v", errObj),
NextType: currentType,
PromptTokens: promptTokens,
OutputTokens: outputTokens,
}
}
@@ -39,6 +41,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
Stop: true,
ContentFilter: true,
NextType: currentType,
PromptTokens: promptTokens,
OutputTokens: outputTokens,
}
}
@@ -48,6 +51,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
Stop: true,
ContentFilter: true,
NextType: currentType,
PromptTokens: promptTokens,
OutputTokens: outputTokens,
}
}
@@ -58,6 +62,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
Stop: finished,
Parts: parts,
NextType: nextType,
PromptTokens: promptTokens,
OutputTokens: outputTokens,
}
}

View File

@@ -364,34 +364,50 @@ func hasContentFilterStatusValue(v any) bool {
return false
}
func extractAccumulatedTokenUsage(chunk map[string]any) int {
func extractAccumulatedTokenUsage(chunk map[string]any) (int, int) {
return findAccumulatedTokenUsage(chunk)
}
func findAccumulatedTokenUsage(v any) int {
func findAccumulatedTokenUsage(v any) (int, int) {
switch x := v.(type) {
case map[string]any:
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
if n, ok := toInt(x["v"]); ok && n > 0 {
return n
return 0, n
}
}
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "token_usage") {
if m, ok := x["v"].(map[string]any); ok {
p, _ := toInt(m["prompt_tokens"])
c, _ := toInt(m["completion_tokens"])
if p > 0 || c > 0 {
return p, c
}
}
}
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
return n
return 0, n
}
if usage, ok := x["token_usage"].(map[string]any); ok {
p, _ := toInt(usage["prompt_tokens"])
c, _ := toInt(usage["completion_tokens"])
if p > 0 || c > 0 {
return p, c
}
}
for _, vv := range x {
if n := findAccumulatedTokenUsage(vv); n > 0 {
return n
if p, c := findAccumulatedTokenUsage(vv); p > 0 || c > 0 {
return p, c
}
}
case []any:
for _, item := range x {
if n := findAccumulatedTokenUsage(item); n > 0 {
return n
if p, c := findAccumulatedTokenUsage(item); p > 0 || c > 0 {
return p, c
}
}
}
return 0
return 0, 0
}
func toInt(v any) (int, bool) {

View File

@@ -50,18 +50,6 @@ func TestShouldSkipPathQuasiStatus(t *testing.T) {
}
}
func TestShouldSkipPathElapsedSecs(t *testing.T) {
if !shouldSkipPath("response/elapsed_secs") {
t.Fatal("expected skip for elapsed_secs path")
}
}
func TestShouldSkipPathTokenUsage(t *testing.T) {
if !shouldSkipPath("response/token_usage") {
t.Fatal("expected skip for token_usage path")
}
}
func TestShouldSkipPathPendingFragment(t *testing.T) {
if !shouldSkipPath("response/pending_fragment") {
t.Fatal("expected skip for pending_fragment path")
@@ -127,7 +115,7 @@ func TestParseSSEChunkForContentNoVField(t *testing.T) {
func TestParseSSEChunkForContentSkippedPath(t *testing.T) {
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{
"p": "response/token_usage",
"p": "response/quasi_status",
"v": "some data",
}, false, "text")
if finished || len(parts) > 0 {
@@ -498,7 +486,7 @@ func TestExtractContentRecursiveFinishedStatus(t *testing.T) {
func TestExtractContentRecursiveSkipsPath(t *testing.T) {
items := []any{
map[string]any{"p": "token_usage", "v": "data"},
map[string]any{"p": "quasi_status", "v": "data"},
}
parts, finished := extractContentRecursive(items, "text")
if finished {

View File

@@ -19,6 +19,20 @@ func TestParseDeepSeekSSELineDone(t *testing.T) {
}
}
func TestExtractTokenUsage(t *testing.T) {
chunk := map[string]any{
"p": "response/token_usage",
"v": map[string]any{
"prompt_tokens": 123,
"completion_tokens": 456,
},
}
p, c := extractAccumulatedTokenUsage(chunk)
if p != 123 || c != 456 {
t.Fatalf("expected 123/456, got %d/%d", p, c)
}
}
func TestParseSSEChunkForContentSimple(t *testing.T) {
parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text")
if finished {