mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-21 00:17:44 +08:00
Merge pull request #372 from shern-point/feat/accurate-context-token-length
Feat/accurate context token length
This commit is contained in:
@@ -222,7 +222,7 @@ func (s *chatStreamRuntime) finalize(finishReason string, deferEmptyOutput bool)
|
||||
s.sendFailedChunk(status, message, code)
|
||||
return true
|
||||
}
|
||||
usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText)
|
||||
usage := openaifmt.BuildChatUsageForModel(s.model, s.finalPrompt, finalThinking, finalText)
|
||||
s.finalFinishReason = finishReason
|
||||
s.finalUsage = usage
|
||||
s.sendChunk(openaifmt.BuildChatStreamChunk(
|
||||
|
||||
@@ -72,7 +72,7 @@ func (h *Handler) handleNonStreamWithRetry(w http.ResponseWriter, ctx context.Co
|
||||
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "chat.completions", "stream", false, "retry_attempt", attempts, "error", err)
|
||||
return
|
||||
}
|
||||
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
|
||||
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
|
||||
currentResp = nextResp
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,10 +109,10 @@ func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if stdReq.Stream {
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
h.handleStreamWithRetry(w, r, a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
return
|
||||
}
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
h.handleNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, sessionID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, historySession)
|
||||
}
|
||||
|
||||
func (h *Handler) autoDeleteRemoteSession(ctx context.Context, a *auth.RequestAuth, sessionID string) {
|
||||
@@ -183,7 +183,7 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, resp *http.Response, co
|
||||
}
|
||||
}
|
||||
if historySession != nil {
|
||||
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsage(finalPrompt, finalThinking, finalText))
|
||||
historySession.success(http.StatusOK, finalThinking, finalText, finishReason, openaifmt.BuildChatUsageForModel(model, finalPrompt, finalThinking, finalText))
|
||||
}
|
||||
writeJSON(w, http.StatusOK, respBody)
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
|
||||
if strings.TrimSpace(fileText) == "" {
|
||||
return stdReq, errors.New("current user input file produced empty transcript")
|
||||
}
|
||||
|
||||
result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{
|
||||
Filename: currentInputFilename,
|
||||
ContentType: currentInputContentType,
|
||||
@@ -62,6 +61,9 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
|
||||
stdReq.CurrentInputFileApplied = true
|
||||
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
|
||||
stdReq.FinalPrompt, stdReq.ToolNames = promptcompat.BuildOpenAIPrompt(messages, stdReq.ToolsRaw, "", stdReq.ToolChoice, stdReq.Thinking)
|
||||
// Token accounting must reflect the actual downstream context:
|
||||
// the uploaded IGNORE.txt file content + the neutral live prompt.
|
||||
stdReq.PromptTokenText = fileText + "\n" + stdReq.FinalPrompt
|
||||
return stdReq, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"ds2api/internal/auth"
|
||||
dsclient "ds2api/internal/deepseek/client"
|
||||
"ds2api/internal/promptcompat"
|
||||
"ds2api/internal/util"
|
||||
)
|
||||
|
||||
func historySplitTestMessages() []any {
|
||||
@@ -298,6 +299,52 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
|
||||
if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs)
|
||||
}
|
||||
if !strings.Contains(out.PromptTokenText, "first turn content that is long enough") {
|
||||
t.Fatalf("expected prompt token text to preserve original full context, got %q", out.PromptTokenText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *testing.T) {
|
||||
ds := &inlineUploadDSStub{}
|
||||
h := &openAITestSurface{
|
||||
Store: mockOpenAIConfig{
|
||||
wideInput: true,
|
||||
currentInputEnabled: true,
|
||||
currentInputMin: 0,
|
||||
thinkingInjection: boolPtr(true),
|
||||
},
|
||||
DS: ds,
|
||||
}
|
||||
req := map[string]any{
|
||||
"model": "deepseek-v4-flash",
|
||||
"messages": historySplitTestMessages(),
|
||||
}
|
||||
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
|
||||
if err != nil {
|
||||
t.Fatalf("normalize failed: %v", err)
|
||||
}
|
||||
|
||||
out, err := h.applyCurrentInputFile(context.Background(), &auth.RequestAuth{DeepSeekToken: "token"}, stdReq)
|
||||
if err != nil {
|
||||
t.Fatalf("apply current input file failed: %v", err)
|
||||
}
|
||||
if out.FinalPrompt == stdReq.FinalPrompt {
|
||||
t.Fatalf("expected live prompt to be rewritten after current input file")
|
||||
}
|
||||
// PromptTokenText must include the uploaded file content (which contains the full context)
|
||||
// plus the neutral live prompt — reflecting the actual downstream token cost.
|
||||
if !strings.Contains(out.PromptTokenText, "first user turn") || !strings.Contains(out.PromptTokenText, "latest user turn") {
|
||||
t.Fatalf("expected prompt token text to contain file context with full conversation, got %q", out.PromptTokenText)
|
||||
}
|
||||
if !strings.Contains(out.PromptTokenText, "[file content end]") || !strings.Contains(out.PromptTokenText, "[file name]: IGNORE") {
|
||||
t.Fatalf("expected prompt token text to include IGNORE.txt file wrapper, got %q", out.PromptTokenText)
|
||||
}
|
||||
if !strings.Contains(out.PromptTokenText, "Answer the latest user request directly.") {
|
||||
t.Fatalf("expected prompt token text to also include neutral live prompt, got %q", out.PromptTokenText)
|
||||
}
|
||||
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") {
|
||||
t.Fatalf("expected live prompt to hide original turns, got %q", out.FinalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
|
||||
@@ -434,6 +481,16 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
|
||||
if len(refIDs) == 0 || refIDs[0] != "file-inline-1" {
|
||||
t.Fatalf("expected uploaded current input file to be first ref_file_id, got %#v", ds.completionReq["ref_file_ids"])
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
usage, _ := body["usage"].(map[string]any)
|
||||
promptTokens := int(usage["prompt_tokens"].(float64))
|
||||
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
|
||||
if promptTokens <= neutralCount {
|
||||
t.Fatalf("expected prompt_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", promptTokens, neutralCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing.T) {
|
||||
@@ -476,6 +533,16 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing
|
||||
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
|
||||
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
|
||||
}
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
usage, _ := body["usage"].(map[string]any)
|
||||
inputTokens := int(usage["input_tokens"].(float64))
|
||||
neutralCount := util.CountPromptTokens(promptText, "deepseek-v4-flash")
|
||||
if inputTokens <= neutralCount {
|
||||
t.Fatalf("expected input_tokens to exceed neutral live prompt count (includes file context), got=%d neutral=%d", inputTokens, neutralCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompletionsCurrentInputFileMapsManagedAuthFailureTo401(t *testing.T) {
|
||||
|
||||
@@ -68,7 +68,7 @@ func (h *Handler) handleResponsesNonStreamWithRetry(w http.ResponseWriter, ctx c
|
||||
config.Logger.Warn("[openai_empty_retry] retry request failed", "surface", "responses", "stream", false, "retry_attempt", attempts, "error", err)
|
||||
return
|
||||
}
|
||||
usagePrompt = usagePromptWithEmptyOutputRetry(finalPrompt, attempts)
|
||||
usagePrompt = usagePromptWithEmptyOutputRetry(usagePrompt, attempts)
|
||||
currentResp = nextResp
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,10 +115,10 @@ func (h *Handler) Responses(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
responseID := "resp_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
if stdReq.Stream {
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesStreamWithRetry(w, r, a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
return
|
||||
}
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.FinalPrompt, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
h.handleResponsesNonStreamWithRetry(w, r.Context(), a, resp, payload, pow, owner, responseID, stdReq.ResponseModel, stdReq.PromptTokenText, stdReq.Thinking, stdReq.Search, stdReq.ToolNames, stdReq.ToolsRaw, stdReq.ToolChoice, traceID)
|
||||
}
|
||||
|
||||
func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Response, owner, responseID, model, finalPrompt string, thinkingEnabled, searchEnabled bool, toolNames []string, toolsRaw any, toolChoice promptcompat.ToolChoicePolicy, traceID string) {
|
||||
|
||||
Reference in New Issue
Block a user