feat: use tokenizer-based counting in Claude token paths

Unify Claude count_tokens, legacy stream accounting, and legacy render usage with preserved prompt text so Claude stops falling back to lossy message formatting.
This commit is contained in:
shern-point
2026-04-30 00:46:04 +08:00
parent 415a2359ad
commit d3018c281b
9 changed files with 75 additions and 41 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"ds2api/internal/prompt"
"ds2api/internal/util"
)
@@ -43,8 +44,23 @@ func BuildMessageResponse(messageID, model string, normalizedMessages []any, fin
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
"output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText),
"input_tokens": util.CountPromptTokens(prompt.MessagesPrepareWithThinking(claudeMessageMaps(normalizedMessages), false), model),
"output_tokens": util.CountOutputTokens(finalThinking, model) + util.CountOutputTokens(finalText, model),
},
}
}
func claudeMessageMaps(messages []any) []map[string]any {
if len(messages) == 0 {
return nil
}
out := make([]map[string]any, 0, len(messages))
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
out = append(out, msg)
}
return out
}

View File

@@ -206,6 +206,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
h.compatStripReferenceMarkers(),
toolNames,
toolsRaw,
buildClaudePromptTokenText(messages, thinkingEnabled),
)
streamRuntime.sendMessageStart()

View File

@@ -3,8 +3,6 @@ package claude
import (
"encoding/json"
"net/http"
"ds2api/internal/util"
)
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
@@ -26,26 +24,11 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
normalized, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
inputTokens := countClaudeInputTokens(normalized.Standard)
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}

View File

@@ -0,0 +1,7 @@
package claude
import "ds2api/internal/prompt"
func buildClaudePromptTokenText(messages []any, thinkingEnabled bool) string {
return prompt.MessagesPrepareWithThinking(toMessageMaps(messages), thinkingEnabled)
}

View File

@@ -53,6 +53,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
PromptTokenText: finalPrompt,
ToolsRaw: toolsRequested,
FinalPrompt: finalPrompt,
ToolNames: toolNames,

View File

@@ -19,6 +19,7 @@ type claudeStreamRuntime struct {
toolNames []string
messages []any
toolsRaw any
promptTokenText string
thinkingEnabled bool
searchEnabled bool
@@ -49,6 +50,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
promptTokenText string,
) *claudeStreamRuntime {
return &claudeStreamRuntime{
w: w,
@@ -62,6 +64,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
promptTokenText: promptTokenText,
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
thinkingBlockIndex: -1,
textBlockIndex: -1,

View File

@@ -42,7 +42,10 @@ func (s *claudeStreamRuntime) sendPing() {
}
func (s *claudeStreamRuntime) sendMessageStart() {
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
inputTokens := countClaudeInputTokensFromText(s.promptTokenText, s.model)
if inputTokens == 0 {
inputTokens = util.CountPromptTokens(fmt.Sprintf("%v", s.messages), s.model)
}
s.send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{

View File

@@ -109,7 +109,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{

View File

@@ -0,0 +1,20 @@
package claude
import (
"strings"
"ds2api/internal/promptcompat"
"ds2api/internal/util"
)
func countClaudeInputTokens(stdReq promptcompat.StandardRequest) int {
promptText := stdReq.PromptTokenText
if strings.TrimSpace(promptText) == "" {
promptText = stdReq.FinalPrompt
}
return countClaudeInputTokensFromText(promptText, stdReq.ResolvedModel)
}
func countClaudeInputTokensFromText(promptText, model string) int {
return util.CountPromptTokens(promptText, model)
}