Merge pull request #372 from shern-point/feat/accurate-context-token-length

Feat/accurate context token length
This commit is contained in:
CJACK.
2026-04-30 02:11:32 +08:00
committed by GitHub
30 changed files with 341 additions and 113 deletions

View File

@@ -23,9 +23,9 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
messageObj["content"] = nil
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
promptTokens := CountPromptTokens(finalPrompt, model)
reasoningTokens := CountOutputTokens(finalThinking, model)
completionTokens := CountOutputTokens(finalText, model)
return map[string]any{
"id": completionID,
@@ -86,9 +86,9 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi
"content": content,
})
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
promptTokens := CountPromptTokens(finalPrompt, model)
reasoningTokens := CountOutputTokens(finalThinking, model)
completionTokens := CountOutputTokens(finalText, model)
return map[string]any{
"id": responseID,
"type": "response",
@@ -140,8 +140,8 @@ func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []an
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
"output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText),
"input_tokens": CountPromptTokens(fmt.Sprintf("%v", normalizedMessages), model),
"output_tokens": CountOutputTokens(finalThinking, model) + CountOutputTokens(finalText, model),
},
}
}

View File

@@ -0,0 +1,87 @@
package util
import (
"strings"
tiktoken "github.com/hupe1980/go-tiktoken"
)
const (
defaultTokenizerModel = "gpt-4o"
claudeTokenizerModel = "claude"
)
func CountPromptTokens(text, model string) int {
base := maxTokenCount(
EstimateTokens(text),
countWithTokenizer(text, model),
)
if base <= 0 {
return 0
}
return base + conservativePromptPadding(base)
}
func CountOutputTokens(text, model string) int {
base := maxTokenCount(
EstimateTokens(text),
countWithTokenizer(text, model),
)
if base <= 0 {
return 0
}
return base
}
func countWithTokenizer(text, model string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
if err != nil {
return 0
}
ids, _, err := encoding.Encode(text, nil, nil)
if err != nil {
return 0
}
return len(ids)
}
func tokenizerModelForCount(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return defaultTokenizerModel
}
switch {
case strings.HasPrefix(model, "claude"):
return claudeTokenizerModel
case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"):
return defaultTokenizerModel
case strings.HasPrefix(model, "deepseek-v4"):
return defaultTokenizerModel
case strings.HasPrefix(model, "deepseek"):
return defaultTokenizerModel
default:
return defaultTokenizerModel
}
}
func conservativePromptPadding(base int) int {
padding := base / 50
if padding < 4 {
padding = 4
}
return padding
}
func maxTokenCount(values ...int) int {
best := 0
for _, v := range values {
if v > best {
best = v
}
}
return best
}