Merge pull request #372 from shern-point/feat/accurate-context-token-length

Feat/accurate context token length
This commit is contained in:
CJACK.
2026-04-30 02:11:32 +08:00
committed by GitHub
30 changed files with 341 additions and 113 deletions

View File

@@ -206,6 +206,7 @@ func (h *Handler) handleClaudeStreamRealtime(w http.ResponseWriter, r *http.Requ
h.compatStripReferenceMarkers(),
toolNames,
toolsRaw,
buildClaudePromptTokenText(messages, thinkingEnabled),
)
streamRuntime.sendMessageStart()

View File

@@ -3,8 +3,6 @@ package claude
import (
"encoding/json"
"net/http"
"ds2api/internal/util"
)
func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
@@ -26,26 +24,11 @@ func (h *Handler) CountTokens(w http.ResponseWriter, r *http.Request) {
writeClaudeError(w, http.StatusBadRequest, "Request must include 'model' and 'messages'.")
return
}
inputTokens := 0
if sys, ok := req["system"].(string); ok {
inputTokens += util.EstimateTokens(sys)
}
for _, item := range messages {
msg, ok := item.(map[string]any)
if !ok {
continue
}
inputTokens += 2
inputTokens += util.EstimateTokens(extractMessageContent(msg["content"]))
}
if tools, ok := req["tools"].([]any); ok {
for _, t := range tools {
b, _ := json.Marshal(t)
inputTokens += util.EstimateTokens(string(b))
}
}
if inputTokens < 1 {
inputTokens = 1
normalized, err := normalizeClaudeRequest(h.Store, req)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, err.Error())
return
}
inputTokens := countClaudeInputTokens(normalized.Standard)
writeJSON(w, http.StatusOK, map[string]any{"input_tokens": inputTokens})
}

View File

@@ -0,0 +1,7 @@
package claude
import "ds2api/internal/prompt"
func buildClaudePromptTokenText(messages []any, thinkingEnabled bool) string {
return prompt.MessagesPrepareWithThinking(toMessageMaps(messages), thinkingEnabled)
}

View File

@@ -48,17 +48,18 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
return claudeNormalizedRequest{
Standard: promptcompat.StandardRequest{
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
ToolsRaw: toolsRequested,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
Surface: "anthropic_messages",
RequestedModel: strings.TrimSpace(model),
ResolvedModel: dsModel,
ResponseModel: strings.TrimSpace(model),
Messages: payload["messages"].([]any),
PromptTokenText: finalPrompt,
ToolsRaw: toolsRequested,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: util.ToBool(req["stream"]),
Thinking: thinkingEnabled,
Search: searchEnabled,
},
NormalizedMessages: normalizedMessages,
}, nil

View File

@@ -15,10 +15,11 @@ type claudeStreamRuntime struct {
rc *http.ResponseController
canFlush bool
model string
toolNames []string
messages []any
toolsRaw any
model string
toolNames []string
messages []any
toolsRaw any
promptTokenText string
thinkingEnabled bool
searchEnabled bool
@@ -49,6 +50,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers bool,
toolNames []string,
toolsRaw any,
promptTokenText string,
) *claudeStreamRuntime {
return &claudeStreamRuntime{
w: w,
@@ -62,6 +64,7 @@ func newClaudeStreamRuntime(
stripReferenceMarkers: stripReferenceMarkers,
toolNames: toolNames,
toolsRaw: toolsRaw,
promptTokenText: promptTokenText,
messageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()),
thinkingBlockIndex: -1,
textBlockIndex: -1,

View File

@@ -42,7 +42,10 @@ func (s *claudeStreamRuntime) sendPing() {
}
func (s *claudeStreamRuntime) sendMessageStart() {
inputTokens := util.EstimateTokens(fmt.Sprintf("%v", s.messages))
inputTokens := countClaudeInputTokensFromText(s.promptTokenText, s.model)
if inputTokens == 0 {
inputTokens = util.CountPromptTokens(fmt.Sprintf("%v", s.messages), s.model)
}
s.send("message_start", map[string]any{
"type": "message_start",
"message": map[string]any{

View File

@@ -109,7 +109,7 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
}
}
outputTokens := util.EstimateTokens(finalThinking) + util.EstimateTokens(finalText)
outputTokens := util.CountOutputTokens(finalThinking, s.model) + util.CountOutputTokens(finalText, s.model)
s.send("message_delta", map[string]any{
"type": "message_delta",
"delta": map[string]any{

View File

@@ -0,0 +1,20 @@
package claude
import (
"strings"
"ds2api/internal/promptcompat"
"ds2api/internal/util"
)
func countClaudeInputTokens(stdReq promptcompat.StandardRequest) int {
promptText := stdReq.PromptTokenText
if strings.TrimSpace(promptText) == "" {
promptText = stdReq.FinalPrompt
}
return countClaudeInputTokensFromText(promptText, stdReq.ResolvedModel)
}
func countClaudeInputTokensFromText(promptText, model string) int {
return util.CountPromptTokens(promptText, model)
}

View File

@@ -36,16 +36,17 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin
passThrough := collectGeminiPassThrough(req)
return promptcompat.StandardRequest{
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
Surface: "google_gemini",
RequestedModel: requestedModel,
ResolvedModel: resolvedModel,
ResponseModel: requestedModel,
Messages: messagesRaw,
PromptTokenText: finalPrompt,
FinalPrompt: finalPrompt,
ToolNames: toolNames,
Stream: stream,
Thinking: thinkingEnabled,
Search: searchEnabled,
PassThrough: passThrough,
}, nil
}

View File

@@ -227,7 +227,7 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames)
usage := buildGeminiUsage(finalPrompt, finalThinking, finalText)
usage := buildGeminiUsage(model, finalPrompt, finalThinking, finalText)
return map[string]any{
"candidates": []map[string]any{
{
@@ -245,10 +245,10 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final
}
//nolint:unused // retained for native Gemini non-stream handling path.
func buildGeminiUsage(finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.EstimateTokens(finalPrompt)
reasoningTokens := util.EstimateTokens(finalThinking)
completionTokens := util.EstimateTokens(finalText)
func buildGeminiUsage(model, finalPrompt, finalThinking, finalText string) map[string]any {
promptTokens := util.CountPromptTokens(finalPrompt, model)
reasoningTokens := util.CountOutputTokens(finalThinking, model)
completionTokens := util.CountOutputTokens(finalText, model)
return map[string]any{
"promptTokenCount": promptTokens,
"candidatesTokenCount": reasoningTokens + completionTokens,

View File

@@ -194,6 +194,6 @@ func (s *geminiStreamRuntime) finalize() {
},
},
"modelVersion": s.model,
"usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText),
"usageMetadata": buildGeminiUsage(s.model, s.finalPrompt, finalThinking, finalText),
})
}

View File

@@ -222,7 +222,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
s.sendFailedChunk(status, message, code)
return true
}
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText)
s.finalFinishReason = finishReason
s.finalUsage = usage
s.sendChunk(openaifmt.BuildChatStreamChunk(

View File

@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "chat.completions", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}

View File

@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
return
}
if stdReq.Stream {
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
return
}
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
}
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
@@ -183,7 +183,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
}
}
if historySession != nil {
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText))
}
writeJSON(w, http.StatusOK, respBody)
}

View File

@@ -35,7 +35,6 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
if strings.TrimSpace(fileText) == "" {
return stdReq, errors.New("current user input file produced empty transcript")
}
result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{
Filename: currentInputFilename,
ContentType: currentInputContentType,
@@ -62,6 +61,9 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
stdReq.CurrentInputFileApplied = true
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
stdReq.FinalPrompt, stdReq.ToolNames = promptcompat.BuildOpenAIPrompt(messages, stdReq.ToolsRaw, "", stdReq.ToolChoice, stdReq.Thinking)
// Token accounting must reflect the actual downstream context:
// the uploaded IGNORE.txt file content + the neutral live prompt.
stdReq.PromptTokenText = fileText + "\n" + stdReq.FinalPrompt
return stdReq, nil
}

View File

@@ -14,6 +14,7 @@ import (
"ds2api/internal/auth"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/promptcompat"
"ds2api/internal/util"
)
func historySplitTestMessages() []any {
@@ -298,6 +299,52 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" {
t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs)
}
if !strings.Contains(out.PromptTokenText, "first turn content that is long enough") {
t.Fatalf("expected prompt token text to preserve original full context, got %q", out.PromptTokenText)
}
}
func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *testing.T) {
ds := &inlineUploadDSStub{}
h := &openAITestSurface{
Store: mockOpenAIConfig{
wideInput: true,
currentInputEnabled: true,
currentInputMin: 0,
thinkingInjection: boolPtr(true),
},
DS: ds,
}
req := map[string]any{
"model": "deepseek-v4-flash",
"messages": historySplitTestMessages(),
}
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
if err != nil {
t.Fatalf("normalize failed: %v", err)
}
out, err := h.applyCurrentInputFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq)
if err != nil {
t.Fatalf("apply current input file failed: %v", err)
}
if out.FinalPrompt == stdReq.FinalPrompt {
t.Fatalf("expected live prompt to be rewritten after current input file")
}
// PromptTokenText must include the uploaded file content (which contains the full context)
// plus the neutral live prompt — reflecting the actual downstream token cost.
if !strings.Contains(out.PromptTokenText, "first user turn") || !strings.Contains(out.PromptTokenText, "latest user turn") {
t.Fatalf("expected prompt token text to contain file context with full conversation, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "[file content end]") || !strings.Contains(out.PromptTokenText, "[file name]: IGNORE") {
t.Fatalf("expected prompt token text to include IGNORE.txt file wrapper, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "Answer the latest user request directly.") {
t.Fatalf("expected prompt token text to also include neutral live prompt, got %q", out.PromptTokenText)
}
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") {
t.Fatalf("expected live prompt to hide original turns, got %q", out.FinalPrompt)
}
}
func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
@@ -434,6 +481,16 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
if len(refIDs) == 0 || refIDs[0] != "file-inline-1" {
t.Fatalf("expected uploaded current input file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"])
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response failed: %v", err)
}
usage, _ := body["usage"].(map[string]any)
promptTokens := int(usage["prompt_tokens"].(float64))
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
if promptTokens <= neutralCount {
t.Fatalf("expected prompt_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", promptTokens, neutralCount)
}
}
func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing.T) {
@@ -476,6 +533,16 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("decode response failed: %v", err)
}
usage, _ := body["usage"].(map[string]any)
inputTokens := int(usage["input_tokens"].(float64))
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
if inputTokens <= neutralCount {
t.Fatalf("expected input_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", inputTokens, neutralCount)
}
}
func TestChatCompletionsCurrentInputFileMapsManagedAuthFailureTo401(t *testing.T) {

View File

@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
return
}
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
currentResp = nextResp
}
}

View File

@@ -115,10 +115,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
if stdReq.Stream {
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
return
}
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
}
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {