mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-08 10:25:28 +08:00
feat: implement sync.Pool for tiktoken encoding instances to optimize token counting performance
This commit is contained in:
@@ -4,19 +4,26 @@ package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
tiktoken "github.com/hupe1980/go-tiktoken"
|
||||
)
|
||||
|
||||
var (
|
||||
tokenEncodingPools sync.Map
|
||||
tokenEncodingUnsupported sync.Map
|
||||
)
|
||||
|
||||
func countWithTokenizer(text, model string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
|
||||
if err != nil {
|
||||
encoding, release := tokenizerEncodingForCount(tokenizerModelForCount(model))
|
||||
if encoding == nil {
|
||||
return 0
|
||||
}
|
||||
defer release()
|
||||
ids, _, err := encoding.Encode(text, nil, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
@@ -24,6 +31,53 @@ func countWithTokenizer(text, model string) int {
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func tokenizerEncodingForCount(model string) (*tiktoken.Encoding, func()) {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
model = defaultTokenizerModel
|
||||
}
|
||||
if _, ok := tokenEncodingUnsupported.Load(model); ok {
|
||||
return nil, func() {}
|
||||
}
|
||||
if rawPool, ok := tokenEncodingPools.Load(model); ok {
|
||||
pool, _ := rawPool.(*sync.Pool)
|
||||
return getEncodingFromPool(pool)
|
||||
}
|
||||
|
||||
encoding, err := tiktoken.NewEncodingForModel(model)
|
||||
if err != nil {
|
||||
tokenEncodingUnsupported.Store(model, struct{}{})
|
||||
return nil, func() {}
|
||||
}
|
||||
pool := &sync.Pool{
|
||||
New: func() any {
|
||||
encoding, err := tiktoken.NewEncodingForModel(model)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return encoding
|
||||
},
|
||||
}
|
||||
actualPool, _ := tokenEncodingPools.LoadOrStore(model, pool)
|
||||
pool, _ = actualPool.(*sync.Pool)
|
||||
return encoding, func() {
|
||||
pool.Put(encoding)
|
||||
}
|
||||
}
|
||||
|
||||
func getEncodingFromPool(pool *sync.Pool) (*tiktoken.Encoding, func()) {
|
||||
if pool == nil {
|
||||
return nil, func() {}
|
||||
}
|
||||
encoding, _ := pool.Get().(*tiktoken.Encoding)
|
||||
if encoding == nil {
|
||||
return nil, func() {}
|
||||
}
|
||||
return encoding, func() {
|
||||
pool.Put(encoding)
|
||||
}
|
||||
}
|
||||
|
||||
func tokenizerModelForCount(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
|
||||
35
internal/util/token_count_tiktoken_test.go
Normal file
35
internal/util/token_count_tiktoken_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build !386 && !arm && !mips && !mipsle && !wasm
|
||||
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTokenizerEncodingForCountCachesSupportedModel(t *testing.T) {
|
||||
encoding, release := tokenizerEncodingForCount(defaultTokenizerModel)
|
||||
if encoding == nil {
|
||||
t.Fatalf("expected tokenizer encoding for %q", defaultTokenizerModel)
|
||||
}
|
||||
release()
|
||||
|
||||
if _, ok := tokenEncodingPools.Load(defaultTokenizerModel); !ok {
|
||||
t.Fatalf("expected tokenizer encoding pool for %q", defaultTokenizerModel)
|
||||
}
|
||||
|
||||
encoding, release = tokenizerEncodingForCount(defaultTokenizerModel)
|
||||
if encoding == nil {
|
||||
t.Fatalf("expected cached tokenizer encoding for %q", defaultTokenizerModel)
|
||||
}
|
||||
release()
|
||||
}
|
||||
|
||||
func TestTokenizerEncodingForCountCachesUnsupportedModel(t *testing.T) {
|
||||
const model = "__ds2api_unsupported_tokenizer_model__"
|
||||
encoding, release := tokenizerEncodingForCount(model)
|
||||
release()
|
||||
if encoding != nil {
|
||||
t.Fatalf("expected nil encoding for unsupported model %q", model)
|
||||
}
|
||||
if _, ok := tokenEncodingUnsupported.Load(model); !ok {
|
||||
t.Fatalf("expected unsupported tokenizer model to be cached")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user