diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index 5bf6025..31e3927 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -98,6 +98,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools` - `prompt` 才是对话上下文主载体。 - `ref_file_ids` 只承载文件引用,不承载普通文本消息。 - `tools` 不会作为“原生工具 schema”直接下发给下游,而是被改写进 `prompt`。 +- 对外返回给客户端的 `prompt_tokens` / `input_tokens` / `promptTokenCount` 不再按“最后一条消息”或字符粗估近似返回,而是基于**完整上下文 prompt**做 tokenizer 计数;为了避免上下文实际超限但客户端误以为还能塞下,请求侧上下文 token 会额外保守上浮一点,宁可略大也不低估。 - 当前 `/v1/chat/completions` 业务路径仍是“每次请求新建一个远端 `chat_session_id`,并默认发送 `parent_message_id: null`”;因此 DS2API 对外默认表现为“新会话 + prompt 拼历史”,而不是复用 DeepSeek 原生会话树。 - 但 DeepSeek 远端本身支持同一 `chat_session_id` 的跨轮次持续对话。2026-04-27 已用项目内现有 DeepSeek client 做过一次不改业务代码的双轮实测:同一 `chat_session_id` 下,第 1 轮返回 `request_message_id=1` / `response_message_id=2` / 文本 `SESSION_TEST_ONE`;第 2 轮重新获取一次 PoW,并发送 `parent_message_id=2` 后,成功返回 `request_message_id=3` / `response_message_id=4` / 文本 `SESSION_TEST_TWO`。这说明“同远端会话持续聊天”能力存在,且每轮需要携带正确的 parent/message 链接信息,同时重新获取对应轮次可用的 PoW。 - OpenAI Chat / Responses 原生走统一 OpenAI 标准化与 DeepSeek payload 组装;Claude / Gemini 会尽量复用 OpenAI prompt/tool 语义,其中 Gemini 直接复用 `promptcompat.BuildOpenAIPromptForAdapter`,Claude 消息接口在可代理场景会转换为 OpenAI chat 形态再执行。 @@ -249,6 +250,7 @@ OpenAI 文件相关实现: - `current_input_file` 默认开启;它用于把“完整上下文”合并进 `history.txt` 上下文文件。当最新 user turn 的纯文本长度达到 `current_input_file.min_chars`(默认 `0`)时,兼容层会上传一个文件名为 `history.txt` 的上下文文件,并在 live prompt 中只保留一个中性的 user 消息要求模型直接回答最新请求,不再暴露文件名或要求模型读取本地文件。 - 如果 `current_input_file.enabled=false`,请求会直接透传,不上传任何拆分上下文文件。 - 旧的 `history_split.enabled` / `history_split.trigger_after_turns` 会被读取进配置对象以保持兼容,但不会触发拆分上传,也不会影响 `current_input_file` 的默认开启。 +- 即使触发 `current_input_file` 后 live prompt 被缩短,对客户端回包里的上下文 token 统计,仍会沿用**拆分前的完整 prompt 语义**做计数,而不是按缩短后的占位 prompt 计算;否则会把真实上下文显著算小。 相关实现: diff --git a/go.mod b/go.mod index 2613f89..87cabfa 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,11 @@ require ( github.com/router-for-me/CLIProxyAPI/v6 v6.9.14 ) +require ( + github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/hupe1980/go-tiktoken v0.0.10 // indirect +) + require ( github.com/klauspost/compress v1.18.5 // indirect github.com/sirupsen/logrus v1.9.4 // indirect diff --git a/go.sum b/go.sum index 4b47fb0..811e782 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,14 @@ github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eT github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hupe1980/go-tiktoken v0.0.10 h1:m6phOJaGyctqWdGIgwn9X8AfJvaG74tnQoDL+ntOUEQ= +github.com/hupe1980/go-tiktoken v0.0.10/go.mod h1:NME6d8hrE+Jo+kLUZHhXShYV8e40hYkm4BbSLQKtvAo= github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/internal/format/claude/render.go b/internal/format/claude/render.go index 4f9ada5..694f5fd 100644 --- a/internal/format/claude/render.go +++ b/internal/format/claude/render.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "ds2api/internal/prompt" "ds2api/internal/util" ) @@ -43,8 +44,23 @@ func BuildMessageResponse(messageID, model string, normalizedMessages []any, fin "stop_reason": stopReason, "stop_sequence": nil, "usage": map[string]any{ - "input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), - "output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText), + "input_tokens": util.CountPromptTokens(prompt.MessagesPrepareWithThinking(claudeMessageMaps(normalizedMessages), false), model), + "output_tokens": util.CountOutputTokens(finalThinking, model) + util.CountOutputTokens(finalText, model), }, } } + +func claudeMessageMaps(messages []any) []map[string]any { + if len(messages) == 0 { + return nil + } + out := make([]map[string]any, 0, len(messages)) + for _, item := range messages { + msg, ok := item.(map[string]any) + if !ok { + continue + } + out = append(out, msg) + } + return out +} diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go index f88ba41..14a9d1f 100644 --- a/internal/format/openai/render_chat.go +++ b/internal/format/openai/render_chat.go @@ -29,7 +29,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi "created": time.Now().Unix(), "model": model, "choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}}, - "usage": BuildChatUsage(finalPrompt, finalThinking, finalText), + "usage": BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText), } } diff --git a/internal/format/openai/render_responses.go b/internal/format/openai/render_responses.go index 6148fdc..2d3c9dd 100644 --- a/internal/format/openai/render_responses.go +++ b/internal/format/openai/render_responses.go @@ -70,7 +70,7 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking, "model": model, "output": output, "output_text": outputText, - "usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText), + "usage": BuildResponsesUsageForModel(model, finalPrompt, finalThinking, finalText), } } diff --git a/internal/format/openai/render_test.go b/internal/format/openai/render_test.go index c1dc540..61cdb3f 100644 --- a/internal/format/openai/render_test.go +++ b/internal/format/openai/render_test.go @@ -6,6 +6,7 @@ import ( "testing" "ds2api/internal/toolcall" + "ds2api/internal/util" ) func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) { @@ -177,3 +178,17 @@ func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t t.Fatalf("expected response content stringified by schema, got %#v", args["content"]) } } + +func TestBuildChatUsageForModelUsesConservativePromptCount(t *testing.T) { + prompt := strings.Repeat("上下文token ", 40) + usage := BuildChatUsageForModel("deepseek-v4-flash", prompt, "", "ok") + promptTokens, _ := usage["prompt_tokens"].(int) + if promptTokens <= util.EstimateTokens(prompt) { + t.Fatalf("expected conservative prompt token count > rough estimate, got=%d estimate=%d", promptTokens, util.EstimateTokens(prompt)) + } + totalTokens, _ := usage["total_tokens"].(int) + completionTokens, _ := usage["completion_tokens"].(int) + if totalTokens != promptTokens+completionTokens { + t.Fatalf("expected total tokens to add up, got usage=%#v", usage) + } +} diff --git a/internal/format/openai/render_usage.go b/internal/format/openai/render_usage.go index b328d20..ad1f380 100644 --- a/internal/format/openai/render_usage.go +++ b/internal/format/openai/render_usage.go @@ -2,10 +2,10 @@ package openai import "ds2api/internal/util" -func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) +func BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.CountPromptTokens(finalPrompt, model) + reasoningTokens := util.CountOutputTokens(finalThinking, model) + completionTokens := util.CountOutputTokens(finalText, model) return map[string]any{ "prompt_tokens": promptTokens, "completion_tokens": reasoningTokens + completionTokens, @@ -16,13 +16,21 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any } } -func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) +func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any { + return BuildChatUsageForModel("", finalPrompt, finalThinking, finalText) +} + +func BuildResponsesUsageForModel(model, finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.CountPromptTokens(finalPrompt, model) + reasoningTokens := util.CountOutputTokens(finalThinking, model) + completionTokens := util.CountOutputTokens(finalText, model) return map[string]any{ "input_tokens": promptTokens, "output_tokens": reasoningTokens + completionTokens, "total_tokens": promptTokens + reasoningTokens + completionTokens, } } + +func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any { + return BuildResponsesUsageForModel("", finalPrompt, finalThinking, finalText) +} diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index de47d28..e7ed4cd 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -206,6 +206,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ h.compatStripReferenceMarkers(), toolNames, toolsRaw, + buildClaudePromptTokenText(messages, thinkingEnabled), ) streamRuntime.sendMessageStart() diff --git a/internal/httpapi/claude/handler_tokens.go b/internal/httpapi/claude/handler_tokens.go index a369345..d122b0f 100644 --- a/internal/httpapi/claude/handler_tokens.go +++ b/internal/httpapi/claude/handler_tokens.go @@ -3,8 +3,6 @@ package claude import ( "encoding/json" "net/http" - - "ds2api/internal/util" ) func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { @@ -26,26 +24,11 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) { writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.") return } - inputTokens := 0 - if sys, ok := req["system"].(string); ok { - inputTokens += util.EstimateTokens(sys) - } - for _, item := range messages { - msg, ok := item.(map[string]any) - if !ok { - continue - } - inputTokens += 2 - inputTokens += util.EstimateTokens(extractMessageContent(msg["content"])) - } - if tools, ok := req["tools"].([]any); ok { - for _, t := range tools { - b, _ := json.Marshal(t) - inputTokens += util.EstimateTokens(string(b)) - } - } - if inputTokens < 1 { - inputTokens = 1 + normalized, err := normalizeClaudeRequest(h.Store, req) + if err != nil { + writeClaudeError(w, http.StatusBadRequest, err.Error()) + return } + inputTokens := countClaudeInputTokens(normalized.Standard) writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens}) } diff --git a/internal/httpapi/claude/prompt_token_text.go b/internal/httpapi/claude/prompt_token_text.go new file mode 100644 index 0000000..f70641c --- /dev/null +++ b/internal/httpapi/claude/prompt_token_text.go @@ -0,0 +1,7 @@ +package claude + +import "ds2api/internal/prompt" + +func buildClaudePromptTokenText(messages []any, thinkingEnabled bool) string { + return prompt.MessagesPrepareWithThinking(toMessageMaps(messages), thinkingEnabled) +} diff --git a/internal/httpapi/claude/standard_request.go b/internal/httpapi/claude/standard_request.go index 3f10723..e9edb4c 100644 --- a/internal/httpapi/claude/standard_request.go +++ b/internal/httpapi/claude/standard_request.go @@ -48,17 +48,18 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma return claudeNormalizedRequest{ Standard: promptcompat.StandardRequest{ - Surface: "anthropic_messages", - RequestedModel: strings.TrimSpace(model), - ResolvedModel: dsModel, - ResponseModel: strings.TrimSpace(model), - Messages: payload["messages"].([]any), - ToolsRaw: toolsRequested, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, + Surface: "anthropic_messages", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: dsModel, + ResponseModel: strings.TrimSpace(model), + Messages: payload["messages"].([]any), + PromptTokenText: finalPrompt, + ToolsRaw: toolsRequested, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, }, NormalizedMessages: normalizedMessages, }, nil diff --git a/internal/httpapi/claude/stream_runtime_core.go b/internal/httpapi/claude/stream_runtime_core.go index 49fde53..de969e7 100644 --- a/internal/httpapi/claude/stream_runtime_core.go +++ b/internal/httpapi/claude/stream_runtime_core.go @@ -15,10 +15,11 @@ type claudeStreamRuntime struct { rc *http.ResponseController canFlush bool - model string - toolNames []string - messages []any - toolsRaw any + model string + toolNames []string + messages []any + toolsRaw any + promptTokenText string thinkingEnabled bool searchEnabled bool @@ -49,6 +50,7 @@ func newClaudeStreamRuntime( stripReferenceMarkers bool, toolNames []string, toolsRaw any, + promptTokenText string, ) *claudeStreamRuntime { return &claudeStreamRuntime{ w: w, @@ -62,6 +64,7 @@ func newClaudeStreamRuntime( stripReferenceMarkers: stripReferenceMarkers, toolNames: toolNames, toolsRaw: toolsRaw, + promptTokenText: promptTokenText, messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), thinkingBlockIndex: -1, textBlockIndex: -1, diff --git a/internal/httpapi/claude/stream_runtime_emit.go b/internal/httpapi/claude/stream_runtime_emit.go index c2fba19..e071cdc 100644 --- a/internal/httpapi/claude/stream_runtime_emit.go +++ b/internal/httpapi/claude/stream_runtime_emit.go @@ -42,7 +42,10 @@ func (s *claudeStreamRuntime) sendPing() { } func (s *claudeStreamRuntime) sendMessageStart() { - inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages)) + inputTokens := countClaudeInputTokensFromText(s.promptTokenText, s.model) + if inputTokens == 0 { + inputTokens = util.CountPromptTokens(fmt.Sprintf("%v", s.messages), s.model) + } s.send("message_start", map[string]any{ "type": "message_start", "message": map[string]any{ diff --git a/internal/httpapi/claude/stream_runtime_finalize.go b/internal/httpapi/claude/stream_runtime_finalize.go index 32e9b5f..9c239f1 100644 --- a/internal/httpapi/claude/stream_runtime_finalize.go +++ b/internal/httpapi/claude/stream_runtime_finalize.go @@ -109,7 +109,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) { } } - outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText) + outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model) s.send("message_delta", map[string]any{ "type": "message_delta", "delta": map[string]any{ diff --git a/internal/httpapi/claude/token_count.go b/internal/httpapi/claude/token_count.go new file mode 100644 index 0000000..2a06537 --- /dev/null +++ b/internal/httpapi/claude/token_count.go @@ -0,0 +1,20 @@ +package claude + +import ( + "strings" + + "ds2api/internal/promptcompat" + "ds2api/internal/util" +) + +func countClaudeInputTokens(stdReq promptcompat.StandardRequest) int { + promptText := stdReq.PromptTokenText + if strings.TrimSpace(promptText) == "" { + promptText = stdReq.FinalPrompt + } + return countClaudeInputTokensFromText(promptText, stdReq.ResolvedModel) +} + +func countClaudeInputTokensFromText(promptText, model string) int { + return util.CountPromptTokens(promptText, model) +} diff --git a/internal/httpapi/gemini/convert_request.go b/internal/httpapi/gemini/convert_request.go index ca1497a..43697e7 100644 --- a/internal/httpapi/gemini/convert_request.go +++ b/internal/httpapi/gemini/convert_request.go @@ -36,16 +36,17 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin passThrough := collectGeminiPassThrough(req) return promptcompat.StandardRequest{ - Surface: "google_gemini", - RequestedModel: requestedModel, - ResolvedModel: resolvedModel, - ResponseModel: requestedModel, - Messages: messagesRaw, - FinalPrompt: finalPrompt, - ToolNames: toolNames, - Stream: stream, - Thinking: thinkingEnabled, - Search: searchEnabled, - PassThrough: passThrough, + Surface: "google_gemini", + RequestedModel: requestedModel, + ResolvedModel: resolvedModel, + ResponseModel: requestedModel, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + FinalPrompt: finalPrompt, + ToolNames: toolNames, + Stream: stream, + Thinking: thinkingEnabled, + Search: searchEnabled, + PassThrough: passThrough, }, nil } diff --git a/internal/httpapi/gemini/handler_generate.go b/internal/httpapi/gemini/handler_generate.go index c6a08eb..00c4655 100644 --- a/internal/httpapi/gemini/handler_generate.go +++ b/internal/httpapi/gemini/handler_generate.go @@ -227,7 +227,7 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht //nolint:unused // retained for native Gemini non-stream handling path. func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText) + usage := buildGeminiUsage(model, finalPrompt, finalThinking, finalText) return map[string]any{ "candidates": []map[string]any{ { @@ -245,10 +245,10 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) - reasoningTokens := util.EstimateTokens(finalThinking) - completionTokens := util.EstimateTokens(finalText) +func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any { + promptTokens := util.CountPromptTokens(finalPrompt, model) + reasoningTokens := util.CountOutputTokens(finalThinking, model) + completionTokens := util.CountOutputTokens(finalText, model) return map[string]any{ "promptTokenCount": promptTokens, "candidatesTokenCount": reasoningTokens + completionTokens, diff --git a/internal/httpapi/gemini/handler_stream_runtime.go b/internal/httpapi/gemini/handler_stream_runtime.go index 13729fb..fb72981 100644 --- a/internal/httpapi/gemini/handler_stream_runtime.go +++ b/internal/httpapi/gemini/handler_stream_runtime.go @@ -194,6 +194,6 @@ func (s *geminiStreamRuntime) finalize() { }, }, "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText), + "usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText), }) } diff --git a/internal/httpapi/openai/chat/chat_stream_runtime.go b/internal/httpapi/openai/chat/chat_stream_runtime.go index 21d1f4f..e874206 100644 --- a/internal/httpapi/openai/chat/chat_stream_runtime.go +++ b/internal/httpapi/openai/chat/chat_stream_runtime.go @@ -222,7 +222,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) s.sendFailedChunk(status, message, code) return true } - usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText) + usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText) s.finalFinishReason = finishReason s.finalUsage = usage s.sendChunk(openaifmt.BuildChatStreamChunk( diff --git a/internal/httpapi/openai/chat/empty_retry_runtime.go b/internal/httpapi/openai/chat/empty_retry_runtime.go index c3d37b9..a33a60d 100644 --- a/internal/httpapi/openai/chat/empty_retry_runtime.go +++ b/internal/httpapi/openai/chat/empty_retry_runtime.go @@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "chat.completions", "stream", false, "retry_attempt", attempts, "error", err) return } - usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts) + usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts) currentResp = nextResp } } diff --git a/internal/httpapi/openai/chat/handler_chat.go b/internal/httpapi/openai/chat/handler_chat.go index a2e421a..2e3c822 100644 --- a/internal/httpapi/openai/chat/handler_chat.go +++ b/internal/httpapi/openai/chat/handler_chat.go @@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { return } if stdReq.Stream { - h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) + h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) return } - h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) + h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession) } func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) { @@ -183,7 +183,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co } } if historySession != nil { - historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText)) + historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText)) } writeJSON(w, http.StatusOK, respBody) } diff --git a/internal/httpapi/openai/history/current_input_file.go b/internal/httpapi/openai/history/current_input_file.go index 8a24575..464cac2 100644 --- a/internal/httpapi/openai/history/current_input_file.go +++ b/internal/httpapi/openai/history/current_input_file.go @@ -35,7 +35,6 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth, if strings.TrimSpace(fileText) == "" { return stdReq, errors.New("current user input file produced empty transcript") } - result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{ Filename: currentInputFilename, ContentType: currentInputContentType, @@ -62,6 +61,9 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth, stdReq.CurrentInputFileApplied = true stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID) stdReq.FinalPrompt, stdReq.ToolNames = promptcompat.BuildOpenAIPrompt(messages, stdReq.ToolsRaw, "", stdReq.ToolChoice, stdReq.Thinking) + // Token accounting must reflect the actual downstream context: + // the uploaded IGNORE.txt file content + the neutral live prompt. + stdReq.PromptTokenText = fileText + "\n" + stdReq.FinalPrompt return stdReq, nil } diff --git a/internal/httpapi/openai/history_split_test.go b/internal/httpapi/openai/history_split_test.go index 593735a..b792794 100644 --- a/internal/httpapi/openai/history_split_test.go +++ b/internal/httpapi/openai/history_split_test.go @@ -14,6 +14,7 @@ import ( "ds2api/internal/auth" dsclient "ds2api/internal/deepseek/client" "ds2api/internal/promptcompat" + "ds2api/internal/util" ) func historySplitTestMessages() []any { @@ -298,6 +299,52 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T) if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" { t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs) } + if !strings.Contains(out.PromptTokenText, "first turn content that is long enough") { + t.Fatalf("expected prompt token text to preserve original full context, got %q", out.PromptTokenText) + } +} + +func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *testing.T) { + ds := &inlineUploadDSStub{} + h := &openAITestSurface{ + Store: mockOpenAIConfig{ + wideInput: true, + currentInputEnabled: true, + currentInputMin: 0, + thinkingInjection: boolPtr(true), + }, + DS: ds, + } + req := map[string]any{ + "model": "deepseek-v4-flash", + "messages": historySplitTestMessages(), + } + stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "") + if err != nil { + t.Fatalf("normalize failed: %v", err) + } + + out, err := h.applyCurrentInputFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq) + if err != nil { + t.Fatalf("apply current input file failed: %v", err) + } + if out.FinalPrompt == stdReq.FinalPrompt { + t.Fatalf("expected live prompt to be rewritten after current input file") + } + // PromptTokenText must include the uploaded file content (which contains the full context) + // plus the neutral live prompt — reflecting the actual downstream token cost. + if !strings.Contains(out.PromptTokenText, "first user turn") || !strings.Contains(out.PromptTokenText, "latest user turn") { + t.Fatalf("expected prompt token text to contain file context with full conversation, got %q", out.PromptTokenText) + } + if !strings.Contains(out.PromptTokenText, "[file content end]") || !strings.Contains(out.PromptTokenText, "[file name]: IGNORE") { + t.Fatalf("expected prompt token text to include IGNORE.txt file wrapper, got %q", out.PromptTokenText) + } + if !strings.Contains(out.PromptTokenText, "Answer the latest user request directly.") { + t.Fatalf("expected prompt token text to also include neutral live prompt, got %q", out.PromptTokenText) + } + if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") { + t.Fatalf("expected live prompt to hide original turns, got %q", out.FinalPrompt) + } } func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) { @@ -434,6 +481,16 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t if len(refIDs) == 0 || refIDs[0] != "file-inline-1" { t.Fatalf("expected uploaded current input file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"]) } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response failed: %v", err) + } + usage, _ := body["usage"].(map[string]any) + promptTokens := int(usage["prompt_tokens"].(float64)) + neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash") + if promptTokens <= neutralCount { + t.Fatalf("expected prompt_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", promptTokens, neutralCount) + } } func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing.T) { @@ -476,6 +533,16 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") { t.Fatalf("expected prompt to hide original turns, got %s", promptText) } + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode response failed: %v", err) + } + usage, _ := body["usage"].(map[string]any) + inputTokens := int(usage["input_tokens"].(float64)) + neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash") + if inputTokens <= neutralCount { + t.Fatalf("expected input_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", inputTokens, neutralCount) + } } func TestChatCompletionsCurrentInputFileMapsManagedAuthFailureTo401(t *testing.T) { diff --git a/internal/httpapi/openai/responses/empty_retry_runtime.go b/internal/httpapi/openai/responses/empty_retry_runtime.go index a451c92..6b7bff4 100644 --- a/internal/httpapi/openai/responses/empty_retry_runtime.go +++ b/internal/httpapi/openai/responses/empty_retry_runtime.go @@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err) return } - usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts) + usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts) currentResp = nextResp } } diff --git a/internal/httpapi/openai/responses/responses_handler.go b/internal/httpapi/openai/responses/responses_handler.go index a04e7b1..5e5a29a 100644 --- a/internal/httpapi/openai/responses/responses_handler.go +++ b/internal/httpapi/openai/responses/responses_handler.go @@ -115,10 +115,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) { responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "") if stdReq.Stream { - h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) + h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) return } - h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) + h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID) } func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) { diff --git a/internal/promptcompat/request_normalize.go b/internal/promptcompat/request_normalize.go index 8efa772..fbb9d4c 100644 --- a/internal/promptcompat/request_normalize.go +++ b/internal/promptcompat/request_normalize.go @@ -39,20 +39,21 @@ func NormalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID refFileIDs := CollectOpenAIRefFileIDs(req) return StandardRequest{ - Surface: "openai_chat", - RequestedModel: strings.TrimSpace(model), - ResolvedModel: resolvedModel, - ResponseModel: responseModel, - Messages: messagesRaw, - ToolsRaw: req["tools"], - FinalPrompt: finalPrompt, - ToolNames: toolNames, - ToolChoice: toolPolicy, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, - RefFileIDs: refFileIDs, - PassThrough: passThrough, + Surface: "openai_chat", + RequestedModel: strings.TrimSpace(model), + ResolvedModel: resolvedModel, + ResponseModel: responseModel, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + ToolsRaw: req["tools"], + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + RefFileIDs: refFileIDs, + PassThrough: passThrough, }, nil } @@ -99,20 +100,21 @@ func NormalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra refFileIDs := CollectOpenAIRefFileIDs(req) return StandardRequest{ - Surface: "openai_responses", - RequestedModel: model, - ResolvedModel: resolvedModel, - ResponseModel: model, - Messages: messagesRaw, - ToolsRaw: req["tools"], - FinalPrompt: finalPrompt, - ToolNames: toolNames, - ToolChoice: toolPolicy, - Stream: util.ToBool(req["stream"]), - Thinking: thinkingEnabled, - Search: searchEnabled, - RefFileIDs: refFileIDs, - PassThrough: passThrough, + Surface: "openai_responses", + RequestedModel: model, + ResolvedModel: resolvedModel, + ResponseModel: model, + Messages: messagesRaw, + PromptTokenText: finalPrompt, + ToolsRaw: req["tools"], + FinalPrompt: finalPrompt, + ToolNames: toolNames, + ToolChoice: toolPolicy, + Stream: util.ToBool(req["stream"]), + Thinking: thinkingEnabled, + Search: searchEnabled, + RefFileIDs: refFileIDs, + PassThrough: passThrough, }, nil } diff --git a/internal/promptcompat/standard_request.go b/internal/promptcompat/standard_request.go index 6480d9b..1f4c48f 100644 --- a/internal/promptcompat/standard_request.go +++ b/internal/promptcompat/standard_request.go @@ -9,6 +9,7 @@ type StandardRequest struct { ResponseModel string Messages []any HistoryText string + PromptTokenText string CurrentInputFileApplied bool ToolsRaw any FinalPrompt string diff --git a/internal/util/render.go b/internal/util/render.go index 0092e4b..801d2f1 100644 --- a/internal/util/render.go +++ b/internal/util/render.go @@ -23,9 +23,9 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking, messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil) messageObj["content"] = nil } - promptTokens := EstimateTokens(finalPrompt) - reasoningTokens := EstimateTokens(finalThinking) - completionTokens := EstimateTokens(finalText) + promptTokens := CountPromptTokens(finalPrompt, model) + reasoningTokens := CountOutputTokens(finalThinking, model) + completionTokens := CountOutputTokens(finalText, model) return map[string]any{ "id": completionID, @@ -86,9 +86,9 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi "content": content, }) } - promptTokens := EstimateTokens(finalPrompt) - reasoningTokens := EstimateTokens(finalThinking) - completionTokens := EstimateTokens(finalText) + promptTokens := CountPromptTokens(finalPrompt, model) + reasoningTokens := CountOutputTokens(finalThinking, model) + completionTokens := CountOutputTokens(finalText, model) return map[string]any{ "id": responseID, "type": "response", @@ -140,8 +140,8 @@ func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []an "stop_reason": stopReason, "stop_sequence": nil, "usage": map[string]any{ - "input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)), - "output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText), + "input_tokens": CountPromptTokens(fmt.Sprintf("%v", normalizedMessages), model), + "output_tokens": CountOutputTokens(finalThinking, model) + CountOutputTokens(finalText, model), }, } } diff --git a/internal/util/token_count.go b/internal/util/token_count.go new file mode 100644 index 0000000..7ed75d8 --- /dev/null +++ b/internal/util/token_count.go @@ -0,0 +1,87 @@ +package util + +import ( + "strings" + + tiktoken "github.com/hupe1980/go-tiktoken" +) + +const ( + defaultTokenizerModel = "gpt-4o" + claudeTokenizerModel = "claude" +) + +func CountPromptTokens(text, model string) int { + base := maxTokenCount( + EstimateTokens(text), + countWithTokenizer(text, model), + ) + if base <= 0 { + return 0 + } + return base + conservativePromptPadding(base) +} + +func CountOutputTokens(text, model string) int { + base := maxTokenCount( + EstimateTokens(text), + countWithTokenizer(text, model), + ) + if base <= 0 { + return 0 + } + return base +} + +func countWithTokenizer(text, model string) int { + text = strings.TrimSpace(text) + if text == "" { + return 0 + } + encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model)) + if err != nil { + return 0 + } + ids, _, err := encoding.Encode(text, nil, nil) + if err != nil { + return 0 + } + return len(ids) +} + +func tokenizerModelForCount(model string) string { + model = strings.ToLower(strings.TrimSpace(model)) + if model == "" { + return defaultTokenizerModel + } + switch { + case strings.HasPrefix(model, "claude"): + return claudeTokenizerModel + case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"): + return defaultTokenizerModel + case strings.HasPrefix(model, "deepseek-v4"): + return defaultTokenizerModel + case strings.HasPrefix(model, "deepseek"): + return defaultTokenizerModel + default: + return defaultTokenizerModel + } +} + +func conservativePromptPadding(base int) int { + padding := base / 50 + if padding < 4 { + padding = 4 + } + return padding +} + +func maxTokenCount(values ...int) int { + best := 0 + for _, v := range values { + if v > best { + best = v + } + } + return best +}