mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-23 01:17:44 +08:00
Merge pull request #372 from shern-point/feat/accurate-context-token-length
Feat/accurate context token length
This commit is contained in:
@@ -23,9 +23,9 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
|
||||
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
|
||||
messageObj["content"] = nil
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
promptTokens := CountPromptTokens(finalPrompt, model)
|
||||
reasoningTokens := CountOutputTokens(finalThinking, model)
|
||||
completionTokens := CountOutputTokens(finalText, model)
|
||||
|
||||
return map[string]any{
|
||||
"id": completionID,
|
||||
@@ -86,9 +86,9 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi
|
||||
"content": content,
|
||||
})
|
||||
}
|
||||
promptTokens := EstimateTokens(finalPrompt)
|
||||
reasoningTokens := EstimateTokens(finalThinking)
|
||||
completionTokens := EstimateTokens(finalText)
|
||||
promptTokens := CountPromptTokens(finalPrompt, model)
|
||||
reasoningTokens := CountOutputTokens(finalThinking, model)
|
||||
completionTokens := CountOutputTokens(finalText, model)
|
||||
return map[string]any{
|
||||
"id": responseID,
|
||||
"type": "response",
|
||||
@@ -140,8 +140,8 @@ func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []an
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]any{
|
||||
"input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
|
||||
"output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText),
|
||||
"input_tokens": CountPromptTokens(fmt.Sprintf("%v", normalizedMessages), model),
|
||||
"output_tokens": CountOutputTokens(finalThinking, model) + CountOutputTokens(finalText, model),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
87
internal/util/token_count.go
Normal file
87
internal/util/token_count.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
tiktoken "github.com/hupe1980/go-tiktoken"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTokenizerModel = "gpt-4o"
|
||||
claudeTokenizerModel = "claude"
|
||||
)
|
||||
|
||||
func CountPromptTokens(text, model string) int {
|
||||
base := maxTokenCount(
|
||||
EstimateTokens(text),
|
||||
countWithTokenizer(text, model),
|
||||
)
|
||||
if base <= 0 {
|
||||
return 0
|
||||
}
|
||||
return base + conservativePromptPadding(base)
|
||||
}
|
||||
|
||||
func CountOutputTokens(text, model string) int {
|
||||
base := maxTokenCount(
|
||||
EstimateTokens(text),
|
||||
countWithTokenizer(text, model),
|
||||
)
|
||||
if base <= 0 {
|
||||
return 0
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func countWithTokenizer(text, model string) int {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
ids, _, err := encoding.Encode(text, nil, nil)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return len(ids)
|
||||
}
|
||||
|
||||
func tokenizerModelForCount(model string) string {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
if model == "" {
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(model, "claude"):
|
||||
return claudeTokenizerModel
|
||||
case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek-v4"):
|
||||
return defaultTokenizerModel
|
||||
case strings.HasPrefix(model, "deepseek"):
|
||||
return defaultTokenizerModel
|
||||
default:
|
||||
return defaultTokenizerModel
|
||||
}
|
||||
}
|
||||
|
||||
func conservativePromptPadding(base int) int {
|
||||
padding := base / 50
|
||||
if padding < 4 {
|
||||
padding = 4
|
||||
}
|
||||
return padding
|
||||
}
|
||||
|
||||
func maxTokenCount(values ...int) int {
|
||||
best := 0
|
||||
for _, v := range values {
|
||||
if v > best {
|
||||
best = v
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
Reference in New Issue
Block a user