mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-23 01:17:44 +08:00
feat: support explicit prompt token tracking in SSE parsing and stream handlers
This commit is contained in:
@@ -12,6 +12,7 @@ import (
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
PromptTokens int
|
||||
OutputTokens int
|
||||
ContentFilter bool
|
||||
}
|
||||
@@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
promptTokens := 0
|
||||
outputTokens := 0
|
||||
contentFilter := false
|
||||
currentType := "text"
|
||||
@@ -40,18 +42,18 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
if !result.Parsed {
|
||||
return true
|
||||
}
|
||||
if result.PromptTokens > 0 {
|
||||
promptTokens = result.PromptTokens
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
if result.Stop {
|
||||
if result.ContentFilter {
|
||||
contentFilter = true
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
return false
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
for _, p := range result.Parts {
|
||||
if p.Type == "thinking" {
|
||||
trimmed := TrimContinuationOverlap(thinking.String(), p.Text)
|
||||
@@ -66,6 +68,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return CollectResult{
|
||||
Text: text.String(),
|
||||
Thinking: thinking.String(),
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
ContentFilter: contentFilter,
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ type LineResult struct {
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
PromptTokens int
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
@@ -20,9 +21,9 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
if !parsed {
|
||||
return LineResult{NextType: currentType}
|
||||
}
|
||||
outputTokens := extractAccumulatedTokenUsage(chunk)
|
||||
promptTokens, outputTokens := extractAccumulatedTokenUsage(chunk)
|
||||
if done {
|
||||
return LineResult{Parsed: true, Stop: true, NextType: currentType, OutputTokens: outputTokens}
|
||||
return LineResult{Parsed: true, Stop: true, NextType: currentType, PromptTokens: promptTokens, OutputTokens: outputTokens}
|
||||
}
|
||||
if errObj, hasErr := chunk["error"]; hasErr {
|
||||
return LineResult{
|
||||
@@ -30,6 +31,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ErrorMessage: fmt.Sprintf("%v", errObj),
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
@@ -39,6 +41,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
@@ -48,6 +51,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
@@ -58,6 +62,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,34 +364,50 @@ func hasContentFilterStatusValue(v any) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) int {
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) (int, int) {
|
||||
return findAccumulatedTokenUsage(chunk)
|
||||
}
|
||||
|
||||
func findAccumulatedTokenUsage(v any) int {
|
||||
func findAccumulatedTokenUsage(v any) (int, int) {
|
||||
switch x := v.(type) {
|
||||
case map[string]any:
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") {
|
||||
if n, ok := toInt(x["v"]); ok && n > 0 {
|
||||
return n
|
||||
return 0, n
|
||||
}
|
||||
}
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "token_usage") {
|
||||
if m, ok := x["v"].(map[string]any); ok {
|
||||
p, _ := toInt(m["prompt_tokens"])
|
||||
c, _ := toInt(m["completion_tokens"])
|
||||
if p > 0 || c > 0 {
|
||||
return p, c
|
||||
}
|
||||
}
|
||||
}
|
||||
if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 {
|
||||
return n
|
||||
return 0, n
|
||||
}
|
||||
if usage, ok := x["token_usage"].(map[string]any); ok {
|
||||
p, _ := toInt(usage["prompt_tokens"])
|
||||
c, _ := toInt(usage["completion_tokens"])
|
||||
if p > 0 || c > 0 {
|
||||
return p, c
|
||||
}
|
||||
}
|
||||
for _, vv := range x {
|
||||
if n := findAccumulatedTokenUsage(vv); n > 0 {
|
||||
return n
|
||||
if p, c := findAccumulatedTokenUsage(vv); p > 0 || c > 0 {
|
||||
return p, c
|
||||
}
|
||||
}
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if n := findAccumulatedTokenUsage(item); n > 0 {
|
||||
return n
|
||||
if p, c := findAccumulatedTokenUsage(item); p > 0 || c > 0 {
|
||||
return p, c
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
func toInt(v any) (int, bool) {
|
||||
|
||||
@@ -50,18 +50,6 @@ func TestShouldSkipPathQuasiStatus(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathElapsedSecs(t *testing.T) {
|
||||
if !shouldSkipPath("response/elapsed_secs") {
|
||||
t.Fatal("expected skip for elapsed_secs path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathTokenUsage(t *testing.T) {
|
||||
if !shouldSkipPath("response/token_usage") {
|
||||
t.Fatal("expected skip for token_usage path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSkipPathPendingFragment(t *testing.T) {
|
||||
if !shouldSkipPath("response/pending_fragment") {
|
||||
t.Fatal("expected skip for pending_fragment path")
|
||||
@@ -127,7 +115,7 @@ func TestParseSSEChunkForContentNoVField(t *testing.T) {
|
||||
|
||||
func TestParseSSEChunkForContentSkippedPath(t *testing.T) {
|
||||
parts, finished, nextType := ParseSSEChunkForContent(map[string]any{
|
||||
"p": "response/token_usage",
|
||||
"p": "response/quasi_status",
|
||||
"v": "some data",
|
||||
}, false, "text")
|
||||
if finished || len(parts) > 0 {
|
||||
@@ -498,7 +486,7 @@ func TestExtractContentRecursiveFinishedStatus(t *testing.T) {
|
||||
|
||||
func TestExtractContentRecursiveSkipsPath(t *testing.T) {
|
||||
items := []any{
|
||||
map[string]any{"p": "token_usage", "v": "data"},
|
||||
map[string]any{"p": "quasi_status", "v": "data"},
|
||||
}
|
||||
parts, finished := extractContentRecursive(items, "text")
|
||||
if finished {
|
||||
|
||||
@@ -19,6 +19,20 @@ func TestParseDeepSeekSSELineDone(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTokenUsage(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/token_usage",
|
||||
"v": map[string]any{
|
||||
"prompt_tokens": 123,
|
||||
"completion_tokens": 456,
|
||||
},
|
||||
}
|
||||
p, c := extractAccumulatedTokenUsage(chunk)
|
||||
if p != 123 || c != 456 {
|
||||
t.Fatalf("expected 123/456, got %d/%d", p, c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentSimple(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text")
|
||||
if finished {
|
||||
|
||||
Reference in New Issue
Block a user