mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-04 08:25:26 +08:00
remove upstream token-usage plumbing and always estimate from content
This commit is contained in:
@@ -24,10 +24,9 @@ type claudeStreamRuntime struct {
|
||||
bufferToolContent bool
|
||||
stripReferenceMarkers bool
|
||||
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
messageID string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
|
||||
nextBlockIndex int
|
||||
thinkingBlockOpen bool
|
||||
@@ -70,9 +69,6 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ErrorMessage != "" {
|
||||
s.upstreamErr = parsed.ErrorMessage
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReason("upstream_error")}
|
||||
|
||||
@@ -109,9 +109,6 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
|
||||
}
|
||||
|
||||
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
|
||||
if s.outputTokens > 0 {
|
||||
outputTokens = s.outputTokens
|
||||
}
|
||||
s.send("message_delta", map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
|
||||
@@ -149,14 +149,13 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
|
||||
cleanVisibleOutput(result.Thinking, stripReferenceMarkers),
|
||||
cleanVisibleOutput(result.Text, stripReferenceMarkers),
|
||||
toolNames,
|
||||
result.OutputTokens,
|
||||
))
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini non-stream handling path.
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any {
|
||||
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
|
||||
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens)
|
||||
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
|
||||
return map[string]any{
|
||||
"candidates": []map[string]any{
|
||||
{
|
||||
@@ -174,14 +173,10 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini non-stream handling path.
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any {
|
||||
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
|
||||
promptTokens := util.EstimateTokens(finalPrompt)
|
||||
reasoningTokens := util.EstimateTokens(finalThinking)
|
||||
completionTokens := util.EstimateTokens(finalText)
|
||||
if outputTokens > 0 {
|
||||
completionTokens = outputTokens
|
||||
reasoningTokens = 0
|
||||
}
|
||||
return map[string]any{
|
||||
"promptTokenCount": promptTokens,
|
||||
"candidatesTokenCount": reasoningTokens + completionTokens,
|
||||
|
||||
@@ -65,9 +65,8 @@ type geminiStreamRuntime struct {
|
||||
stripReferenceMarkers bool
|
||||
toolNames []string
|
||||
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
outputTokens int
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
}
|
||||
|
||||
//nolint:unused // retained for native Gemini stream handling path.
|
||||
@@ -112,9 +111,6 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
@@ -198,6 +194,6 @@ func (s *geminiStreamRuntime) finalize() {
|
||||
},
|
||||
},
|
||||
"modelVersion": s.model,
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens),
|
||||
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,8 +37,6 @@ type chatStreamRuntime struct {
|
||||
streamToolNames map[int]string
|
||||
thinking strings.Builder
|
||||
text strings.Builder
|
||||
promptTokens int
|
||||
outputTokens int
|
||||
}
|
||||
|
||||
func newChatStreamRuntime(
|
||||
@@ -171,17 +169,6 @@ func (s *chatStreamRuntime) finalize(finishReason string) {
|
||||
finishReason = "tool_calls"
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
if s.promptTokens > 0 {
|
||||
usage["prompt_tokens"] = s.promptTokens
|
||||
}
|
||||
if s.outputTokens > 0 {
|
||||
usage["completion_tokens"] = s.outputTokens
|
||||
}
|
||||
if s.promptTokens > 0 || s.outputTokens > 0 {
|
||||
p := usage["prompt_tokens"].(int)
|
||||
c := usage["completion_tokens"].(int)
|
||||
usage["total_tokens"] = p + c
|
||||
}
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
s.completionID,
|
||||
s.created,
|
||||
@@ -196,12 +183,6 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.PromptTokens > 0 {
|
||||
s.promptTokens = parsed.PromptTokens
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter {
|
||||
return streamengine.ParsedDecision{Stop: true, StopReason: streamengine.StopReasonHandlerRequested}
|
||||
}
|
||||
|
||||
@@ -131,19 +131,6 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re
|
||||
return
|
||||
}
|
||||
respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames)
|
||||
if result.PromptTokens > 0 || result.OutputTokens > 0 {
|
||||
if usage, ok := respBody["usage"].(map[string]any); ok {
|
||||
if result.PromptTokens > 0 {
|
||||
usage["prompt_tokens"] = result.PromptTokens
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
usage["completion_tokens"] = result.OutputTokens
|
||||
}
|
||||
p, _ := usage["prompt_tokens"].(int)
|
||||
c, _ := usage["completion_tokens"].(int)
|
||||
usage["total_tokens"] = p + c
|
||||
}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
|
||||
@@ -130,19 +130,6 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res
|
||||
}
|
||||
|
||||
responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, toolNames)
|
||||
if result.PromptTokens > 0 || result.OutputTokens > 0 {
|
||||
if usage, ok := responseObj["usage"].(map[string]any); ok {
|
||||
if result.PromptTokens > 0 {
|
||||
usage["input_tokens"] = result.PromptTokens
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
usage["output_tokens"] = result.OutputTokens
|
||||
}
|
||||
input, _ := usage["input_tokens"].(int)
|
||||
output, _ := usage["output_tokens"].(int)
|
||||
usage["total_tokens"] = input + output
|
||||
}
|
||||
}
|
||||
h.getResponseStore().put(owner, responseID, responseObj)
|
||||
writeJSON(w, http.StatusOK, responseObj)
|
||||
}
|
||||
|
||||
@@ -51,8 +51,6 @@ type responsesStreamRuntime struct {
|
||||
messagePartAdded bool
|
||||
sequence int
|
||||
failed bool
|
||||
promptTokens int
|
||||
outputTokens int
|
||||
|
||||
persistResponse func(obj map[string]any)
|
||||
}
|
||||
@@ -150,24 +148,6 @@ func (s *responsesStreamRuntime) finalize() {
|
||||
s.closeIncompleteFunctionItems()
|
||||
|
||||
obj := s.buildCompletedResponseObject(finalThinking, finalText, detected)
|
||||
if s.outputTokens > 0 {
|
||||
if usage, ok := obj["usage"].(map[string]any); ok {
|
||||
usage["output_tokens"] = s.outputTokens
|
||||
}
|
||||
}
|
||||
if s.promptTokens > 0 || s.outputTokens > 0 {
|
||||
if usage, ok := obj["usage"].(map[string]any); ok {
|
||||
if s.promptTokens > 0 {
|
||||
usage["input_tokens"] = s.promptTokens
|
||||
}
|
||||
if s.outputTokens > 0 {
|
||||
usage["output_tokens"] = s.outputTokens
|
||||
}
|
||||
input, _ := usage["input_tokens"].(int)
|
||||
output, _ := usage["output_tokens"].(int)
|
||||
usage["total_tokens"] = input + output
|
||||
}
|
||||
}
|
||||
if s.persistResponse != nil {
|
||||
s.persistResponse(obj)
|
||||
}
|
||||
@@ -196,12 +176,6 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa
|
||||
if !parsed.Parsed {
|
||||
return streamengine.ParsedDecision{}
|
||||
}
|
||||
if parsed.PromptTokens > 0 {
|
||||
s.promptTokens = parsed.PromptTokens
|
||||
}
|
||||
if parsed.OutputTokens > 0 {
|
||||
s.outputTokens = parsed.OutputTokens
|
||||
}
|
||||
if parsed.ContentFilter || parsed.ErrorMessage != "" || parsed.Stop {
|
||||
return streamengine.ParsedDecision{Stop: true}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,6 @@ func TestGoCompatSSEFixtures(t *testing.T) {
|
||||
Finished bool `json:"finished"`
|
||||
NewType string `json:"new_type"`
|
||||
ContentFilter bool `json:"content_filter"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
}
|
||||
mustLoadJSON(t, expectedPath, &expected)
|
||||
@@ -58,11 +57,10 @@ func TestGoCompatSSEFixtures(t *testing.T) {
|
||||
res.Stop != expected.Finished ||
|
||||
res.NextType != expected.NewType ||
|
||||
res.ContentFilter != expected.ContentFilter ||
|
||||
res.OutputTokens != expected.OutputTokens ||
|
||||
res.ErrorMessage != expected.ErrorMessage {
|
||||
t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q contentFilter=%v outputTokens=%d errorMessage=%q\nwant parts=%#v finished=%v newType=%q contentFilter=%v outputTokens=%d errorMessage=%q",
|
||||
name, gotParts, res.Stop, res.NextType, res.ContentFilter, res.OutputTokens, res.ErrorMessage,
|
||||
expected.Parts, expected.Finished, expected.NewType, expected.ContentFilter, expected.OutputTokens, expected.ErrorMessage)
|
||||
t.Fatalf("fixture %s mismatch:\n got parts=%#v finished=%v newType=%q contentFilter=%v errorMessage=%q\nwant parts=%#v finished=%v newType=%q contentFilter=%v errorMessage=%q",
|
||||
name, gotParts, res.Stop, res.NextType, res.ContentFilter, res.ErrorMessage,
|
||||
expected.Parts, expected.Finished, expected.NewType, expected.ContentFilter, expected.ErrorMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,8 +125,6 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
let currentType = thinkingEnabled ? 'thinking' : 'text';
|
||||
let thinkingText = '';
|
||||
let outputText = '';
|
||||
let promptTokens = 0;
|
||||
let outputTokens = 0;
|
||||
const toolSieveEnabled = toolPolicy.toolSieveEnabled;
|
||||
const toolSieveState = createToolSieveState();
|
||||
let toolCallsEmitted = false;
|
||||
@@ -179,7 +177,7 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
created,
|
||||
model,
|
||||
choices: [{ delta: {}, index: 0, finish_reason: reason }],
|
||||
usage: buildUsage(finalPrompt, thinkingText, outputText, outputTokens, promptTokens),
|
||||
usage: buildUsage(finalPrompt, thinkingText, outputText),
|
||||
});
|
||||
if (!res.writableEnded && !res.destroyed) {
|
||||
res.write('data: [DONE]\n\n');
|
||||
@@ -228,12 +226,6 @@ async function handleVercelStream(req, res, rawBody, payload) {
|
||||
if (!parsed.parsed) {
|
||||
continue;
|
||||
}
|
||||
if (parsed.promptTokens > 0) {
|
||||
promptTokens = parsed.promptTokens;
|
||||
}
|
||||
if (parsed.outputTokens > 0) {
|
||||
outputTokens = parsed.outputTokens;
|
||||
}
|
||||
currentType = parsed.newType;
|
||||
if (parsed.errorMessage) {
|
||||
await finish('content_filter');
|
||||
|
||||
@@ -12,8 +12,6 @@ import (
|
||||
type CollectResult struct {
|
||||
Text string
|
||||
Thinking string
|
||||
PromptTokens int
|
||||
OutputTokens int
|
||||
ContentFilter bool
|
||||
}
|
||||
|
||||
@@ -29,8 +27,6 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
}
|
||||
text := strings.Builder{}
|
||||
thinking := strings.Builder{}
|
||||
promptTokens := 0
|
||||
outputTokens := 0
|
||||
contentFilter := false
|
||||
currentType := "text"
|
||||
if thinkingEnabled {
|
||||
@@ -42,12 +38,6 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
if !result.Parsed {
|
||||
return true
|
||||
}
|
||||
if result.PromptTokens > 0 {
|
||||
promptTokens = result.PromptTokens
|
||||
}
|
||||
if result.OutputTokens > 0 {
|
||||
outputTokens = result.OutputTokens
|
||||
}
|
||||
if result.Stop {
|
||||
if result.ContentFilter {
|
||||
contentFilter = true
|
||||
@@ -68,8 +58,6 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co
|
||||
return CollectResult{
|
||||
Text: text.String(),
|
||||
Thinking: thinking.String(),
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
ContentFilter: contentFilter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,8 +10,6 @@ type LineResult struct {
|
||||
ErrorMessage string
|
||||
Parts []ContentPart
|
||||
NextType string
|
||||
PromptTokens int
|
||||
OutputTokens int
|
||||
}
|
||||
|
||||
// ParseDeepSeekContentLine centralizes one-line DeepSeek SSE parsing for both
|
||||
@@ -21,9 +19,8 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
if !parsed {
|
||||
return LineResult{NextType: currentType}
|
||||
}
|
||||
promptTokens, outputTokens := extractAccumulatedTokenUsage(chunk)
|
||||
if done {
|
||||
return LineResult{Parsed: true, Stop: true, NextType: currentType, PromptTokens: promptTokens, OutputTokens: outputTokens}
|
||||
return LineResult{Parsed: true, Stop: true, NextType: currentType}
|
||||
}
|
||||
if errObj, hasErr := chunk["error"]; hasErr {
|
||||
return LineResult{
|
||||
@@ -31,8 +28,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ErrorMessage: fmt.Sprintf("%v", errObj),
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
if code, _ := chunk["code"].(string); code == "content_filter" {
|
||||
@@ -41,8 +36,6 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
if hasContentFilterStatus(chunk) {
|
||||
@@ -51,18 +44,14 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri
|
||||
Stop: true,
|
||||
ContentFilter: true,
|
||||
NextType: currentType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
}
|
||||
}
|
||||
parts, finished, nextType := ParseSSEChunkForContent(chunk, thinkingEnabled, currentType)
|
||||
parts = filterLeakedContentFilterParts(parts)
|
||||
return LineResult{
|
||||
Parsed: true,
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
PromptTokens: promptTokens,
|
||||
OutputTokens: outputTokens,
|
||||
Parsed: true,
|
||||
Stop: finished,
|
||||
Parts: parts,
|
||||
NextType: nextType,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestParseDeepSeekContentLineContentFilter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterCodeIgnoresUpstreamOutputTokens(t *testing.T) {
|
||||
func TestParseDeepSeekContentLineContentFilterCodeStops(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine(
|
||||
[]byte(`data: {"code":"content_filter","accumulated_token_usage":99}`),
|
||||
false, "text",
|
||||
@@ -34,9 +34,6 @@ func TestParseDeepSeekContentLineContentFilterCodeIgnoresUpstreamOutputTokens(t
|
||||
if !res.Parsed || !res.Stop || !res.ContentFilter {
|
||||
t.Fatalf("expected content-filter stop result: %#v", res)
|
||||
}
|
||||
if res.OutputTokens != 0 {
|
||||
t.Fatalf("expected upstream output token usage to be ignored, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
|
||||
@@ -48,26 +45,23 @@ func TestParseDeepSeekContentLineContentFilterStatus(t *testing.T) {
|
||||
|
||||
func TestParseDeepSeekContentLineIgnoresAccumulatedTokenUsage(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":1383},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
|
||||
if res.OutputTokens != 0 {
|
||||
t.Fatalf("expected accumulated token usage ignored, got %d", res.OutputTokens)
|
||||
if !res.Parsed {
|
||||
t.Fatalf("expected parsed result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineIgnoresAccumulatedTokenUsageString(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":"190"},{"p":"quasi_status","v":"FINISHED"}]}`), false, "text")
|
||||
if res.OutputTokens != 0 {
|
||||
t.Fatalf("expected accumulated token usage string ignored, got %d", res.OutputTokens)
|
||||
if !res.Parsed {
|
||||
t.Fatalf("expected parsed result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineErrorIgnoresUpstreamOutputTokens(t *testing.T) {
|
||||
func TestParseDeepSeekContentLineErrorStops(t *testing.T) {
|
||||
res := ParseDeepSeekContentLine([]byte(`data: {"error":"boom","accumulated_token_usage":123}`), false, "text")
|
||||
if !res.Parsed || !res.Stop {
|
||||
t.Fatalf("expected stop on error: %#v", res)
|
||||
}
|
||||
if res.OutputTokens != 0 {
|
||||
t.Fatalf("expected output token usage ignored on error, got %d", res.OutputTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDeepSeekContentLineContent(t *testing.T) {
|
||||
|
||||
@@ -361,10 +361,3 @@ func hasContentFilterStatusValue(v any) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAccumulatedTokenUsage(chunk map[string]any) (int, int) {
|
||||
// 临时策略:忽略上游 usage 字段(accumulated_token_usage / token_usage),
|
||||
// 由下游统一使用内部估算 token 计数,避免上下文累计口径导致单次输出偏差过大。
|
||||
_ = chunk
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
@@ -19,20 +19,6 @@ func TestParseDeepSeekSSELineDone(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTokenUsage(t *testing.T) {
|
||||
chunk := map[string]any{
|
||||
"p": "response/token_usage",
|
||||
"v": map[string]any{
|
||||
"prompt_tokens": 123,
|
||||
"completion_tokens": 456,
|
||||
},
|
||||
}
|
||||
p, c := extractAccumulatedTokenUsage(chunk)
|
||||
if p != 0 || c != 0 {
|
||||
t.Fatalf("expected upstream usage ignored as 0/0, got %d/%d", p, c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEChunkForContentSimple(t *testing.T) {
|
||||
parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text")
|
||||
if finished {
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRawStreamSamplesTokenReplay(t *testing.T) {
|
||||
root := filepath.Join("..", "..", "tests", "raw_stream_samples")
|
||||
entries, err := os.ReadDir(root)
|
||||
if err != nil {
|
||||
t.Fatalf("read samples root: %v", err)
|
||||
}
|
||||
|
||||
found := 0
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
ssePath := filepath.Join(root, entry.Name(), "upstream.stream.sse")
|
||||
if _, err := os.Stat(ssePath); err != nil {
|
||||
continue
|
||||
}
|
||||
found++
|
||||
t.Run(entry.Name(), func(t *testing.T) {
|
||||
raw, err := os.ReadFile(ssePath)
|
||||
if err != nil {
|
||||
t.Fatalf("read sample: %v", err)
|
||||
}
|
||||
parsedTokens, expectedTokens, err := replayAndCollectTokens(string(raw))
|
||||
if err != nil {
|
||||
t.Fatalf("replay token collection failed: %v", err)
|
||||
}
|
||||
if expectedTokens <= 0 {
|
||||
t.Fatalf("expected positive token usage from raw stream, got %d", expectedTokens)
|
||||
}
|
||||
if parsedTokens != 0 {
|
||||
t.Fatalf("expected parser to ignore upstream token usage, got parsed=%d expectedRaw=%d", parsedTokens, expectedTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if found == 0 {
|
||||
t.Fatalf("no upstream.stream.sse samples found under %s", root)
|
||||
}
|
||||
}
|
||||
|
||||
func replayAndCollectTokens(raw string) (parsedTokens int, expectedTokens int, err error) {
|
||||
currentType := "thinking"
|
||||
scanner := bufio.NewScanner(strings.NewReader(raw))
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
if payload == "" || payload == "[DONE]" || !strings.HasPrefix(payload, "{") {
|
||||
continue
|
||||
}
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
if n := rawAccumulatedTokenUsage(chunk); n > 0 {
|
||||
expectedTokens = n
|
||||
}
|
||||
res := ParseDeepSeekContentLine([]byte(line), true, currentType)
|
||||
currentType = res.NextType
|
||||
if res.OutputTokens > 0 {
|
||||
parsedTokens = res.OutputTokens
|
||||
}
|
||||
}
|
||||
if scanErr := scanner.Err(); scanErr != nil {
|
||||
if errors.Is(scanErr, bufio.ErrTooLong) {
|
||||
return 0, 0, errors.New("raw stream line exceeds 2MiB scanner limit")
|
||||
}
|
||||
return 0, 0, scanErr
|
||||
}
|
||||
return parsedTokens, expectedTokens, nil
|
||||
}
|
||||
|
||||
func rawAccumulatedTokenUsage(v any) int {
|
||||
switch x := v.(type) {
|
||||
case []any:
|
||||
for _, item := range x {
|
||||
if n := rawAccumulatedTokenUsage(item); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
case map[string]any:
|
||||
if n := rawToInt(x["accumulated_token_usage"]); n > 0 {
|
||||
return n
|
||||
}
|
||||
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(strings.TrimSpace(p)), "accumulated_token_usage") {
|
||||
if n := rawToInt(x["v"]); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
for _, vv := range x {
|
||||
if n := rawAccumulatedTokenUsage(vv); n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func rawToInt(v any) int {
|
||||
switch x := v.(type) {
|
||||
case float64:
|
||||
return int(x)
|
||||
case int:
|
||||
return x
|
||||
case string:
|
||||
s := strings.TrimSpace(x)
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
if n, err := strconv.Atoi(s); err == nil {
|
||||
return n
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return int(f)
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -31,7 +31,6 @@ test('js compat: sse fixtures', () => {
|
||||
assert.equal(got.finished, expected.finished, `${name}: finished mismatch`);
|
||||
assert.equal(got.newType, expected.new_type, `${name}: newType mismatch`);
|
||||
assert.equal(Boolean(got.contentFilter), Boolean(expected.content_filter), `${name}: contentFilter mismatch`);
|
||||
assert.equal(Number(got.outputTokens || 0), Number(expected.output_tokens || 0), `${name}: outputTokens mismatch`);
|
||||
assert.equal(got.errorMessage || '', expected.error_message || '', `${name}: errorMessage mismatch`);
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user