Merge branch 'CJackHwang:main' into main

This commit is contained in:
VanceHud
2026-05-02 12:19:09 +08:00
committed by GitHub
61 changed files with 1720 additions and 346 deletions

View File

@@ -23,6 +23,7 @@ type UploadFileRequest struct {
Filename string
ContentType string
Purpose string
ModelType string
Data []byte
}
@@ -54,6 +55,7 @@ func (c *Client) UploadFile(ctx context.Context, a *auth.RequestAuth, req Upload
contentType = "application/octet-stream"
}
purpose := strings.TrimSpace(req.Purpose)
modelType := strings.ToLower(strings.TrimSpace(req.ModelType))
body, contentTypeHeader, err := buildUploadMultipartBody(filename, contentType, req.Data)
if err != nil {
return nil, err
@@ -64,6 +66,9 @@ func (c *Client) UploadFile(ctx context.Context, a *auth.RequestAuth, req Upload
"purpose": purpose,
"bytes": len(req.Data),
}
if modelType != "" {
capturePayload["model_type"] = modelType
}
captureSession := c.capture.Start("deepseek_upload_file", dsprotocol.DeepSeekUploadFileURL, a.AccountID, capturePayload)
attempts := 0
refreshed := false
@@ -81,6 +86,9 @@ func (c *Client) UploadFile(ctx context.Context, a *auth.RequestAuth, req Upload
}
headers := c.authHeaders(a.DeepSeekToken)
headers["Content-Type"] = contentTypeHeader
if modelType != "" {
headers["x-model-type"] = modelType
}
headers["x-ds-pow-response"] = powHeader
headers["x-file-size"] = strconv.Itoa(len(req.Data))
headers["x-thinking-enabled"] = "1"

View File

@@ -82,6 +82,7 @@ func TestUploadFileUsesUploadTargetPowAndMultipartHeaders(t *testing.T) {
var seenTargetPath string
var seenContentType string
var seenFileSize string
var seenModelType string
var seenBody string
call := 0
client := &Client{
@@ -96,6 +97,7 @@ func TestUploadFileUsesUploadTargetPowAndMultipartHeaders(t *testing.T) {
seenPow = req.Header.Get("x-ds-pow-response")
seenContentType = req.Header.Get("Content-Type")
seenFileSize = req.Header.Get("x-file-size")
seenModelType = req.Header.Get("x-model-type")
seenBody = string(bodyBytes)
return &http.Response{StatusCode: http.StatusOK, Header: make(http.Header), Body: io.NopCloser(strings.NewReader(uploadResponse)), Request: req}, nil
default:
@@ -112,6 +114,7 @@ func TestUploadFileUsesUploadTargetPowAndMultipartHeaders(t *testing.T) {
Filename: "demo.txt",
ContentType: "text/plain",
Purpose: "assistants",
ModelType: "vision",
Data: []byte("hello"),
}, 1)
if err != nil {
@@ -140,6 +143,9 @@ func TestUploadFileUsesUploadTargetPowAndMultipartHeaders(t *testing.T) {
if seenFileSize != "5" {
t.Fatalf("expected x-file-size=5, got %q", seenFileSize)
}
if seenModelType != "vision" {
t.Fatalf("expected x-model-type=vision, got %q", seenModelType)
}
if !strings.HasPrefix(seenContentType, "multipart/form-data; boundary=") {
t.Fatalf("expected multipart content type, got %q", seenContentType)
}

View File

@@ -159,6 +159,6 @@ func toStringSet(in []string) map[string]struct{} {
const (
KeepAliveTimeout = 5
StreamIdleTimeout = 90
MaxKeepaliveCount = 10
StreamIdleTimeout = 300
MaxKeepaliveCount = 40
)

View File

@@ -3,12 +3,14 @@ package claude
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"ds2api/internal/config"
"ds2api/internal/httpapi/requestbody"
streamengine "ds2api/internal/stream"
"ds2api/internal/translatorcliproxy"
"ds2api/internal/util"
@@ -33,7 +35,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid body")
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
} else {
writeClaudeError(w, http.StatusBadRequest, "invalid body")
}
return true
}
var req map[string]any

View File

@@ -2,8 +2,8 @@ package gemini
import (
"bytes"
"ds2api/internal/toolcall"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@@ -11,7 +11,9 @@ import (
"github.com/go-chi/chi/v5"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/translatorcliproxy"
"ds2api/internal/util"
@@ -32,7 +34,11 @@ func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request,
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
} else {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
}
return true
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))

View File

@@ -311,16 +311,16 @@ func TestChatCompletionsCurrentInputFilePersistsNeutralPrompt(t *testing.T) {
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected current input upload to happen, got %d", len(ds.uploadCalls))
}
if ds.uploadCalls[0].Filename != "history.txt" {
t.Fatalf("expected history.txt upload, got %q", ds.uploadCalls[0].Filename)
if ds.uploadCalls[0].Filename != "DS2API_HISTORY.txt" {
t.Fatalf("expected DS2API_HISTORY.txt upload, got %q", ds.uploadCalls[0].Filename)
}
if full.HistoryText != string(ds.uploadCalls[0].Data) {
t.Fatalf("expected uploaded current input file to be persisted in history text")
}
if len(full.Messages) != 1 {
t.Fatalf("expected neutral prompt to be the only persisted message, got %#v", full.Messages)
t.Fatalf("expected continuation prompt to be the only persisted message, got %#v", full.Messages)
}
if !strings.Contains(full.Messages[0].Content, "Answer the latest user request directly.") {
t.Fatalf("expected neutral prompt to be persisted, got %#v", full.Messages[0])
if !strings.Contains(full.Messages[0].Content, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation prompt to be persisted, got %#v", full.Messages[0])
}
}

View File

@@ -173,6 +173,15 @@ func (s *chatStreamRuntime) sendFailedChunk(status int, message, code string) {
s.sendDone()
}
func (s *chatStreamRuntime) markContextCancelled() {
s.finalErrorStatus = 499
s.finalErrorMessage = "request context cancelled"
s.finalErrorCode = string(streamengine.StopReasonContextCancelled)
s.finalThinking = s.thinking.String()
s.finalText = cleanVisibleOutput(s.text.String(), s.stripReferenceMarkers)
s.finalFinishReason = string(streamengine.StopReasonContextCancelled)
}
func (s *chatStreamRuntime) resetStreamToolCallState() {
s.streamToolCallIDs = map[int]string{}
s.streamToolNames = map[int]string{}

View File

@@ -247,11 +247,15 @@ func (h *Handler) consumeChatStreamAttempt(r *http.Request, resp *http.Response,
}
},
OnContextDone: func() {
streamRuntime.markContextCancelled()
if historySession != nil {
historySession.stopped(streamRuntime.thinking.String(), streamRuntime.text.String(), string(streamengine.StopReasonContextCancelled))
}
},
})
if streamRuntime.finalErrorCode == string(streamengine.StopReasonContextCancelled) {
return true, false
}
terminalWritten := streamRuntime.finalize(finalReason, allowDeferEmpty && finalReason != "content_filter")
if terminalWritten {
recordChatStreamHistory(streamRuntime, historySession)
@@ -283,6 +287,10 @@ func logChatStreamTerminal(streamRuntime *chatStreamRuntime, attempts int) {
if attempts > 0 {
source = "synthetic_retry"
}
if streamRuntime.finalErrorCode == string(streamengine.StopReasonContextCancelled) {
config.Logger.Info("[openai_empty_retry] terminal cancelled", "surface", "chat.completions", "stream", true, "retry_attempts", attempts, "error_code", streamRuntime.finalErrorCode)
return
}
if streamRuntime.finalErrorMessage != "" {
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "chat.completions", "stream", true, "retry_attempts", attempts, "success_source", "none", "error_code", streamRuntime.finalErrorCode)
return

View File

@@ -0,0 +1,85 @@
package chat
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"ds2api/internal/chathistory"
"ds2api/internal/stream"
)
func TestConsumeChatStreamAttemptMarksContextCancelledState(t *testing.T) {
historyStore := newTestChatHistoryStore(t)
entry, err := historyStore.Start(chathistory.StartParams{
CallerID: "caller:test",
Model: "deepseek-v4-flash",
Stream: true,
UserInput: "hello",
})
if err != nil {
t.Fatalf("start history failed: %v", err)
}
session := &chatHistorySession{
store: historyStore,
entryID: entry.ID,
startedAt: time.Now(),
lastPersist: time.Now(),
finalPrompt: "prompt",
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil).WithContext(ctx)
rec := httptest.NewRecorder()
streamRuntime := newChatStreamRuntime(
rec,
http.NewResponseController(rec),
true,
"cid-cancelled",
time.Now().Unix(),
"deepseek-v4-flash",
"prompt",
false,
false,
true,
nil,
nil,
false,
false,
)
resp := makeOpenAISSEHTTPResponse(
`data: {"p":"response/content","v":"hello"}`,
`data: [DONE]`,
)
h := &Handler{}
terminalWritten, retryable := h.consumeChatStreamAttempt(req, resp, streamRuntime, "text", false, session, true)
if !terminalWritten || retryable {
t.Fatalf("expected cancelled attempt to terminate without retry, got terminalWritten=%v retryable=%v", terminalWritten, retryable)
}
if got, want := streamRuntime.finalErrorCode, string(stream.StopReasonContextCancelled); got != want {
t.Fatalf("expected cancelled final error code %q, got %q", want, got)
}
if streamRuntime.finalErrorMessage == "" {
t.Fatalf("expected cancelled final error message to be preserved")
}
snapshot, err := historyStore.Snapshot()
if err != nil {
t.Fatalf("snapshot failed: %v", err)
}
if len(snapshot.Items) != 1 {
t.Fatalf("expected one history item, got %d", len(snapshot.Items))
}
full, err := historyStore.Get(snapshot.Items[0].ID)
if err != nil {
t.Fatalf("get detail failed: %v", err)
}
if full.Status != "stopped" {
t.Fatalf("expected stopped status, got %#v", full)
}
}

View File

@@ -130,8 +130,8 @@ func TestHandleVercelStreamPrepareAppliesCurrentInputFile(t *testing.T) {
t.Fatalf("expected payload object, got %#v", body["payload"])
}
promptText, _ := payload["prompt"].(string)
if !strings.Contains(promptText, "Answer the latest user request directly.") {
t.Fatalf("expected neutral prompt, got %s", promptText)
if !strings.Contains(promptText, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation prompt, got %s", promptText)
}
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected original turns hidden from prompt, got %s", promptText)

View File

@@ -94,6 +94,9 @@ func TestPreprocessInlineFileInputsReplacesDataURLAndCollectsRefFileIDs(t *testi
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected 1 upload, got %d", len(ds.uploadCalls))
}
if ds.uploadCalls[0].ModelType != "default" {
t.Fatalf("expected default model type when request omits model, got %q", ds.uploadCalls[0].ModelType)
}
if ds.lastCtx != ctx {
t.Fatalf("expected upload to use request context")
}
@@ -149,7 +152,7 @@ func TestPreprocessInlineFileInputsDeduplicatesIdenticalPayloads(t *testing.T) {
func TestChatCompletionsUploadsInlineFilesBeforeCompletion(t *testing.T) {
ds := &inlineUploadDSStub{}
h := &openAITestSurface{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
reqBody := `{"model":"deepseek-v4-flash","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
reqBody := `{"model":"deepseek-v4-vision","messages":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"image_url","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
@@ -163,6 +166,9 @@ func TestChatCompletionsUploadsInlineFilesBeforeCompletion(t *testing.T) {
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
}
if ds.uploadCalls[0].ModelType != "vision" {
t.Fatalf("expected vision model type for vision request, got %q", ds.uploadCalls[0].ModelType)
}
if ds.completionReq == nil {
t.Fatal("expected completion payload to be captured")
}
@@ -177,7 +183,7 @@ func TestResponsesUploadsInlineFilesBeforeCompletion(t *testing.T) {
h := &openAITestSurface{Store: mockOpenAIConfig{wideInput: true}, Auth: streamStatusAuthStub{}, DS: ds}
r := chi.NewRouter()
registerOpenAITestRoutes(r, h)
reqBody := `{"model":"deepseek-v4-flash","input":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"input_image","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
reqBody := `{"model":"deepseek-v4-pro","input":[{"role":"user","content":[{"type":"input_text","text":"hi"},{"type":"input_image","image_url":{"url":"data:image/png;base64,QUJDRA=="}}]}],"stream":false}`
req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody))
req.Header.Set("Authorization", "Bearer direct-token")
req.Header.Set("Content-Type", "application/json")
@@ -191,6 +197,9 @@ func TestResponsesUploadsInlineFilesBeforeCompletion(t *testing.T) {
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
}
if ds.uploadCalls[0].ModelType != "expert" {
t.Fatalf("expected expert model type for pro request, got %q", ds.uploadCalls[0].ModelType)
}
refIDs, _ := ds.completionReq["ref_file_ids"].([]any)
if len(refIDs) != 1 || refIDs[0] != "file-inline-1" {
t.Fatalf("unexpected completion ref_file_ids: %#v", ds.completionReq["ref_file_ids"])

View File

@@ -12,6 +12,7 @@ import (
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
@@ -42,6 +43,7 @@ type inlineUploadState struct {
ctx context.Context
handler *Handler
auth *auth.RequestAuth
modelType string
uploadedByID map[string]string
uploadCount int
inlineFileBytes int
@@ -58,10 +60,19 @@ func (h *Handler) PreprocessInlineFileInputs(ctx context.Context, a *auth.Reques
if h == nil || h.DS == nil || len(req) == 0 {
return nil
}
modelType := "default"
if requestedModel, ok := req["model"].(string); ok {
if resolvedModel, ok := config.ResolveModel(h.Store, requestedModel); ok {
if resolvedType, ok := config.GetModelType(resolvedModel); ok {
modelType = resolvedType
}
}
}
state := &inlineUploadState{
ctx: ctx,
handler: h,
auth: a,
modelType: modelType,
uploadedByID: map[string]string{},
}
for _, key := range []string{"messages", "input", "attachments"} {
@@ -174,6 +185,7 @@ func (s *inlineUploadState) uploadInlineFile(file inlineDecodedFile) (string, er
result, err := s.handler.DS.UploadFile(s.ctx, s.auth, dsclient.UploadFileRequest{
Filename: file.Filename,
ContentType: contentType,
ModelType: s.modelType,
Data: file.Data,
}, 3)
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"ds2api/internal/auth"
"ds2api/internal/chathistory"
"ds2api/internal/config"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/httpapi/openai/shared"
)
@@ -66,10 +67,12 @@ func (h *Handler) UploadFile(w http.ResponseWriter, r *http.Request) {
if contentType == "" && len(data) > 0 {
contentType = http.DetectContentType(data)
}
modelType := resolveUploadModelType(h.Store, r)
result, err := h.DS.UploadFile(r.Context(), a, dsclient.UploadFileRequest{
Filename: header.Filename,
ContentType: contentType,
Purpose: strings.TrimSpace(r.FormValue("purpose")),
ModelType: modelType,
Data: data,
}, 3)
if err != nil {
@@ -82,6 +85,32 @@ func (h *Handler) UploadFile(w http.ResponseWriter, r *http.Request) {
shared.WriteJSON(w, http.StatusOK, buildOpenAIFileObject(result))
}
func resolveUploadModelType(store shared.ConfigReader, r *http.Request) string {
for _, candidate := range []string{r.FormValue("model_type"), r.Header.Get("X-Model-Type")} {
if modelType := normalizeUploadModelType(candidate); modelType != "" {
return modelType
}
}
requestedModel := strings.TrimSpace(r.FormValue("model"))
if requestedModel != "" {
if resolvedModel, ok := config.ResolveModel(store, requestedModel); ok {
if modelType, ok := config.GetModelType(resolvedModel); ok {
return modelType
}
}
}
return "default"
}
func normalizeUploadModelType(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "default", "expert", "vision":
return strings.ToLower(strings.TrimSpace(raw))
default:
return ""
}
}
func buildOpenAIFileObject(result *dsclient.UploadFileResult) map[string]any {
if result == nil {
obj := map[string]any{

View File

@@ -77,7 +77,7 @@ func (m *filesRouteDSStub) DeleteAllSessionsForToken(_ context.Context, _ string
return nil
}
func newMultipartUploadRequest(t *testing.T, purpose string, filename string, data []byte) *http.Request {
func newMultipartUploadRequest(t *testing.T, purpose string, filename string, data []byte, model string) *http.Request {
t.Helper()
var body bytes.Buffer
writer := multipart.NewWriter(&body)
@@ -86,6 +86,11 @@ func newMultipartUploadRequest(t *testing.T, purpose string, filename string, da
t.Fatalf("write purpose failed: %v", err)
}
}
if model != "" {
if err := writer.WriteField("model", model); err != nil {
t.Fatalf("write model failed: %v", err)
}
}
part, err := writer.CreateFormFile("file", filename)
if err != nil {
t.Fatalf("create form file failed: %v", err)
@@ -108,7 +113,7 @@ func TestFilesRouteUploadSuccess(t *testing.T) {
r := chi.NewRouter()
registerOpenAITestRoutes(r, h)
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"), "deepseek-v4-vision")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
@@ -121,6 +126,9 @@ func TestFilesRouteUploadSuccess(t *testing.T) {
if ds.lastReq.Purpose != "assistants" {
t.Fatalf("expected purpose assistants, got %q", ds.lastReq.Purpose)
}
if ds.lastReq.ModelType != "vision" {
t.Fatalf("expected vision model type, got %q", ds.lastReq.ModelType)
}
if string(ds.lastReq.Data) != "hello world" {
t.Fatalf("unexpected uploaded data: %q", string(ds.lastReq.Data))
}
@@ -145,7 +153,7 @@ func TestFilesRouteUploadIncludesAccountIDForManagedAccount(t *testing.T) {
r := chi.NewRouter()
registerOpenAITestRoutes(r, h)
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"))
req := newMultipartUploadRequest(t, "assistants", "notes.txt", []byte("hello world"), "deepseek-v4-vision")
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)

View File

@@ -7,6 +7,7 @@ import (
"strings"
"ds2api/internal/auth"
"ds2api/internal/config"
dsclient "ds2api/internal/deepseek/client"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/promptcompat"
@@ -35,10 +36,15 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
if strings.TrimSpace(fileText) == "" {
return stdReq, errors.New("current user input file produced empty transcript")
}
modelType := "default"
if resolvedType, ok := config.GetModelType(stdReq.ResolvedModel); ok {
modelType = resolvedType
}
result, err := s.DS.UploadFile(ctx, a, dsclient.UploadFileRequest{
Filename: currentInputFilename,
ContentType: currentInputContentType,
Purpose: currentInputPurpose,
ModelType: modelType,
Data: []byte(fileText),
}, 3)
if err != nil {
@@ -62,7 +68,7 @@ func (s Service) ApplyCurrentInputFile(ctx context.Context, a *auth.RequestAuth,
stdReq.RefFileIDs = prependUniqueRefFileID(stdReq.RefFileIDs, fileID)
stdReq.FinalPrompt, stdReq.ToolNames = promptcompat.BuildOpenAIPrompt(messages, stdReq.ToolsRaw, "", stdReq.ToolChoice, stdReq.Thinking)
// Token accounting must reflect the actual downstream context:
// the uploaded history.txt file content + the neutral live prompt.
// the uploaded DS2API_HISTORY.txt file content + the continuation live prompt.
stdReq.PromptTokenText = fileText + "\n" + stdReq.FinalPrompt
return stdReq, nil
}
@@ -87,5 +93,5 @@ func latestUserInputForFile(messages []any) (int, string) {
}
func currentInputFilePrompt() string {
return "The current request and prior conversation context have already been provided. Answer the latest user request directly."
return "Continue from the latest state in the attached DS2API_HISTORY.txt context. Treat it as the current working state and answer the latest user request directly."
}

View File

@@ -61,26 +61,33 @@ func (streamStatusManagedAuthStub) DetermineCaller(_ *http.Request) (*auth.Reque
func (streamStatusManagedAuthStub) Release(_ *auth.RequestAuth) {}
func TestBuildOpenAICurrentInputContextTranscriptUsesInjectedFileWrapper(t *testing.T) {
func TestBuildOpenAICurrentInputContextTranscriptUsesNumberedHistorySections(t *testing.T) {
_, historyMessages := splitOpenAIHistoryMessages(historySplitTestMessages(), 1)
transcript := buildOpenAICurrentInputContextTranscript(historyMessages)
if strings.Contains(transcript, "[file content end]") || strings.Contains(transcript, "[file content begin]") || strings.Contains(transcript, "[file name]:") {
t.Fatalf("expected plain transcript without file wrapper tags, got %q", transcript)
t.Fatalf("expected transcript without file wrapper tags, got %q", transcript)
}
if !strings.Contains(transcript, "<begin▁of▁sentence>") {
t.Fatalf("expected serialized conversation markers, got %q", transcript)
if !strings.Contains(transcript, "# DS2API_HISTORY.txt") {
t.Fatalf("expected history transcript header, got %q", transcript)
}
if !strings.Contains(transcript, "first user turn") || !strings.Contains(transcript, "tool result") {
t.Fatalf("expected historical turns preserved, got %q", transcript)
if !strings.Contains(transcript, "Prior conversation history and tool progress.") {
t.Fatalf("expected history transcript description, got %q", transcript)
}
if !strings.Contains(transcript, "[reasoning_content]") || !strings.Contains(transcript, "hidden reasoning") {
t.Fatalf("expected reasoning block preserved, got %q", transcript)
for _, want := range []string{
"=== 1. USER ===",
"=== 2. ASSISTANT ===",
"=== 3. TOOL ===",
"first user turn",
"tool result",
"[reasoning_content]",
"hidden reasoning",
"<|DSML|tool_calls>",
} {
if !strings.Contains(transcript, want) {
t.Fatalf("expected transcript to contain %q, got %q", want, transcript)
}
}
if !strings.Contains(transcript, "<|DSML|tool_calls>") {
t.Fatalf("expected tool calls preserved, got %q", transcript)
}
}
func TestSplitOpenAIHistoryMessagesUsesLatestUserTurn(t *testing.T) {
@@ -220,7 +227,7 @@ func TestApplyCurrentInputFileDisabledPassThrough(t *testing.T) {
DS: ds,
}
req := map[string]any{
"model": "deepseek-v4-flash",
"model": "deepseek-v4-vision",
"messages": historySplitTestMessages(),
}
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
@@ -243,7 +250,7 @@ func TestApplyCurrentInputFileDisabledPassThrough(t *testing.T) {
}
}
func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T) {
func TestApplyCurrentInputFileUploadsFirstTurnWithNumberedHistoryTranscript(t *testing.T) {
ds := &inlineUploadDSStub{}
h := &openAITestSurface{
Store: mockOpenAIConfig{
@@ -273,15 +280,21 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
t.Fatalf("expected 1 current input upload, got %d", len(ds.uploadCalls))
}
upload := ds.uploadCalls[0]
if upload.Filename != "history.txt" {
if upload.Filename != "DS2API_HISTORY.txt" {
t.Fatalf("unexpected upload filename: %q", upload.Filename)
}
uploadedText := string(upload.Data)
if strings.Contains(uploadedText, "[file content end]") || strings.Contains(uploadedText, "[file content begin]") || strings.Contains(uploadedText, "[file name]:") {
t.Fatalf("expected uploaded transcript without file wrapper tags, got %q", uploadedText)
}
if !strings.Contains(uploadedText, "<begin▁of▁sentence><User>first turn content that is long enough") {
t.Fatalf("expected serialized current user turn markers, got %q", uploadedText)
for _, want := range []string{
"# DS2API_HISTORY.txt",
"=== 1. USER ===",
"first turn content that is long enough",
} {
if !strings.Contains(uploadedText, want) {
t.Fatalf("expected uploaded transcript to contain %q, got %q", want, uploadedText)
}
}
if !strings.Contains(uploadedText, promptcompat.ThinkingInjectionMarker) {
t.Fatalf("expected thinking injection in current input file, got %q", uploadedText)
@@ -290,11 +303,11 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
if strings.Contains(out.FinalPrompt, "first turn content that is long enough") {
t.Fatalf("expected current input text to be replaced in live prompt, got %s", out.FinalPrompt)
}
if strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "history.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
if strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
t.Fatalf("expected live prompt not to instruct file reads, got %s", out.FinalPrompt)
}
if !strings.Contains(out.FinalPrompt, "Answer the latest user request directly.") {
t.Fatalf("expected neutral continuation instruction in live prompt, got %s", out.FinalPrompt)
if !strings.Contains(out.FinalPrompt, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation-oriented prompt in live prompt, got %s", out.FinalPrompt)
}
if len(out.RefFileIDs) != 1 || out.RefFileIDs[0] != "file-inline-1" {
t.Fatalf("expected current input file id in ref_file_ids, got %#v", out.RefFileIDs)
@@ -302,6 +315,9 @@ func TestApplyCurrentInputFileUploadsFirstTurnWithInjectedWrapper(t *testing.T)
if !strings.Contains(out.PromptTokenText, "first turn content that is long enough") {
t.Fatalf("expected prompt token text to preserve original full context, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "# DS2API_HISTORY.txt") || !strings.Contains(out.PromptTokenText, "=== 1. USER ===") {
t.Fatalf("expected prompt token text to include numbered history transcript, got %q", out.PromptTokenText)
}
}
func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *testing.T) {
@@ -316,7 +332,7 @@ func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *test
DS: ds,
}
req := map[string]any{
"model": "deepseek-v4-flash",
"model": "deepseek-v4-vision",
"messages": historySplitTestMessages(),
}
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
@@ -337,10 +353,13 @@ func TestApplyCurrentInputFilePreservesFullContextPromptForTokenCounting(t *test
t.Fatalf("expected prompt token text to contain file context with full conversation, got %q", out.PromptTokenText)
}
if strings.Contains(out.PromptTokenText, "[file content end]") || strings.Contains(out.PromptTokenText, "[file name]:") {
t.Fatalf("expected prompt token text to use raw transcript without wrapper tags, got %q", out.PromptTokenText)
t.Fatalf("expected prompt token text to omit file wrapper tags, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "Answer the latest user request directly.") {
t.Fatalf("expected prompt token text to also include neutral live prompt, got %q", out.PromptTokenText)
if !strings.Contains(out.PromptTokenText, "# DS2API_HISTORY.txt") || !strings.Contains(out.PromptTokenText, "=== 1. SYSTEM ===") {
t.Fatalf("expected prompt token text to include numbered history transcript, got %q", out.PromptTokenText)
}
if !strings.Contains(out.PromptTokenText, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected prompt token text to also include continuation prompt, got %q", out.PromptTokenText)
}
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") {
t.Fatalf("expected live prompt to hide original turns, got %q", out.FinalPrompt)
@@ -359,7 +378,7 @@ func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
DS: ds,
}
req := map[string]any{
"model": "deepseek-v4-flash",
"model": "deepseek-v4-vision",
"messages": historySplitTestMessages(),
}
stdReq, err := promptcompat.NormalizeOpenAIChatRequest(h.Store, req, "")
@@ -378,20 +397,23 @@ func TestApplyCurrentInputFileUploadsFullContextFile(t *testing.T) {
t.Fatalf("expected one current input upload, got %d", len(ds.uploadCalls))
}
upload := ds.uploadCalls[0]
if upload.Filename != "history.txt" {
t.Fatalf("expected history.txt upload, got %q", upload.Filename)
if upload.Filename != "DS2API_HISTORY.txt" {
t.Fatalf("expected DS2API_HISTORY.txt upload, got %q", upload.Filename)
}
if upload.ModelType != "vision" {
t.Fatalf("expected vision model type for vision request, got %q", upload.ModelType)
}
uploadedText := string(upload.Data)
for _, want := range []string{"system instructions", "first user turn", "hidden reasoning", "tool result", "latest user turn", promptcompat.ThinkingInjectionMarker} {
for _, want := range []string{"# DS2API_HISTORY.txt", "=== 1. SYSTEM ===", "=== 2. USER ===", "=== 3. ASSISTANT ===", "=== 4. TOOL ===", "=== 5. USER ===", "system instructions", "first user turn", "hidden reasoning", "tool result", "latest user turn", promptcompat.ThinkingInjectionMarker} {
if !strings.Contains(uploadedText, want) {
t.Fatalf("expected full context file to contain %q, got %q", want, uploadedText)
}
}
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") || strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "history.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
t.Fatalf("expected live prompt to use only a neutral continuation instruction, got %s", out.FinalPrompt)
if strings.Contains(out.FinalPrompt, "first user turn") || strings.Contains(out.FinalPrompt, "latest user turn") || strings.Contains(out.FinalPrompt, "CURRENT_USER_INPUT.txt") || strings.Contains(out.FinalPrompt, "Read that file") {
t.Fatalf("expected live prompt to use only a continuation instruction, got %s", out.FinalPrompt)
}
if !strings.Contains(out.FinalPrompt, "Answer the latest user request directly.") {
t.Fatalf("expected neutral continuation instruction in live prompt, got %s", out.FinalPrompt)
if !strings.Contains(out.FinalPrompt, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation-oriented prompt in live prompt, got %s", out.FinalPrompt)
}
}
@@ -423,6 +445,9 @@ func TestApplyCurrentInputFileCarriesHistoryText(t *testing.T) {
if out.HistoryText != string(ds.uploadCalls[0].Data) {
t.Fatalf("expected current input file flow to preserve uploaded text in history, got %q", out.HistoryText)
}
if !strings.Contains(out.HistoryText, "# DS2API_HISTORY.txt") || !strings.Contains(out.HistoryText, "=== 1. SYSTEM ===") {
t.Fatalf("expected history text to use numbered transcript format, got %q", out.HistoryText)
}
}
func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing.T) {
@@ -454,7 +479,7 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
}
upload := ds.uploadCalls[0]
if upload.Filename != "history.txt" {
if upload.Filename != "DS2API_HISTORY.txt" {
t.Fatalf("unexpected upload filename: %q", upload.Filename)
}
if upload.Purpose != "assistants" {
@@ -462,7 +487,10 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
}
historyText := string(upload.Data)
if strings.Contains(historyText, "[file content end]") || strings.Contains(historyText, "[file content begin]") || strings.Contains(historyText, "[file name]:") {
t.Fatalf("expected plain history transcript without wrapper tags, got %s", historyText)
t.Fatalf("expected history transcript without file wrapper tags, got %s", historyText)
}
if !strings.Contains(historyText, "# DS2API_HISTORY.txt") || !strings.Contains(historyText, "=== 1. SYSTEM ===") {
t.Fatalf("expected history transcript to use numbered sections, got %s", historyText)
}
if !strings.Contains(historyText, "latest user turn") {
t.Fatalf("expected full context to include latest turn, got %s", historyText)
@@ -471,8 +499,8 @@ func TestChatCompletionsCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *t
t.Fatal("expected completion payload to be captured")
}
promptText, _ := ds.completionReq["prompt"].(string)
if !strings.Contains(promptText, "Answer the latest user request directly.") {
t.Fatalf("expected neutral completion prompt, got %s", promptText)
if !strings.Contains(promptText, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation-oriented prompt, got %s", promptText)
}
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
@@ -523,12 +551,16 @@ func TestResponsesCurrentInputFileUploadsContextAndKeepsNeutralPrompt(t *testing
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected 1 upload call, got %d", len(ds.uploadCalls))
}
historyText := string(ds.uploadCalls[0].Data)
if !strings.Contains(historyText, "# DS2API_HISTORY.txt") || !strings.Contains(historyText, "=== 1. SYSTEM ===") {
t.Fatalf("expected uploaded history text to use numbered transcript format, got %s", historyText)
}
if ds.completionReq == nil {
t.Fatal("expected completion payload to be captured")
}
promptText, _ := ds.completionReq["prompt"].(string)
if !strings.Contains(promptText, "Answer the latest user request directly.") {
t.Fatalf("expected neutral completion prompt, got %s", promptText)
if !strings.Contains(promptText, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") {
t.Fatalf("expected continuation-oriented prompt, got %s", promptText)
}
if strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("expected prompt to hide original turns, got %s", promptText)
@@ -669,11 +701,15 @@ func TestCurrentInputFileWorksAcrossAutoDeleteModes(t *testing.T) {
if len(ds.uploadCalls) != 1 {
t.Fatalf("expected current input upload for mode=%s, got %d", mode, len(ds.uploadCalls))
}
historyText := string(ds.uploadCalls[0].Data)
if !strings.Contains(historyText, "# DS2API_HISTORY.txt") || !strings.Contains(historyText, "=== 1. SYSTEM ===") {
t.Fatalf("expected uploaded history text to use numbered transcript format, got %s", historyText)
}
if ds.completionReq == nil {
t.Fatalf("expected completion payload for mode=%s", mode)
}
promptText, _ := ds.completionReq["prompt"].(string)
if !strings.Contains(promptText, "Answer the latest user request directly.") || strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
if !strings.Contains(promptText, "Continue from the latest state in the attached DS2API_HISTORY.txt context.") || strings.Contains(promptText, "first user turn") || strings.Contains(promptText, "latest user turn") {
t.Fatalf("unexpected prompt for mode=%s: %s", mode, promptText)
}
})

View File

@@ -222,7 +222,13 @@ func (h *Handler) consumeResponsesStreamAttempt(r *http.Request, resp *http.Resp
finalReason = "content_filter"
}
},
OnContextDone: func() {
streamRuntime.markContextCancelled()
},
})
if streamRuntime.finalErrorCode == string(streamengine.StopReasonContextCancelled) {
return true, false
}
terminalWritten := streamRuntime.finalize(finalReason, allowDeferEmpty && finalReason != "content_filter")
if terminalWritten {
return true, false
@@ -235,6 +241,10 @@ func logResponsesStreamTerminal(streamRuntime *responsesStreamRuntime, attempts
if attempts > 0 {
source = "synthetic_retry"
}
if streamRuntime.finalErrorCode == string(streamengine.StopReasonContextCancelled) {
config.Logger.Info("[openai_empty_retry] terminal cancelled", "surface", "responses", "stream", true, "retry_attempts", attempts, "error_code", streamRuntime.finalErrorCode)
return
}
if streamRuntime.failed {
config.Logger.Info("[openai_empty_retry] terminal empty output", "surface", "responses", "stream", true, "retry_attempts", attempts, "success_source", "none", "error_code", streamRuntime.finalErrorCode)
return

View File

@@ -0,0 +1,70 @@
package responses
import (
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ds2api/internal/promptcompat"
"ds2api/internal/stream"
)
func makeResponsesOpenAISSEHTTPResponse(lines ...string) *http.Response {
body := strings.Join(lines, "\n")
if !strings.HasSuffix(body, "\n") {
body += "\n"
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func TestConsumeResponsesStreamAttemptMarksContextCancelledState(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil).WithContext(ctx)
rec := httptest.NewRecorder()
streamRuntime := newResponsesStreamRuntime(
rec,
http.NewResponseController(rec),
true,
"resp-cancelled",
"deepseek-v4-flash",
"prompt",
false,
false,
true,
nil,
nil,
false,
false,
promptcompat.DefaultToolChoicePolicy(),
"",
nil,
)
resp := makeResponsesOpenAISSEHTTPResponse(
`data: {"p":"response/content","v":"hello"}`,
`data: [DONE]`,
)
h := &Handler{}
terminalWritten, retryable := h.consumeResponsesStreamAttempt(req, resp, streamRuntime, "text", false, true)
if !terminalWritten || retryable {
t.Fatalf("expected cancelled attempt to terminate without retry, got terminalWritten=%v retryable=%v", terminalWritten, retryable)
}
if !streamRuntime.failed {
t.Fatalf("expected cancelled response stream to be marked failed")
}
if got, want := streamRuntime.finalErrorCode, string(stream.StopReasonContextCancelled); got != want {
t.Fatalf("expected cancelled final error code %q, got %q", want, got)
}
if streamRuntime.finalErrorMessage == "" {
t.Fatalf("expected cancelled final error message to be preserved")
}
}

View File

@@ -139,6 +139,13 @@ func (s *responsesStreamRuntime) failResponse(status int, message, code string)
s.sendDone()
}
func (s *responsesStreamRuntime) markContextCancelled() {
s.failed = true
s.finalErrorStatus = 499
s.finalErrorMessage = "request context cancelled"
s.finalErrorCode = string(streamengine.StopReasonContextCancelled)
}
func (s *responsesStreamRuntime) finalize(finishReason string, deferEmptyOutput bool) bool {
s.failed = false
s.finalErrorStatus = 0

View File

@@ -0,0 +1,134 @@
package requestbody
import (
"bytes"
"errors"
"io"
"mime"
"net/http"
"strings"
"unicode/utf8"
)
var (
ErrInvalidUTF8Body = errors.New("invalid utf-8 request body")
errRequestBodyTooLarge = errors.New("request body too large")
)
const maxJSONUTF8ValidationSize = 100 << 20
// ValidateJSONUTF8 validates complete JSON request bodies before downstream
// decoders can silently replace malformed UTF-8 or stop before trailing bytes.
func ValidateJSONUTF8(next http.Handler) http.Handler {
if next == nil {
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if shouldValidateJSONBody(r) {
r.Body = validateAndReplayBody(r.Body)
}
next.ServeHTTP(w, r)
})
}
func shouldValidateJSONBody(r *http.Request) bool {
if r == nil || r.Body == nil {
return false
}
path := ""
if r.URL != nil {
path = r.URL.Path
}
return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path)
}
func isJSONContentType(raw string) bool {
raw = strings.TrimSpace(raw)
if raw == "" {
return false
}
mediaType, _, err := mime.ParseMediaType(raw)
if err != nil {
mediaType = raw
}
mediaType = strings.ToLower(strings.TrimSpace(mediaType))
return strings.Contains(mediaType, "json")
}
func isKnownJSONRequestPath(method, path string) bool {
switch strings.ToUpper(strings.TrimSpace(method)) {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
default:
return false
}
path = strings.TrimSpace(path)
if path == "" {
return false
}
switch {
case path == "/v1/chat/completions" || path == "/chat/completions":
return true
case path == "/v1/responses" || path == "/responses":
return true
case path == "/v1/embeddings" || path == "/embeddings":
return true
case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages":
return true
case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
return true
case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"):
return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent")
case strings.HasPrefix(path, "/admin/"):
return true
default:
return false
}
}
func validateAndReplayBody(body io.ReadCloser) io.ReadCloser {
if body == nil {
return body
}
raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1))
if err != nil {
return &errorReadCloser{err: err, closer: body}
}
if len(raw) > maxJSONUTF8ValidationSize {
return &errorReadCloser{err: errRequestBodyTooLarge, closer: body}
}
if !utf8.Valid(raw) {
return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body}
}
return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body}
}
type replayReadCloser struct {
*bytes.Reader
closer io.Closer
}
func (r *replayReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}
type errorReadCloser struct {
err error
closer io.Closer
}
func (r *errorReadCloser) Read([]byte) (int, error) {
if r == nil || r.err == nil {
return 0, io.EOF
}
return 0, r.err
}
func (r *errorReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}

View File

@@ -0,0 +1,158 @@
package requestbody
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type singleByteReadCloser struct {
data []byte
pos int
}
func (r *singleByteReadCloser) Read(p []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
p[0] = r.data[r.pos]
r.pos++
return 1, nil
}
func (r *singleByteReadCloser) Close() error {
return nil
}
func TestValidateJSONUTF8AllowsSplitMultibyteRunes(t *testing.T) {
body := []byte(`{"text":"你好"}`)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("unexpected decode error: %v", err)
}
w.WriteHeader(http.StatusNoContent)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", &singleByteReadCloser{data: body})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("expected 204 for valid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsInvalidBytesBeforeJSONDecode(t *testing.T) {
body := append([]byte(`{"text":"`), 0xff)
body = append(body, []byte(`"}`)...)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json; charset=utf-8")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsInvalidBytesWithoutJSONContentTypeOnKnownPath(t *testing.T) {
body := append([]byte(`{"text":"`), 0xff)
body = append(body, []byte(`"}`)...)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "text/plain")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsTrailingInvalidBytesAfterJSONValue(t *testing.T) {
body := append([]byte(`{"text":"ok"}`), 0xff)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestIsJSONContentType(t *testing.T) {
for _, raw := range []string{
"application/json",
"application/json; charset=utf-8",
"application/problem+json",
"application/vnd.api+json",
} {
if !isJSONContentType(raw) {
t.Fatalf("expected %q to be recognized as json", raw)
}
}
for _, raw := range []string{
"multipart/form-data; boundary=abc",
"text/plain",
"application/octet-stream",
} {
if isJSONContentType(raw) {
t.Fatalf("expected %q not to be recognized as json", raw)
}
}
}
func TestIsKnownJSONRequestPathIncludesGeminiStream(t *testing.T) {
if !isKnownJSONRequestPath(http.MethodPost, "/v1beta/models/gemini-pro:streamGenerateContent") {
t.Fatal("expected Gemini stream generate path to be recognized as json")
}
}

View File

@@ -846,10 +846,18 @@ function parseMarkupValue(raw, paramName = '') {
if (cdata.ok) {
const literal = parseJSONLiteralValue(cdata.value);
if (literal.ok) {
const literalArray = coerceArrayValue(literal.value, paramName);
if (literalArray.ok) {
return literalArray.value;
}
return literal.value;
}
const structured = parseStructuredCDATAParameterValue(paramName, cdata.value);
return structured.ok ? structured.value : cdata.value;
if (structured.ok) {
return structured.value;
}
const looseArray = parseLooseJSONArrayValue(cdata.value, paramName);
return looseArray.ok ? looseArray.value : cdata.value;
}
const s = toStringSafe(extractRawTagValue(raw)).trim();
if (!s) {
@@ -862,8 +870,14 @@ function parseMarkupValue(raw, paramName = '') {
return nested;
}
if (nested && typeof nested === 'object') {
const nestedArray = coerceArrayValue(nested, paramName);
if (nestedArray.ok) {
return nestedArray.value;
}
if (isOnlyRawValue(nested)) {
return toStringSafe(nested._raw);
const rawValue = toStringSafe(nested._raw);
const looseArray = parseLooseJSONArrayValue(rawValue, paramName);
return looseArray.ok ? looseArray.value : rawValue;
}
return nested;
}
@@ -871,8 +885,16 @@ function parseMarkupValue(raw, paramName = '') {
const literal = parseJSONLiteralValue(s);
if (literal.ok) {
const literalArray = coerceArrayValue(literal.value, paramName);
if (literalArray.ok) {
return literalArray.value;
}
return literal.value;
}
const looseArray = parseLooseJSONArrayValue(s, paramName);
if (looseArray.ok) {
return looseArray.value;
}
return s;
}
@@ -1008,6 +1030,226 @@ function parseJSONLiteralValue(raw) {
}
}
function parseLooseJSONArrayValue(raw, paramName = '') {
if (preservesCDATAStringParameter(paramName)) {
return { ok: false, value: null };
}
const s = toStringSafe(raw).trim();
if (!s) {
return { ok: false, value: null };
}
const candidate = parseLooseJSONArrayCandidate(s, paramName);
if (candidate.ok) {
return candidate;
}
const segments = splitTopLevelJSONValues(s);
if (segments.length < 2) {
return { ok: false, value: null };
}
const out = [];
for (const segment of segments) {
const parsed = parseLooseArrayElementValue(segment);
if (!parsed.ok) {
return { ok: false, value: null };
}
out.push(parsed.value);
}
return { ok: true, value: out };
}
function parseLooseJSONArrayCandidate(raw, paramName = '') {
const parsed = parseLooseArrayElementValue(raw);
if (!parsed.ok) {
return { ok: false, value: null };
}
return coerceArrayValue(parsed.value, paramName);
}
function parseLooseArrayElementValue(raw) {
const s = toStringSafe(raw).trim();
if (!s) {
return { ok: false, value: null };
}
const literal = parseJSONLiteralValue(s);
if (literal.ok) {
return literal;
}
const repairedBackslashes = repairInvalidJSONBackslashes(s);
if (repairedBackslashes !== s) {
try {
const parsed = JSON.parse(repairedBackslashes);
return { ok: true, value: parsed };
} catch (_err) {
// Fall through.
}
}
const repairedLoose = repairLooseJSON(s);
if (repairedLoose !== s) {
try {
const parsed = JSON.parse(repairedLoose);
return { ok: true, value: parsed };
} catch (_err) {
// Fall through.
}
}
if (s.includes('<') && s.includes('>')) {
const parsed = parseMarkupInput(s);
if (Array.isArray(parsed)) {
return { ok: true, value: parsed };
}
if (parsed && typeof parsed === 'object') {
return { ok: true, value: parsed };
}
}
return { ok: false, value: null };
}
function coerceArrayValue(value, paramName = '') {
if (Array.isArray(value)) {
return { ok: true, value };
}
if (!value || typeof value !== 'object') {
return { ok: false, value: null };
}
const keys = Object.keys(value);
if (keys.length !== 1) {
return { ok: false, value: null };
}
if (Object.prototype.hasOwnProperty.call(value, 'item')) {
const items = value.item;
const nested = coerceArrayValue(items, '');
return nested.ok ? nested : { ok: true, value: [items] };
}
if (paramName && Object.prototype.hasOwnProperty.call(value, paramName)) {
const nested = coerceArrayValue(value[paramName], '');
if (nested.ok) {
return nested;
}
}
return { ok: false, value: null };
}
function splitTopLevelJSONValues(raw) {
const s = toStringSafe(raw).trim();
if (!s) {
return [];
}
const values = [];
let start = 0;
let depth = 0;
let inString = false;
let escaped = false;
for (let i = 0; i < s.length; i += 1) {
const ch = s[i];
if (inString) {
if (escaped) {
escaped = false;
continue;
}
if (ch === '\\') {
escaped = true;
continue;
}
if (ch === '"') {
inString = false;
}
continue;
}
if (ch === '"') {
inString = true;
continue;
}
if (ch === '{' || ch === '[') {
depth += 1;
continue;
}
if (ch === '}' || ch === ']') {
if (depth > 0) {
depth -= 1;
}
continue;
}
if (ch === ',' && depth === 0) {
const segment = s.slice(start, i).trim();
if (!segment) {
return [];
}
values.push(segment);
start = i + 1;
}
}
const last = s.slice(start).trim();
if (!last) {
return [];
}
values.push(last);
return values.length > 1 ? values : [];
}
function repairInvalidJSONBackslashes(s) {
if (!s || !s.includes('\\')) {
return s;
}
let out = '';
for (let i = 0; i < s.length; i += 1) {
const ch = s[i];
if (ch !== '\\') {
out += ch;
continue;
}
if (i + 1 < s.length) {
const next = s[i + 1];
if ('"\\/bfnrt'.includes(next)) {
out += `\\${next}`;
i += 1;
continue;
}
if (next === 'u' && i + 5 < s.length) {
let isHex = true;
for (let j = 1; j <= 4; j += 1) {
const r = s[i + 1 + j];
if (!/[0-9a-fA-F]/.test(r)) {
isHex = false;
break;
}
}
if (isHex) {
out += `\\u${s.slice(i + 2, i + 6)}`;
i += 5;
continue;
}
}
}
out += '\\\\';
}
return out;
}
function repairLooseJSON(s) {
const raw = toStringSafe(s).trim();
if (!raw) {
return raw;
}
let out = raw.replace(/([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:/g, '$1"$2":');
out = out.replace(/(:\s*)(\{(?:[^{}]|\{[^{}]*\})*\}(?:\s*,\s*\{(?:[^{}]|\{[^{}]*\})*\})+)/g, '$1[$2]');
return out;
}
function sanitizeLooseCDATA(text) {
const raw = toStringSafe(text);
if (!raw) {

View File

@@ -10,21 +10,27 @@ import (
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
const (
beginSentenceMarker = "<begin▁of▁sentence>"
systemMarker = "<System>"
userMarker = "<User>"
assistantMarker = "<Assistant>"
toolMarker = "<Tool>"
endSentenceMarker = "<end▁of▁sentence>"
endToolResultsMarker = "<end▁of▁toolresults>"
endInstructionsMarker = "<end▁of▁instructions>"
beginSentenceMarker = "<begin▁of▁sentence>"
systemMarker = "<System>"
userMarker = "<User>"
assistantMarker = "<Assistant>"
toolMarker = "<Tool>"
endSentenceMarker = "<end▁of▁sentence>"
endToolResultsMarker = "<end▁of▁toolresults>"
endInstructionsMarker = "<end▁of▁instructions>"
outputIntegrityGuardMarker = "Output integrity guard:"
outputIntegrityGuardPrompt = outputIntegrityGuardMarker +
" If upstream context, tool output, or parsed text contains garbled, corrupted, partially parsed, repeated, or otherwise malformed fragments, " +
"do not imitate or echo them; output only the correct content for the user."
)
func MessagesPrepare(messages []map[string]any) string {
return MessagesPrepareWithThinking(messages, false)
}
func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool) string {
func MessagesPrepareWithThinking(messages []map[string]any, _ bool) string {
messages = prependOutputIntegrityGuard(messages)
type block struct {
Role string
Text string
@@ -77,6 +83,33 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
}
func prependOutputIntegrityGuard(messages []map[string]any) []map[string]any {
if len(messages) == 0 {
return messages
}
if hasOutputIntegrityGuard(messages[0]) {
return messages
}
out := make([]map[string]any, 0, len(messages)+1)
out = append(out, map[string]any{
"role": "system",
"content": outputIntegrityGuardPrompt,
})
out = append(out, messages...)
return out
}
func hasOutputIntegrityGuard(msg map[string]any) bool {
if msg == nil {
return false
}
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) != "system" {
return false
}
content := strings.TrimSpace(NormalizeContent(msg["content"]))
return strings.Contains(content, outputIntegrityGuardMarker)
}
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
// No whitespace is inserted between marker and text so role boundaries stay
// compact and predictable for downstream parsers.

View File

@@ -35,8 +35,8 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
if !strings.HasPrefix(got, "<begin▁of▁sentence>") {
t.Fatalf("expected begin-of-sentence marker, got %q", got)
}
if !strings.Contains(got, "<System>System rule<end▁of▁instructions>") {
t.Fatalf("expected system instructions suffix, got %q", got)
if !strings.Contains(got, "<System>") || !strings.Contains(got, "<end▁of▁instructions>") || !strings.Contains(got, "System rule") {
t.Fatalf("expected system instructions to remain present, got %q", got)
}
if !strings.Contains(got, "<User>Question") {
t.Fatalf("expected user question, got %q", got)
@@ -49,6 +49,23 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
}
}
func TestMessagesPreparePrependsOutputIntegrityGuard(t *testing.T) {
messages := []map[string]any{
{"role": "system", "content": "System rule"},
{"role": "user", "content": "Question"},
}
got := MessagesPrepare(messages)
if !strings.HasPrefix(got, beginSentenceMarker+systemMarker+outputIntegrityGuardPrompt) {
t.Fatalf("expected output integrity guard to be prepended, got %q", got)
}
if !strings.Contains(got, outputIntegrityGuardPrompt+"\n\nSystem rule") {
t.Fatalf("expected output integrity guard to precede system prompt content, got %q", got)
}
if !strings.Contains(got, "<User>Question") {
t.Fatalf("expected user question after guard, got %q", got)
}
}
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
got := NormalizeContent([]any{
map[string]any{"type": "text", "text": "", "content": "from-content"},

View File

@@ -1,35 +1,108 @@
package promptcompat
import (
"fmt"
"strings"
"ds2api/internal/prompt"
)
const CurrentInputContextFilename = "history.txt"
const CurrentInputContextFilename = "DS2API_HISTORY.txt"
const historyTranscriptTitle = "# DS2API_HISTORY.txt"
const historyTranscriptSummary = "Prior conversation history and tool progress."
func BuildOpenAIHistoryTranscript(messages []any) string {
return buildOpenAIInjectedFileTranscript(messages)
return buildOpenAIHistoryTranscript(messages)
}
func BuildOpenAICurrentUserInputTranscript(text string) string {
if strings.TrimSpace(text) == "" {
return ""
}
return BuildOpenAICurrentInputContextTranscript([]any{
return buildOpenAIHistoryTranscript([]any{
map[string]any{"role": "user", "content": text},
})
}
func BuildOpenAICurrentInputContextTranscript(messages []any) string {
return buildOpenAIInjectedFileTranscript(messages)
return buildOpenAIHistoryTranscript(messages)
}
func buildOpenAIInjectedFileTranscript(messages []any) string {
normalized := NormalizeOpenAIMessagesForPrompt(messages, "")
transcript := strings.TrimSpace(prompt.MessagesPrepare(normalized))
func buildOpenAIHistoryTranscript(messages []any) string {
if len(messages) == 0 {
return ""
}
var b strings.Builder
b.WriteString(historyTranscriptTitle)
b.WriteString("\n")
b.WriteString(historyTranscriptSummary)
b.WriteString("\n\n")
entry := 0
for _, raw := range messages {
msg, ok := raw.(map[string]any)
if !ok {
continue
}
role := normalizeOpenAIRoleForPrompt(strings.ToLower(strings.TrimSpace(asString(msg["role"]))))
content := strings.TrimSpace(buildOpenAIHistoryEntry(role, msg))
if content == "" {
continue
}
entry++
fmt.Fprintf(&b, "=== %d. %s ===\n%s\n\n", entry, strings.ToUpper(roleLabelForHistory(role)), content)
}
transcript := strings.TrimSpace(b.String())
if transcript == "" {
return ""
}
return transcript
return transcript + "\n"
}
func buildOpenAIHistoryEntry(role string, msg map[string]any) string {
switch role {
case "assistant":
return strings.TrimSpace(buildAssistantContentForPrompt(msg))
case "tool", "function":
return strings.TrimSpace(buildToolHistoryContent(msg))
case "system", "user":
return strings.TrimSpace(NormalizeOpenAIContentForPrompt(msg["content"]))
default:
return strings.TrimSpace(NormalizeOpenAIContentForPrompt(msg["content"]))
}
}
func buildToolHistoryContent(msg map[string]any) string {
content := strings.TrimSpace(NormalizeOpenAIContentForPrompt(msg["content"]))
parts := make([]string, 0, 2)
if name := strings.TrimSpace(asString(msg["name"])); name != "" {
parts = append(parts, "name="+name)
}
if callID := strings.TrimSpace(asString(msg["tool_call_id"])); callID != "" {
parts = append(parts, "tool_call_id="+callID)
}
header := ""
if len(parts) > 0 {
header = "[" + strings.Join(parts, " ") + "]"
}
switch {
case header != "" && content != "":
return header + "\n" + content
case header != "":
return header
default:
return content
}
}
func roleLabelForHistory(role string) string {
role = strings.ToLower(strings.TrimSpace(role))
switch role {
case "function":
return "tool"
case "":
return "unknown"
default:
return role
}
}

View File

@@ -88,6 +88,38 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
}
}
func TestBuildOpenAIFinalPromptPrependsOutputIntegrityGuard(t *testing.T) {
messages := []any{
map[string]any{"role": "system", "content": "You are helpful"},
map[string]any{"role": "user", "content": "请调用工具"},
}
tools := []any{
map[string]any{
"type": "function",
"function": map[string]any{
"name": "search",
"description": "search docs",
"parameters": map[string]any{
"type": "object",
},
},
},
}
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false)
guardIdx := strings.Index(finalPrompt, "Output integrity guard")
toolIdx := strings.Index(finalPrompt, "TOOL CALL FORMAT")
if guardIdx < 0 {
t.Fatalf("expected output integrity guard in final prompt, got: %q", finalPrompt)
}
if toolIdx < 0 {
t.Fatalf("expected tool instructions in final prompt, got: %q", finalPrompt)
}
if guardIdx > toolIdx {
t.Fatalf("expected output integrity guard to precede tool instructions, got: %q", finalPrompt)
}
}
func TestBuildOpenAIFinalPromptReadLikeToolIncludesCacheGuard(t *testing.T) {
messages := []any{
map[string]any{"role": "user", "content": "请读取文件"},

View File

@@ -27,6 +27,7 @@ import (
"ds2api/internal/httpapi/openai/files"
"ds2api/internal/httpapi/openai/responses"
"ds2api/internal/httpapi/openai/shared"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/webui"
)
@@ -75,6 +76,7 @@ func NewApp() (*App, error) {
r.Use(filteredLogger())
r.Use(middleware.Recoverer)
r.Use(cors)
r.Use(requestbody.ValidateJSONUTF8)
r.Use(timeout(0))
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {

View File

@@ -0,0 +1,89 @@
package server
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestJSONRequestsRejectInvalidUTF8BeforeDecode(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
app, err := NewApp()
if err != nil {
t.Fatalf("NewApp() error: %v", err)
}
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff)
body = append(body, []byte(`"}]}`)...)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("x-api-key", "direct-token")
rec := httptest.NewRecorder()
app.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
}
}
func TestKnownJSONRequestsRejectInvalidUTF8WithoutJSONContentType(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
app, err := NewApp()
if err != nil {
t.Fatalf("NewApp() error: %v", err)
}
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff)
body = append(body, []byte(`"}]}`)...)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "text/plain")
req.Header.Set("x-api-key", "direct-token")
rec := httptest.NewRecorder()
app.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
}
}
func TestJSONRequestsRejectTrailingInvalidUTF8AfterCompleteJSON(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
app, err := NewApp()
if err != nil {
t.Fatalf("NewApp() error: %v", err)
}
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"ok"}]}`), 0xff)
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", "direct-token")
rec := httptest.NewRecorder()
app.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
}
}

View File

@@ -0,0 +1,164 @@
package toolcall
import (
"encoding/json"
"html"
"strings"
)
func parseLooseJSONArrayValue(raw, paramName string) ([]any, bool) {
if preservesCDATAStringParameter(paramName) {
return nil, false
}
trimmed := strings.TrimSpace(html.UnescapeString(raw))
if trimmed == "" {
return nil, false
}
if parsed, ok := parseLooseJSONArrayCandidate(trimmed, paramName); ok {
return parsed, true
}
segments, ok := splitTopLevelJSONValues(trimmed)
if !ok {
return nil, false
}
out := make([]any, 0, len(segments))
for _, segment := range segments {
parsed, ok := parseLooseArrayElementValue(segment)
if !ok {
return nil, false
}
out = append(out, parsed)
}
return out, true
}
func parseLooseJSONArrayCandidate(raw, paramName string) ([]any, bool) {
parsed, ok := parseLooseArrayElementValue(raw)
if !ok {
return nil, false
}
return coerceArrayValue(parsed, paramName)
}
func parseLooseArrayElementValue(raw string) (any, bool) {
trimmed := strings.TrimSpace(html.UnescapeString(raw))
if trimmed == "" {
return nil, false
}
var parsed any
if err := json.Unmarshal([]byte(trimmed), &parsed); err == nil {
return parsed, true
}
repairedBackslashes := repairInvalidJSONBackslashes(trimmed)
if repairedBackslashes != trimmed {
if err := json.Unmarshal([]byte(repairedBackslashes), &parsed); err == nil {
return parsed, true
}
}
repairedLoose := RepairLooseJSON(trimmed)
if repairedLoose != trimmed {
if err := json.Unmarshal([]byte(repairedLoose), &parsed); err == nil {
return parsed, true
}
}
if strings.Contains(trimmed, "<") && strings.Contains(trimmed, ">") {
if parsedXML, ok := parseXMLFragmentValue(trimmed); ok {
return parsedXML, true
}
}
return nil, false
}
func coerceArrayValue(value any, paramName string) ([]any, bool) {
switch x := value.(type) {
case []any:
return x, true
case map[string]any:
if len(x) != 1 {
return nil, false
}
if items, ok := x["item"]; ok {
if arr, ok := coerceArrayValue(items, ""); ok {
return arr, true
}
return []any{items}, true
}
if paramName != "" {
if wrapped, ok := x[paramName]; ok {
if arr, ok := coerceArrayValue(wrapped, ""); ok {
return arr, true
}
}
}
}
return nil, false
}
func splitTopLevelJSONValues(raw string) ([]string, bool) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil, false
}
values := make([]string, 0, 2)
start := 0
depth := 0
inString := false
escaped := false
for i, r := range trimmed {
if inString {
if escaped {
escaped = false
continue
}
switch r {
case '\\':
escaped = true
case '"':
inString = false
}
continue
}
switch r {
case '"':
inString = true
case '{', '[':
depth++
case '}', ']':
if depth > 0 {
depth--
}
case ',':
if depth == 0 {
segment := strings.TrimSpace(trimmed[start:i])
if segment == "" {
return nil, false
}
values = append(values, segment)
start = i + 1
}
}
}
last := strings.TrimSpace(trimmed[start:])
if last == "" {
return nil, false
}
values = append(values, last)
if len(values) < 2 {
return nil, false
}
return values, true
}

View File

@@ -298,11 +298,17 @@ func parseInvokeParameterValue(paramName, raw string) any {
}
if value, ok := extractStandaloneCDATA(trimmed); ok {
if parsed, ok := parseJSONLiteralValue(value); ok {
if parsedArray, ok := coerceArrayValue(parsed, paramName); ok {
return parsedArray
}
return parsed
}
if parsed, ok := parseStructuredCDATAParameterValue(paramName, value); ok {
return parsed
}
if parsed, ok := parseLooseJSONArrayValue(value, paramName); ok {
return parsed
}
return value
}
decoded := html.UnescapeString(extractRawTagValue(trimmed))
@@ -311,6 +317,9 @@ func parseInvokeParameterValue(paramName, raw string) any {
switch v := parsedValue.(type) {
case map[string]any:
if len(v) > 0 {
if parsedArray, ok := coerceArrayValue(v, paramName); ok {
return parsedArray
}
return v
}
case []any:
@@ -321,6 +330,12 @@ func parseInvokeParameterValue(paramName, raw string) any {
return ""
}
if parsedText, ok := parseJSONLiteralValue(text); ok {
if parsedArray, ok := coerceArrayValue(parsedText, paramName); ok {
return parsedArray
}
return parsedText
}
if parsedText, ok := parseLooseJSONArrayValue(text, paramName); ok {
return parsedText
}
return v
@@ -331,13 +346,25 @@ func parseInvokeParameterValue(paramName, raw string) any {
if parsed := parseStructuredToolCallInput(decoded); len(parsed) > 0 {
if len(parsed) == 1 {
if rawValue, ok := parsed["_raw"].(string); ok {
if parsedText, ok := parseLooseJSONArrayValue(rawValue, paramName); ok {
return parsedText
}
return rawValue
}
}
if parsedArray, ok := coerceArrayValue(parsed, paramName); ok {
return parsedArray
}
return parsed
}
}
if parsed, ok := parseJSONLiteralValue(decoded); ok {
if parsedArray, ok := coerceArrayValue(parsed, paramName); ok {
return parsedArray
}
return parsed
}
if parsed, ok := parseLooseJSONArrayValue(decoded, paramName); ok {
return parsed
}
return decoded

View File

@@ -294,6 +294,59 @@ func TestParseToolCallsTreatsSingleItemCDATAAsArray(t *testing.T) {
}
}
func TestParseToolCallsTreatsLooseJSONListAsArray(t *testing.T) {
tests := []struct {
name string
body string
}{
{
name: "plain text",
body: `{"content":"Test TodoWrite tool","status":"completed"}, {"content":"Another task","status":"pending"}`,
},
{
name: "cdata",
body: `<![CDATA[{"content":"Test TodoWrite tool","status":"completed"}, {"content":"Another task","status":"pending"}]]>`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
text := `<tool_calls><invoke name="TodoWrite"><parameter name="todos">` + tt.body + `</parameter></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"TodoWrite"})
if len(calls) != 1 {
t.Fatalf("expected one TodoWrite call, got %#v", calls)
}
items, ok := calls[0].Input["todos"].([]any)
if !ok || len(items) != 2 {
t.Fatalf("expected loose JSON list to parse as array, got %#v", calls[0].Input["todos"])
}
first, ok := items[0].(map[string]any)
if !ok {
t.Fatalf("expected first todo object, got %#v", items[0])
}
if first["content"] != "Test TodoWrite tool" || first["status"] != "completed" {
t.Fatalf("unexpected first todo: %#v", first)
}
})
}
}
func TestParseToolCallsKeepsPreservedTextParametersAsText(t *testing.T) {
text := `<tool_calls><invoke name="Write"><parameter name="content"><![CDATA[{"content":"Test TodoWrite tool","status":"completed"}, {"content":"Another task","status":"pending"}]]></parameter></invoke></tool_calls>`
calls := ParseToolCalls(text, []string{"Write"})
if len(calls) != 1 {
t.Fatalf("expected one Write call, got %#v", calls)
}
got, ok := calls[0].Input["content"].(string)
if !ok {
t.Fatalf("expected content to stay a string, got %#v", calls[0].Input["content"])
}
want := `{"content":"Test TodoWrite tool","status":"completed"}, {"content":"Another task","status":"pending"}`
if got != want {
t.Fatalf("expected content to stay raw, got %q", got)
}
}
func TestParseToolCallsTreatsCDATAObjectFragmentAsObject(t *testing.T) {
payload := `<question><![CDATA[Pick one]]></question><options><item><label><![CDATA[A]]></label></item><item><label><![CDATA[B]]></label></item></options>`
text := `<tool_calls><invoke name="AskUserQuestion"><parameter name="questions"><![CDATA[` + payload + `]]></parameter></invoke></tool_calls>`

View File

@@ -1,6 +1,7 @@
package util
import (
"strings"
"testing"
"ds2api/internal/config"
@@ -12,7 +13,10 @@ func TestMessagesPrepareBasic(t *testing.T) {
if got == "" {
t.Fatal("expected non-empty prompt")
}
if got != "<begin▁of▁sentence><User>Hello<Assistant>" {
if !strings.HasPrefix(got, "<begin▁of▁sentence><System>") {
t.Fatalf("expected output integrity guard at the start, got %q", got)
}
if !strings.Contains(got, "Hello") || !strings.HasSuffix(got, "<Assistant>") {
t.Fatalf("unexpected prompt: %q", got)
}
}
@@ -26,8 +30,11 @@ func TestMessagesPrepareRoles(t *testing.T) {
{"role": "user", "content": "How are you"},
}
got := MessagesPrepare(messages)
if !contains(got, "<System>You are helper<end▁of▁instructions><User>Hi") {
t.Fatalf("expected system/user separation in %q", got)
if !contains(got, "Output integrity guard") {
t.Fatalf("expected output integrity guard in %q", got)
}
if !contains(got, "You are helper") || !contains(got, "<User>Hi") {
t.Fatalf("expected system/user content in %q", got)
}
if !contains(got, "<begin▁of▁sentence>") {
t.Fatalf("expected begin marker in %q", got)
@@ -77,9 +84,12 @@ func TestMessagesPrepareArrayTextVariants(t *testing.T) {
},
}
got := MessagesPrepare(messages)
if got != "<begin▁of▁sentence><User>line1\nline2<Assistant>" {
if !contains(got, "line1\nline2") {
t.Fatalf("unexpected content from text variants: %q", got)
}
if !strings.Contains(got, "Output integrity guard") {
t.Fatalf("expected output integrity guard in %q", got)
}
}
func TestConvertClaudeToDeepSeek(t *testing.T) {

View File

@@ -4,19 +4,26 @@ package util
import (
"strings"
"sync"
tiktoken "github.com/hupe1980/go-tiktoken"
)
var (
tokenEncodingPools sync.Map
tokenEncodingUnsupported sync.Map
)
func countWithTokenizer(text, model string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
encoding, err := tiktoken.NewEncodingForModel(tokenizerModelForCount(model))
if err != nil {
encoding, release := tokenizerEncodingForCount(tokenizerModelForCount(model))
if encoding == nil {
return 0
}
defer release()
ids, _, err := encoding.Encode(text, nil, nil)
if err != nil {
return 0
@@ -24,6 +31,53 @@ func countWithTokenizer(text, model string) int {
return len(ids)
}
func tokenizerEncodingForCount(model string) (*tiktoken.Encoding, func()) {
model = strings.TrimSpace(model)
if model == "" {
model = defaultTokenizerModel
}
if _, ok := tokenEncodingUnsupported.Load(model); ok {
return nil, func() {}
}
if rawPool, ok := tokenEncodingPools.Load(model); ok {
pool, _ := rawPool.(*sync.Pool)
return getEncodingFromPool(pool)
}
encoding, err := tiktoken.NewEncodingForModel(model)
if err != nil {
tokenEncodingUnsupported.Store(model, struct{}{})
return nil, func() {}
}
pool := &sync.Pool{
New: func() any {
encoding, err := tiktoken.NewEncodingForModel(model)
if err != nil {
return nil
}
return encoding
},
}
actualPool, _ := tokenEncodingPools.LoadOrStore(model, pool)
pool, _ = actualPool.(*sync.Pool)
return encoding, func() {
pool.Put(encoding)
}
}
func getEncodingFromPool(pool *sync.Pool) (*tiktoken.Encoding, func()) {
if pool == nil {
return nil, func() {}
}
encoding, _ := pool.Get().(*tiktoken.Encoding)
if encoding == nil {
return nil, func() {}
}
return encoding, func() {
pool.Put(encoding)
}
}
func tokenizerModelForCount(model string) string {
model = strings.ToLower(strings.TrimSpace(model))
if model == "" {

View File

@@ -0,0 +1,35 @@
//go:build !386 && !arm && !mips && !mipsle && !wasm
package util
import "testing"
func TestTokenizerEncodingForCountCachesSupportedModel(t *testing.T) {
encoding, release := tokenizerEncodingForCount(defaultTokenizerModel)
if encoding == nil {
t.Fatalf("expected tokenizer encoding for %q", defaultTokenizerModel)
}
release()
if _, ok := tokenEncodingPools.Load(defaultTokenizerModel); !ok {
t.Fatalf("expected tokenizer encoding pool for %q", defaultTokenizerModel)
}
encoding, release = tokenizerEncodingForCount(defaultTokenizerModel)
if encoding == nil {
t.Fatalf("expected cached tokenizer encoding for %q", defaultTokenizerModel)
}
release()
}
func TestTokenizerEncodingForCountCachesUnsupportedModel(t *testing.T) {
const model = "__ds2api_unsupported_tokenizer_model__"
encoding, release := tokenizerEncodingForCount(model)
release()
if encoding != nil {
t.Fatalf("expected nil encoding for unsupported model %q", model)
}
if _, ok := tokenEncodingUnsupported.Load(model); !ok {
t.Fatalf("expected unsupported tokenizer model to be cached")
}
}