Merge pull request #372 from shern-point/feat/accurate-context-token-length

Feat/accurate context token length
This commit is contained in:
CJACK.
2026-04-30 02:11:32 +08:00
committed by GitHub
30 changed files with 341 additions and 113 deletions

View File

@@ -98,6 +98,7 @@ DS2API 当前的核心思路,不是把客户端传来的 `messages`、`tools`
- `prompt` 才是对话上下文主载体。
- `ref_file_ids` 只承载文件引用,不承载普通文本消息。
- `tools` 不会作为“原生工具 schema”直接下发给下游而是被改写进 `prompt`
- 对外返回给客户端的 `prompt_tokens` / `input_tokens` / `promptTokenCount` 不再按“最后一条消息”或字符粗估近似返回,而是基于**完整上下文 prompt**做 tokenizer 计数;为了避免上下文实际超限但客户端误以为还能塞下,请求侧上下文 token 会额外保守上浮一点,宁可略大也不低估。
- 当前 `/v1/chat/completions` 业务路径仍是“每次请求新建一个远端 `chat_session_id`,并默认发送 `parent_message_id: null`”;因此 DS2API 对外默认表现为“新会话 + prompt 拼历史”,而不是复用 DeepSeek 原生会话树。
- 但 DeepSeek 远端本身支持同一 `chat_session_id` 的跨轮次持续对话。2026-04-27 已用项目内现有 DeepSeek client 做过一次不改业务代码的双轮实测:同一 `chat_session_id` 下,第 1 轮返回 `request_message_id=1` / `response_message_id=2` / 文本 `SESSION_TEST_ONE`;第 2 轮重新获取一次 PoW并发送 `parent_message_id=2` 后,成功返回 `request_message_id=3` / `response_message_id=4` / 文本 `SESSION_TEST_TWO`。这说明“同远端会话持续聊天”能力存在,且每轮需要携带正确的 parent/message 链接信息,同时重新获取对应轮次可用的 PoW。
- OpenAI Chat / Responses 原生走统一 OpenAI 标准化与 DeepSeek payload 组装Claude / Gemini 会尽量复用 OpenAI prompt/tool 语义,其中 Gemini 直接复用 `promptcompat.BuildOpenAIPromptForAdapter`Claude 消息接口在可代理场景会转换为 OpenAI chat 形态再执行。
@@ -249,6 +250,7 @@ OpenAI 文件相关实现:
- `current_input_file` 默认开启;它用于把“完整上下文”合并进 `history.txt` 上下文文件。当最新 user turn 的纯文本长度达到 `current_input_file.min_chars`(默认 `0`)时,兼容层会上传一个文件名为 `history.txt` 的上下文文件,并在 live prompt 中只保留一个中性的 user 消息要求模型直接回答最新请求,不再暴露文件名或要求模型读取本地文件。
- 如果 `current_input_file.enabled=false`,请求会直接透传,不上传任何拆分上下文文件。
- 旧的 `history_split.enabled` / `history_split.trigger_after_turns` 会被读取进配置对象以保持兼容,但不会触发拆分上传,也不会影响 `current_input_file` 的默认开启。
- 即使触发 `current_input_file` 后 live prompt 被缩短,对客户端回包里的上下文 token 统计,仍会沿用**拆分前的完整 prompt 语义**做计数,而不是按缩短后的占位 prompt 计算;否则会把真实上下文显著算小。
相关实现:

5
go.mod
View File

@@ -10,6 +10,11 @@ require (
github.com/router-for-me/CLIProxyAPI/v6 v6.9.14
)
require (
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/hupe1980/go-tiktoken v0.0.10 // indirect
)
require (
github.com/klauspost/compress v1.18.5 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect

4
go.sum
View File

@@ -2,10 +2,14 @@ github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eT
github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hupe1980/go-tiktoken v0.0.10 h1:m6phOJaGyctqWdGIgwn9X8AfJvaG74tnQoDL+ntOUEQ=
github.com/hupe1980/go-tiktoken v0.0.10/go.mod h1:NME6d8hrE+Jo+kLUZHhXShYV8e40hYkm4BbSLQKtvAo=
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"time"
"ds2api/internal/prompt"
"ds2api/internal/util"
)
@@ -43,8 +44,23 @@ func BuildMessageResponse(messageID, model string, normalizedMessages []any, fin
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": util.EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
"output_tokens": util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText),
"input_tokens": util.CountPromptTokens(prompt.MessagesPrepareWithThinking(claudeMessageMaps(normalizedMessages), false), model),
"output_tokens": util.CountOutputTokens(finalThinking, model) + util.CountOutputTokens(finalText, model),
},
}
}
func claudeMessageMaps(messages []any) []map[string]any {
if len(messages) == 0 {
return nil
}
out := make([]map[string]any, 0, len(messages))
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
out = append(out, msg)
}
return out
}

View File

@@ -29,7 +29,7 @@ func BuildChatCompletionWithToolCalls(completionID, model, finalPrompt, finalThi
"created": time.Now().Unix(),
"model": model,
"choices": []map[string]any{{"index": 0, "message": messageObj, "finish_reason": finishReason}},
"usage": BuildChatUsage(finalPrompt, finalThinking, finalText),
"usage": BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText),
}
}

View File

@@ -70,7 +70,7 @@ func BuildResponseObjectFromItems(responseID, model, finalPrompt, finalThinking,
"model": model,
"output": output,
"output_text": outputText,
"usage": BuildResponsesUsage(finalPrompt, finalThinking, finalText),
"usage": BuildResponsesUsageForModel(model, finalPrompt, finalThinking, finalText),
}
}

View File

@@ -6,6 +6,7 @@ import (
"testing"
"ds2api/internal/toolcall"
"ds2api/internal/util"
)
func TestBuildResponseObjectKeepsFencedToolPayloadAsText(t *testing.T) {
@@ -177,3 +178,17 @@ func TestBuildResponseObjectWithToolCallsCoercesSchemaDeclaredStringArguments(t
t.Fatalf("expected response content stringified by schema, got %#v", args["content"])
}
}
func TestBuildChatUsageForModelUsesConservativePromptCount(t *testing.T) {
prompt := strings.Repeat("上下文token ", 40)
usage := BuildChatUsageForModel("deepseek-v4-flash", prompt, "", "ok")
promptTokens, _ := usage["prompt_tokens"].(int)
if promptTokens <= util.EstimateTokens(prompt) {
t.Fatalf("expected conservative prompt token count > rough estimate, got=%d estimate=%d", promptTokens, util.EstimateTokens(prompt))
}
totalTokens, _ := usage["total_tokens"].(int)
completionTokens, _ := usage["completion_tokens"].(int)
if totalTokens != promptTokens+completionTokens {
t.Fatalf("expected total tokens to add up, got usage=%#v", usage)
}
}

View File

@@ -2,10 +2,10 @@ package openai
import "ds2api/internal/util"
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
func BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)
reasoningTokens := util.CountOutputTokens(finalThinking, model)
completionTokens := util.CountOutputTokens(finalText, model)
return map[string]any{
"prompt_tokens": promptTokens,
"completion_tokens": reasoningTokens + completionTokens,
@@ -16,13 +16,21 @@ func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any
}
}
func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
func BuildChatUsage(finalPrompt, finalThinking, finalText string) map[string]any {
return BuildChatUsageForModel("", finalPrompt, finalThinking, finalText)
}
func BuildResponsesUsageForModel(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)
reasoningTokens := util.CountOutputTokens(finalThinking, model)
completionTokens := util.CountOutputTokens(finalText, model)
return map[string]any{
"input_tokens": promptTokens,
"output_tokens": reasoningTokens + completionTokens,
"total_tokens": promptTokens + reasoningTokens + completionTokens,
}
}
func BuildResponsesUsage(finalPrompt, finalThinking, finalText string) map[string]any {
return BuildResponsesUsageForModel("", finalPrompt, finalThinking, finalText)
}

View File

@@ -206,6 +206,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
h.compatStripReferenceMarkers(),
toolNames,
toolsRaw,
buildClaudePromptTokenText(messages, thinkingEnabled),
)
streamRuntime.sendMessageStart()

View File

@@ -3,8 +3,6 @@ package claude
import (
"encoding/json"
"net/http"
"ds2api/internal/util"
)
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
@@ -26,26 +24,11 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
normalized, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
inputTokens := countClaudeInputTokens(normalized.Standard)
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}

View File

@@ -0,0 +1,7 @@
package claude
import "ds2api/internal/prompt"
func buildClaudePromptTokenText(messages []any, thinkingEnabled bool) string {
return prompt.MessagesPrepareWithThinking(toMessageMaps(messages), thinkingEnabled)
}

View File

@@ -48,17 +48,18 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
return claudeNormalizedRequest{
Standard: promptcompat.StandardRequest{
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
ToolsRaw: toolsRequested,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
PromptTokenText: finalPrompt,
ToolsRaw: toolsRequested,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
},
NormalizedMessages: normalizedMessages,
}, nil

View File

@@ -15,10 +15,11 @@ type claudeStreamRuntime struct {
rc *http.ResponseController
canFlush bool
model string
toolNames []string
messages []any
toolsRaw any
model string
toolNames []string
messages []any
toolsRaw any
promptTokenText string
thinkingEnabled bool
searchEnabled bool
@@ -49,6 +50,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
promptTokenText string,
) *claudeStreamRuntime {
return &claudeStreamRuntime{
w: w,
@@ -62,6 +64,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
promptTokenText: promptTokenText,
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
thinkingBlockIndex: -1,
textBlockIndex: -1,

View File

@@ -42,7 +42,10 @@ func (s *claudeStreamRuntime) sendPing() {
}
func (s *claudeStreamRuntime) sendMessageStart() {
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
inputTokens := countClaudeInputTokensFromText(s.promptTokenText, s.model)
if inputTokens == 0 {
inputTokens = util.CountPromptTokens(fmt.Sprintf("%v", s.messages), s.model)
}
s.send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{

View File

@@ -109,7 +109,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{

View File

@@ -0,0 +1,20 @@
package claude
import (
"strings"
"ds2api/internal/promptcompat"
"ds2api/internal/util"
)
func countClaudeInputTokens(stdReq promptcompat.StandardRequest) int {
promptText := stdReq.PromptTokenText
if strings.TrimSpace(promptText) == "" {
promptText = stdReq.FinalPrompt
}
return countClaudeInputTokensFromText(promptText, stdReq.ResolvedModel)
}
func countClaudeInputTokensFromText(promptText, model string) int {
return util.CountPromptTokens(promptText, model)
}

View File

@@ -36,16 +36,17 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
passThrough := collectGeminiPassThrough(req)
return promptcompat.StandardRequest{
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}

View File

@@ -227,7 +227,7 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
usage := buildGeminiUsage(model, finalPrompt, finalThinking, finalText)
return map[string]any{
"candidates": []map[string]any{
{
@@ -245,10 +245,10 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
}
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)
reasoningTokens := util.CountOutputTokens(finalThinking, model)
completionTokens := util.CountOutputTokens(finalText, model)
return map[string]any{
"promptTokenCount": promptTokens,
"candidatesTokenCount": reasoningTokens + completionTokens,

View File

@@ -194,6 +194,6 @@ func (s *geminiStreamRuntime) finalize() {
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
})
}

View File

@@ -222,7 +222,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.sendFailedChunk(status, message, code)
return true
}
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText)
s.finalFinishReason = finishReason
s.finalUsage = usage
s.sendChunk(openaifmt.BuildChatStreamChunk(

View File

@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "chat.completions", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}

View File

@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -183,7 +183,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
}
}
if historySession != nil {
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText))
}
writeJSON(w, http.StatusOK, respBody)
}

View File

@@ -35,7 +35,6 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
if strings.TrimSpace(fileText) == "" {
return stdReq, errors.New("current user input file produced empty transcript")
}
result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{
Filename: currentInputFilename,
ContentType: currentInputContentType,
@@ -62,6 +61,9 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
stdReq.CurrentInputFileApplied = true
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
stdReq.FinalPrompt, stdReq.ToolNames = promptcompat.BuildOpenAIPrompt(messages, stdReq.ToolsRaw, "", stdReq.ToolChoice, stdReq.Thinking)
// Token accounting must reflect the actual downstream context:
// the uploaded IGNORE.txt file content + the neutral live prompt.
stdReq.PromptTokenText = fileText + "\n" + stdReq.FinalPrompt
return stdReq, nil
}

View File

@@ -14,6 +14,7 @@ import (
"ds2api/internal/auth"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/promptcompat"
"ds2api/internal/util"
)
func historySplitTestMessages() []any {
@@ -298,6 +299,52 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" {
t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs)
}
if !strings.Contains(out.PromptTokenText, "first turn content that is long enough") {
t.Fatalf("expected prompt token text to preserve original full context, got %q", out.PromptTokenText)
}
}
func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *testing.T) {
ds := &inlineUploadDSStub{}
h := &openAITestSurface{
Store: mockOpenAIConfig{
wideInput: true,
currentInputEnabled: true,
currentInputMin: 0,
thinkingInjection: boolPtr(true),
},
DS: ds,
}
req := map[string]any{
"model": "deepseek-v4-flash",
"messages": historySplitTestMessages(),
}
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
out, err := h.applyCurrentInputFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq)
if err != nil {
t.Fatalf("apply current input file failed: %v", err)
}
if out.FinalPrompt == stdReq.FinalPrompt {
t.Fatalf("expected live prompt to be rewritten after current input file")
}
// PromptTokenText must include the uploaded file content (which contains the full context)
// plus the neutral live prompt — reflecting the actual downstream token cost.
if !strings.Contains(out.PromptTokenText, "first user turn") || !strings.Contains(out.PromptTokenText, "latest user turn") {
t.Fatalf("expected prompt token text to contain file context with full conversation, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "[file content end]") || !strings.Contains(out.PromptTokenText, "[file name]: IGNORE") {
t.Fatalf("expected prompt token text to include IGNORE.txt file wrapper, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "Answer the latest user request directly.") {
t.Fatalf("expected prompt token text to also include neutral live prompt, got %q", out.PromptTokenText)
}
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") {
t.Fatalf("expected live prompt to hide original turns, got %q", out.FinalPrompt)
}
}
func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
@@ -434,6 +481,16 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
if len(refIDs) == 0 || refIDs[0] != "file-inline-1" {
t.Fatalf("expected uploaded current input file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"])
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response failed: %v", err)
}
usage, _ := body["usage"].(map[string]any)
promptTokens := int(usage["prompt_tokens"].(float64))
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
if promptTokens <= neutralCount {
t.Fatalf("expected prompt_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", promptTokens, neutralCount)
}
}
func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing.T) {
@@ -476,6 +533,16 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response failed: %v", err)
}
usage, _ := body["usage"].(map[string]any)
inputTokens := int(usage["input_tokens"].(float64))
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
if inputTokens <= neutralCount {
t.Fatalf("expected input_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", inputTokens, neutralCount)
}
}
func TestChatCompletionsCurrentInputFileMapsManagedAuthFailureTo401(t *testing.T) {

View File

@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}

View File

@@ -115,10 +115,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if stdReq.Stream {
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
return
}
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {

View File

@@ -39,20 +39,21 @@ func NormalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
refFileIDs := CollectOpenAIRefFileIDs(req)
return StandardRequest{
Surface: "openai_chat",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: resolvedModel,
ResponseModel: responseModel,
Messages: messagesRaw,
ToolsRaw: req["tools"],
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
RefFileIDs: refFileIDs,
PassThrough: passThrough,
Surface: "openai_chat",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: resolvedModel,
ResponseModel: responseModel,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
ToolsRaw: req["tools"],
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
RefFileIDs: refFileIDs,
PassThrough: passThrough,
}, nil
}
@@ -99,20 +100,21 @@ func NormalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
refFileIDs := CollectOpenAIRefFileIDs(req)
return StandardRequest{
Surface: "openai_responses",
RequestedModel: model,
ResolvedModel: resolvedModel,
ResponseModel: model,
Messages: messagesRaw,
ToolsRaw: req["tools"],
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
RefFileIDs: refFileIDs,
PassThrough: passThrough,
Surface: "openai_responses",
RequestedModel: model,
ResolvedModel: resolvedModel,
ResponseModel: model,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
ToolsRaw: req["tools"],
FinalPrompt: finalPrompt,
ToolNames: toolNames,
ToolChoice: toolPolicy,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
RefFileIDs: refFileIDs,
PassThrough: passThrough,
}, nil
}

View File

@@ -9,6 +9,7 @@ type StandardRequest struct {
ResponseModel string
Messages []any
HistoryText string
PromptTokenText string
CurrentInputFileApplied bool
ToolsRaw any
FinalPrompt string

View File

@@ -23,9 +23,9 @@ func BuildOpenAIChatCompletion(completionID, model, finalPrompt, finalThinking,
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected, nil)
messageObj["content"] = nil
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
promptTokens := CountPromptTokens(finalPrompt, model)
reasoningTokens := CountOutputTokens(finalThinking, model)
completionTokens := CountOutputTokens(finalText, model)
return map[string]any{
"id": completionID,
@@ -86,9 +86,9 @@ func BuildOpenAIResponseObject(responseID, model, finalPrompt, finalThinking, fi
"content": content,
})
}
promptTokens := EstimateTokens(finalPrompt)
reasoningTokens := EstimateTokens(finalThinking)
completionTokens := EstimateTokens(finalText)
promptTokens := CountPromptTokens(finalPrompt, model)
reasoningTokens := CountOutputTokens(finalThinking, model)
completionTokens := CountOutputTokens(finalText, model)
return map[string]any{
"id": responseID,
"type": "response",
@@ -140,8 +140,8 @@ func BuildClaudeMessageResponse(messageID, model string, normalizedMessages []an
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": map[string]any{
"input_tokens": EstimateTokens(fmt.Sprintf("%v", normalizedMessages)),
"output_tokens": EstimateTokens(finalThinking) + EstimateTokens(finalText),
"input_tokens": CountPromptTokens(fmt.Sprintf("%v", normalizedMessages), model),
"output_tokens": CountOutputTokens(finalThinking, model) + CountOutputTokens(finalText, model),
},
}
}

View File

@@ -0,0 +1,87 @@
package util
import (
"strings"
tiktoken "github.com/hupe1980/go-tiktoken"
)
const (
defaultTokenizerModel = "gpt-4o"
claudeTokenizerModel = "claude"
)
func CountPromptTokens(text, model string) int {
base := maxTokenCount(
EstimateTokens(text),
countWithTokenizer(text, model),
)
if base <= 0 {
return 0
}
return base + conservativePromptPadding(base)
}
func CountOutputTokens(text, model string) int {
base := maxTokenCount(
EstimateTokens(text),
countWithTokenizer(text, model),
)
if base <= 0 {
return 0
}
return base
}
func countWithTokenizer(text, model string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
if err != nil {
return 0
}
ids, _, err := encoding.Encode(text, nil, nil)
if err != nil {
return 0
}
return len(ids)
}
func tokenizerModelForCount(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {
return defaultTokenizerModel
}
switch {
case strings.HasPrefix(model, "claude"):
return claudeTokenizerModel
case strings.HasPrefix(model, "gpt-4"), strings.HasPrefix(model, "gpt-5"), strings.HasPrefix(model, "o1"), strings.HasPrefix(model, "o3"), strings.HasPrefix(model, "o4"):
return defaultTokenizerModel
case strings.HasPrefix(model, "deepseek-v4"):
return defaultTokenizerModel
case strings.HasPrefix(model, "deepseek"):
return defaultTokenizerModel
default:
return defaultTokenizerModel
}
}
func conservativePromptPadding(base int) int {
padding := base / 50
if padding < 4 {
padding = 4
}
return padding
}
func maxTokenCount(values ...int) int {
best := 0
for _, v := range values {
if v > best {
best = v
}
}
return best
}