mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-14 13:15:07 +08:00
135 lines
3.1 KiB
Go
135 lines
3.1 KiB
Go
package sse
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestRawStreamSamplesTokenReplay(t *testing.T) {
|
|
root := filepath.Join("..", "..", "tests", "raw_stream_samples")
|
|
entries, err := os.ReadDir(root)
|
|
if err != nil {
|
|
t.Fatalf("read samples root: %v", err)
|
|
}
|
|
|
|
found := 0
|
|
for _, entry := range entries {
|
|
if !entry.IsDir() {
|
|
continue
|
|
}
|
|
ssePath := filepath.Join(root, entry.Name(), "upstream.stream.sse")
|
|
if _, err := os.Stat(ssePath); err != nil {
|
|
continue
|
|
}
|
|
found++
|
|
t.Run(entry.Name(), func(t *testing.T) {
|
|
raw, err := os.ReadFile(ssePath)
|
|
if err != nil {
|
|
t.Fatalf("read sample: %v", err)
|
|
}
|
|
parsedTokens, expectedTokens, err := replayAndCollectTokens(string(raw))
|
|
if err != nil {
|
|
t.Fatalf("replay token collection failed: %v", err)
|
|
}
|
|
if expectedTokens <= 0 {
|
|
t.Fatalf("expected positive token usage from raw stream, got %d", expectedTokens)
|
|
}
|
|
if parsedTokens != expectedTokens {
|
|
t.Fatalf("token mismatch parsed=%d expected=%d", parsedTokens, expectedTokens)
|
|
}
|
|
})
|
|
}
|
|
|
|
if found == 0 {
|
|
t.Fatalf("no upstream.stream.sse samples found under %s", root)
|
|
}
|
|
}
|
|
|
|
func replayAndCollectTokens(raw string) (parsedTokens int, expectedTokens int, err error) {
|
|
currentType := "thinking"
|
|
scanner := bufio.NewScanner(strings.NewReader(raw))
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if !strings.HasPrefix(line, "data:") {
|
|
continue
|
|
}
|
|
payload := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
if payload == "" || payload == "[DONE]" || !strings.HasPrefix(payload, "{") {
|
|
continue
|
|
}
|
|
var chunk map[string]any
|
|
if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
|
|
continue
|
|
}
|
|
if n := rawAccumulatedTokenUsage(chunk); n > 0 {
|
|
expectedTokens = n
|
|
}
|
|
res := ParseDeepSeekContentLine([]byte(line), true, currentType)
|
|
currentType = res.NextType
|
|
if res.OutputTokens > 0 {
|
|
parsedTokens = res.OutputTokens
|
|
}
|
|
}
|
|
if scanErr := scanner.Err(); scanErr != nil {
|
|
if errors.Is(scanErr, bufio.ErrTooLong) {
|
|
return 0, 0, errors.New("raw stream line exceeds 2MiB scanner limit")
|
|
}
|
|
return 0, 0, scanErr
|
|
}
|
|
return parsedTokens, expectedTokens, nil
|
|
}
|
|
|
|
func rawAccumulatedTokenUsage(v any) int {
|
|
switch x := v.(type) {
|
|
case []any:
|
|
for _, item := range x {
|
|
if n := rawAccumulatedTokenUsage(item); n > 0 {
|
|
return n
|
|
}
|
|
}
|
|
case map[string]any:
|
|
if n := rawToInt(x["accumulated_token_usage"]); n > 0 {
|
|
return n
|
|
}
|
|
if p, _ := x["p"].(string); strings.Contains(strings.ToLower(strings.TrimSpace(p)), "accumulated_token_usage") {
|
|
if n := rawToInt(x["v"]); n > 0 {
|
|
return n
|
|
}
|
|
}
|
|
for _, vv := range x {
|
|
if n := rawAccumulatedTokenUsage(vv); n > 0 {
|
|
return n
|
|
}
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func rawToInt(v any) int {
|
|
switch x := v.(type) {
|
|
case float64:
|
|
return int(x)
|
|
case int:
|
|
return x
|
|
case string:
|
|
s := strings.TrimSpace(x)
|
|
if s == "" {
|
|
return 0
|
|
}
|
|
if n, err := strconv.Atoi(s); err == nil {
|
|
return n
|
|
}
|
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
|
return int(f)
|
|
}
|
|
}
|
|
return 0
|
|
}
|