From d3018c281bc25b2531451a4a88b915ead86d7c25 Mon Sep 17 00:00:00 2001 From: shern-point Date: Thu, 30 Apr 2026 00:46:04 +0800 Subject: [PATCH] feat: use tokenizer-based counting in Claude token paths Unify Claude count_tokens, legacy stream accounting, and legacy render usage with preserved prompt text so Claude stops falling back to lossy message formatting. --- internal/format/claude/render.go | 20 ++++++++++++-- internal/httpapi/claude/handler_messages.go | 1 + internal/httpapi/claude/handler_tokens.go | 27 ++++--------------- internal/httpapi/claude/prompt_token_text.go | 7 +++++ internal/httpapi/claude/standard_request.go | 23 ++++++++-------- .../httpapi/claude/stream_runtime_core.go | 11 +++++--- .../httpapi/claude/stream_runtime_emit.go | 5 +++- .../httpapi/claude/stream_runtime_finalize.go | 2 +- internal/httpapi/claude/token_count.go | 20 ++++++++++++++ 9 files changed, 75 insertions(+), 41 deletions(-) create mode 100644 internal/httpapi/claude/prompt_token_text.go create mode 100644 internal/httpapi/claude/token_count.go diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go index 4f9ada5..694f5fd 100644 --- a/internal/format/claude/render.go +++ b/internal/format/claude/render.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "ds2api/internal/prompt" "ds2api/internal/util" ) @@ -43,8 +44,23 @@ func BuildMessageResponse(messageID, model string, normalizedMessages []any, fin "stop_reason": stopReason, "stop_sequence": nil, "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), - "output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText), + "input_tokens": util.CountPromptTokens(prompt.MessagesPrepareWithThinking(claudeMessageMaps(normalizedMessages), false), model), + "output_tokens": util.CountOutputTokens(finalThinking, model) + util.CountOutputTokens(finalText, model), }, } } + +func claudeMessageMaps(messages []any) []map[string]any { + if len(messages) == 0 { + return nil + } + out := make([]map[string]any, 0, len(messages)) + for _, item := range messages { + msg, ok := item.(map[string]any) + if !ok { + continue + } + out = append(out, msg) + } + return out +} diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index de47d28..e7ed4cd 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -206,6 +206,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ h.compatStripReferenceMarkers(), toolNames, toolsRaw, + buildClaudePromptTokenText(messages, thinkingEnabled), ) streamRuntime.sendMessageStart() diff --git a/internal/httpapi/claude/handler_tokens.go b/internal/httpapi/claude/handler_tokens.go index a369345..d122b0f 100644 --- a/internal/httpapi/claude/handler_tokens.go +++ b/internal/httpapi/claude/handler_tokens.go @@ -3,8 +3,6 @@ package claude import ( "encoding/json" "net/http" - - "ds2api/internal/util" ) func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { @@ -26,26 +24,11 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } - inputTokens := 0 - if sys, ok := req["system"].(string); ok { - inputTokens += util.EstimateTokens(sys) - } - for _, item := range messages { - msg, ok := item.(map[string]any) - if !ok { - continue - } - inputTokens += 2 - inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) - } - if tools, ok := req["tools"].([]any); ok { - for _, t := range tools { - b, _ := json.Marshal(t) - inputTokens += util.EstimateTokens(string(b)) - } - } - if inputTokens < 1 { - inputTokens = 1 + normalized, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) + return } + inputTokens := countClaudeInputTokens(normalized.Standard) writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) } diff --git a/internal/httpapi/claude/prompt_token_text.go b/internal/httpapi/claude/prompt_token_text.go new file mode 100644 index 0000000..f70641c --- /dev/null +++ b/internal/httpapi/claude/prompt_token_text.go @@ -0,0 +1,7 @@ +package claude + +import "ds2api/internal/prompt" + +func buildClaudePromptTokenText(messages []any, thinkingEnabled bool) string { + return prompt.MessagesPrepareWithThinking(toMessageMaps(messages), thinkingEnabled) +} diff --git a/internal/httpapi/claude/standard_request.go b/internal/httpapi/claude/standard_request.go index 3f10723..e9edb4c 100644 --- a/internal/httpapi/claude/standard_request.go +++ b/internal/httpapi/claude/standard_request.go @@ -48,17 +48,18 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma return claudeNormalizedRequest{ Standard: promptcompat.StandardRequest{ - Surface: "anthropic_messages", - RequestedModel: strings.TrimSpace(model), - ResolvedModel: dsModel, - ResponseModel: strings.TrimSpace(model), - Messages: payload["messages"].([]any), - ToolsRaw: toolsRequested, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + PromptTokenText: finalPrompt, + ToolsRaw: toolsRequested, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, }, NormalizedMessages: normalizedMessages, }, nil diff --git a/internal/httpapi/claude/stream_runtime_core.go b/internal/httpapi/claude/stream_runtime_core.go index 49fde53..de969e7 100644 --- a/internal/httpapi/claude/stream_runtime_core.go +++ b/internal/httpapi/claude/stream_runtime_core.go @@ -15,10 +15,11 @@ type claudeStreamRuntime struct { rc *http.ResponseController canFlush bool - model string - toolNames []string - messages []any - toolsRaw any + model string + toolNames []string + messages []any + toolsRaw any + promptTokenText string thinkingEnabled bool searchEnabled bool @@ -49,6 +50,7 @@ func newClaudeStreamRuntime( stripReferenceMarkers bool, toolNames []string, toolsRaw any, + promptTokenText string, ) *claudeStreamRuntime { return &claudeStreamRuntime{ w: w, @@ -62,6 +64,7 @@ func newClaudeStreamRuntime( stripReferenceMarkers: stripReferenceMarkers, toolNames: toolNames, toolsRaw: toolsRaw, + promptTokenText: promptTokenText, messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), thinkingBlockIndex: -1, textBlockIndex: -1, diff --git a/internal/httpapi/claude/stream_runtime_emit.go b/internal/httpapi/claude/stream_runtime_emit.go index c2fba19..e071cdc 100644 --- a/internal/httpapi/claude/stream_runtime_emit.go +++ b/internal/httpapi/claude/stream_runtime_emit.go @@ -42,7 +42,10 @@ func (s *claudeStreamRuntime) sendPing() { } func (s *claudeStreamRuntime) sendMessageStart() { - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + inputTokens := countClaudeInputTokensFromText(s.promptTokenText, s.model) + if inputTokens == 0 { + inputTokens = util.CountPromptTokens(fmt.Sprintf("%v", s.messages), s.model) + } s.send("message_start", map[string]any{ "type": "message_start", "message": map[string]any{ diff --git a/internal/httpapi/claude/stream_runtime_finalize.go b/internal/httpapi/claude/stream_runtime_finalize.go index 32e9b5f..9c239f1 100644 --- a/internal/httpapi/claude/stream_runtime_finalize.go +++ b/internal/httpapi/claude/stream_runtime_finalize.go @@ -109,7 +109,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { } } - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model) s.send("message_delta", map[string]any{ "type": "message_delta", "delta": map[string]any{ diff --git a/internal/httpapi/claude/token_count.go b/internal/httpapi/claude/token_count.go new file mode 100644 index 0000000..2a06537 --- /dev/null +++ b/internal/httpapi/claude/token_count.go @@ -0,0 +1,20 @@ +package claude + +import ( + "strings" + + "ds2api/internal/promptcompat" + "ds2api/internal/util" +) + +func countClaudeInputTokens(stdReq promptcompat.StandardRequest) int { + promptText := stdReq.PromptTokenText + if strings.TrimSpace(promptText) == "" { + promptText = stdReq.FinalPrompt + } + return countClaudeInputTokensFromText(promptText, stdReq.ResolvedModel) +} + +func countClaudeInputTokensFromText(promptText, model string) int { + return util.CountPromptTokens(promptText, model) +}