diff --git a/internal/util/token_count_tiktoken.go b/internal/util/token_count_tiktoken.go index f9cecf9..92e48e1 100644 --- a/internal/util/token_count_tiktoken.go +++ b/internal/util/token_count_tiktoken.go @@ -4,19 +4,26 @@ package util import ( "strings" + "sync" tiktoken "github.com/hupe1980/go-tiktoken" ) +var ( + tokenEncodingPools sync.Map + tokenEncodingUnsupported sync.Map +) + func countWithTokenizer(text, model string) int { text = strings.TrimSpace(text) if text == "" { return 0 } - encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model)) - if err != nil { + encoding, release := tokenizerEncodingForCount(tokenizerModelForCount(model)) + if encoding == nil { return 0 } + defer release() ids, _, err := encoding.Encode(text, nil, nil) if err != nil { return 0 @@ -24,6 +31,53 @@ func countWithTokenizer(text, model string) int { return len(ids) } +func tokenizerEncodingForCount(model string) (*tiktoken.Encoding, func()) { + model = strings.TrimSpace(model) + if model == "" { + model = defaultTokenizerModel + } + if _, ok := tokenEncodingUnsupported.Load(model); ok { + return nil, func() {} + } + if rawPool, ok := tokenEncodingPools.Load(model); ok { + pool, _ := rawPool.(*sync.Pool) + return getEncodingFromPool(pool) + } + + encoding, err := tiktoken.NewEncodingForModel(model) + if err != nil { + tokenEncodingUnsupported.Store(model, struct{}{}) + return nil, func() {} + } + pool := &sync.Pool{ + New: func() any { + encoding, err := tiktoken.NewEncodingForModel(model) + if err != nil { + return nil + } + return encoding + }, + } + actualPool, _ := tokenEncodingPools.LoadOrStore(model, pool) + pool, _ = actualPool.(*sync.Pool) + return encoding, func() { + pool.Put(encoding) + } +} + +func getEncodingFromPool(pool *sync.Pool) (*tiktoken.Encoding, func()) { + if pool == nil { + return nil, func() {} + } + encoding, _ := pool.Get().(*tiktoken.Encoding) + if encoding == nil { + return nil, func() {} + } + return encoding, func() { + pool.Put(encoding) + } +} + func tokenizerModelForCount(model string) string { model = strings.ToLower(strings.TrimSpace(model)) if model == "" { diff --git a/internal/util/token_count_tiktoken_test.go b/internal/util/token_count_tiktoken_test.go new file mode 100644 index 0000000..811c03d --- /dev/null +++ b/internal/util/token_count_tiktoken_test.go @@ -0,0 +1,35 @@ +//go:build !386 && !arm && !mips && !mipsle && !wasm + +package util + +import "testing" + +func TestTokenizerEncodingForCountCachesSupportedModel(t *testing.T) { + encoding, release := tokenizerEncodingForCount(defaultTokenizerModel) + if encoding == nil { + t.Fatalf("expected tokenizer encoding for %q", defaultTokenizerModel) + } + release() + + if _, ok := tokenEncodingPools.Load(defaultTokenizerModel); !ok { + t.Fatalf("expected tokenizer encoding pool for %q", defaultTokenizerModel) + } + + encoding, release = tokenizerEncodingForCount(defaultTokenizerModel) + if encoding == nil { + t.Fatalf("expected cached tokenizer encoding for %q", defaultTokenizerModel) + } + release() +} + +func TestTokenizerEncodingForCountCachesUnsupportedModel(t *testing.T) { + const model = "__ds2api_unsupported_tokenizer_model__" + encoding, release := tokenizerEncodingForCount(model) + release() + if encoding != nil { + t.Fatalf("expected nil encoding for unsupported model %q", model) + } + if _, ok := tokenEncodingUnsupported.Load(model); !ok { + t.Fatalf("expected unsupported tokenizer model to be cached") + } +}