feat: implement sync.Pool for tiktoken encoding instances to optimize token counting performance

This commit is contained in:
CJACK
2026-05-02 02:31:24 +08:00
parent e2756f800d
commit 4389e02b29
2 changed files with 91 additions and 2 deletions

View File

@@ -4,19 +4,26 @@ package util
import (
"strings"
"sync"
tiktoken "github.com/hupe1980/go-tiktoken"
)
var (
tokenEncodingPools sync.Map
tokenEncodingUnsupported sync.Map
)
func countWithTokenizer(text, model string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
if err != nil {
encoding, release := tokenizerEncodingForCount(tokenizerModelForCount(model))
if encoding == nil {
return 0
}
defer release()
ids, _, err := encoding.Encode(text, nil, nil)
if err != nil {
return 0
@@ -24,6 +31,53 @@ func countWithTokenizer(text, model string) int {
return len(ids)
}
func tokenizerEncodingForCount(model string) (*tiktoken.Encoding, func()) {
model = strings.TrimSpace(model)
if model == "" {
model = defaultTokenizerModel
}
if _, ok := tokenEncodingUnsupported.Load(model); ok {
return nil, func() {}
}
if rawPool, ok := tokenEncodingPools.Load(model); ok {
pool, _ := rawPool.(*sync.Pool)
return getEncodingFromPool(pool)
}
encoding, err := tiktoken.NewEncodingForModel(model)
if err != nil {
tokenEncodingUnsupported.Store(model, struct{}{})
return nil, func() {}
}
pool := &sync.Pool{
New: func() any {
encoding, err := tiktoken.NewEncodingForModel(model)
if err != nil {
return nil
}
return encoding
},
}
actualPool, _ := tokenEncodingPools.LoadOrStore(model, pool)
pool, _ = actualPool.(*sync.Pool)
return encoding, func() {
pool.Put(encoding)
}
}
func getEncodingFromPool(pool *sync.Pool) (*tiktoken.Encoding, func()) {
if pool == nil {
return nil, func() {}
}
encoding, _ := pool.Get().(*tiktoken.Encoding)
if encoding == nil {
return nil, func() {}
}
return encoding, func() {
pool.Put(encoding)
}
}
func tokenizerModelForCount(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {

View File

@@ -0,0 +1,35 @@
//go:build !386 && !arm && !mips && !mipsle && !wasm
package util
import "testing"
func TestTokenizerEncodingForCountCachesSupportedModel(t *testing.T) {
encoding, release := tokenizerEncodingForCount(defaultTokenizerModel)
if encoding == nil {
t.Fatalf("expected tokenizer encoding for %q", defaultTokenizerModel)
}
release()
if _, ok := tokenEncodingPools.Load(defaultTokenizerModel); !ok {
t.Fatalf("expected tokenizer encoding pool for %q", defaultTokenizerModel)
}
encoding, release = tokenizerEncodingForCount(defaultTokenizerModel)
if encoding == nil {
t.Fatalf("expected cached tokenizer encoding for %q", defaultTokenizerModel)
}
release()
}
func TestTokenizerEncodingForCountCachesUnsupportedModel(t *testing.T) {
const model = "__ds2api_unsupported_tokenizer_model__"
encoding, release := tokenizerEncodingForCount(model)
release()
if encoding != nil {
t.Fatalf("expected nil encoding for unsupported model %q", model)
}
if _, ok := tokenEncodingUnsupported.Load(model); !ok {
t.Fatalf("expected unsupported tokenizer model to be cached")
}
}