From e2756f800d32953f53fcf39b6d65574b3924c61e Mon Sep 17 00:00:00 2001 From: CJACK Date: Sat, 2 May 2026 02:22:34 +0800 Subject: [PATCH] feat: introduce JSON UTF-8 validation middleware and prepend output integrity guard system prompt to messages --- API.en.md | 2 + API.md | 2 + docs/prompt-compatibility.md | 5 + internal/httpapi/claude/handler_messages.go | 8 +- internal/httpapi/gemini/handler_generate.go | 10 +- internal/httpapi/requestbody/json_utf8.go | 134 +++++++++++++++ .../httpapi/requestbody/json_utf8_test.go | 158 ++++++++++++++++++ internal/prompt/messages.go | 51 +++++- internal/prompt/messages_test.go | 21 ++- internal/promptcompat/prompt_build_test.go | 32 ++++ internal/server/router.go | 2 + internal/server/router_utf8_test.go | 89 ++++++++++ internal/util/messages_test.go | 18 +- 13 files changed, 514 insertions(+), 18 deletions(-) create mode 100644 internal/httpapi/requestbody/json_utf8.go create mode 100644 internal/httpapi/requestbody/json_utf8_test.go create mode 100644 internal/server/router_utf8_test.go diff --git a/API.en.md b/API.en.md index 53ab8a6..a07c304 100644 --- a/API.en.md +++ b/API.en.md @@ -33,6 +33,8 @@ Docs: [Overview](README.en.md) / [Architecture](docs/ARCHITECTURE.en.md) / [Depl | Health probes | `GET /healthz`, `GET /readyz` | | CORS | Enabled (uniformly covers `/v1/*`, `/anthropic/*`, `/v1beta/models/*`, and `/admin/*`; echoes the browser `Origin` when present, otherwise `*`; default allow-list includes `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`, `X-Goog-Api-Key`, `Anthropic-Version`, `Anthropic-Beta`, and also accepts third-party preflight-requested headers such as `x-stainless-*`; `/v1/chat/completions` on Vercel Node Runtime matches the same behavior; internal-only `X-Ds2-Internal-Token` remains blocked) | +- All JSON request bodies must be valid UTF-8; malformed byte sequences are rejected on ingress with `400 invalid json`. + ### 3.0 Adapter-Layer Notes - OpenAI / Claude / Gemini protocols are now mounted on one shared `chi` router tree assembled in `internal/server/router.go`. diff --git a/API.md b/API.md index 7c7bd9b..0733c18 100644 --- a/API.md +++ b/API.md @@ -33,6 +33,8 @@ | 健康检查 | `GET /healthz`、`GET /readyz` | | CORS | 已启用(统一覆盖 `/v1/*`、`/anthropic/*`、`/v1beta/models/*`、`/admin/*`;浏览器有 `Origin` 时回显该 Origin,否则为 `*`;默认允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`, `X-Goog-Api-Key`, `Anthropic-Version`, `Anthropic-Beta`,并会放行预检里声明的第三方请求头,如 `x-stainless-*`;Vercel 上 `/v1/chat/completions` 的 Node Runtime 也对齐相同行为;内部专用头 `X-Ds2-Internal-Token` 仍被拦截) | +- 所有 JSON 请求体都必须是合法 UTF-8;非法字节序列会在入站阶段被拒绝为 `400 invalid json`。 + ### 3.0 接口适配层说明 - OpenAI / Claude / Gemini 三套协议已统一挂在同一 `chi` 路由树上,由 `internal/server/router.go` 负责装配。 diff --git a/docs/prompt-compatibility.md b/docs/prompt-compatibility.md index fcc70a5..951fc90 100644 --- a/docs/prompt-compatibility.md +++ b/docs/prompt-compatibility.md @@ -117,6 +117,11 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认 - 普通请求会直接出现在最终 `prompt` 的最新 user block 末尾。 - 如果触发 current input file,它会进入完整上下文文件中。 +另外,`MessagesPrepareWithThinking` 还会在最终 prompt 的最前面预置一段固定的 system 级“输出完整性约束(Output integrity guard)”: + +- 如果上游上下文、工具输出或解析后的文本出现乱码、损坏、部分解析、重复或其他畸形片段,不要模仿、不要回显,只输出给用户的正确内容。 +- 这段约束位于普通 system / tool prompt 之前,因此是当前最终 prompt 里的最高优先级前置指令。 + ### 5.1 角色标记 最终 prompt 使用 DeepSeek 风格角色标记: diff --git a/internal/httpapi/claude/handler_messages.go b/internal/httpapi/claude/handler_messages.go index e7ed4cd..ed66475 100644 --- a/internal/httpapi/claude/handler_messages.go +++ b/internal/httpapi/claude/handler_messages.go @@ -3,12 +3,14 @@ package claude import ( "bytes" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "strings" "ds2api/internal/config" + "ds2api/internal/httpapi/requestbody" streamengine "ds2api/internal/stream" "ds2api/internal/translatorcliproxy" "ds2api/internal/util" @@ -33,7 +35,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) { func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool { raw, err := io.ReadAll(r.Body) if err != nil { - writeClaudeError(w, http.StatusBadRequest, "invalid body") + if errors.Is(err, requestbody.ErrInvalidUTF8Body) { + writeClaudeError(w, http.StatusBadRequest, "invalid json") + } else { + writeClaudeError(w, http.StatusBadRequest, "invalid body") + } return true } var req map[string]any diff --git a/internal/httpapi/gemini/handler_generate.go b/internal/httpapi/gemini/handler_generate.go index 00c4655..085a29c 100644 --- a/internal/httpapi/gemini/handler_generate.go +++ b/internal/httpapi/gemini/handler_generate.go @@ -2,8 +2,8 @@ package gemini import ( "bytes" - "ds2api/internal/toolcall" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" @@ -11,7 +11,9 @@ import ( "github.com/go-chi/chi/v5" + "ds2api/internal/httpapi/requestbody" "ds2api/internal/sse" + "ds2api/internal/toolcall" "ds2api/internal/translatorcliproxy" "ds2api/internal/util" @@ -32,7 +34,11 @@ func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request, func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool { raw, err := io.ReadAll(r.Body) if err != nil { - writeGeminiError(w, http.StatusBadRequest, "invalid body") + if errors.Is(err, requestbody.ErrInvalidUTF8Body) { + writeGeminiError(w, http.StatusBadRequest, "invalid json") + } else { + writeGeminiError(w, http.StatusBadRequest, "invalid body") + } return true } routeModel := strings.TrimSpace(chi.URLParam(r, "model")) diff --git a/internal/httpapi/requestbody/json_utf8.go b/internal/httpapi/requestbody/json_utf8.go new file mode 100644 index 0000000..5a3afe8 --- /dev/null +++ b/internal/httpapi/requestbody/json_utf8.go @@ -0,0 +1,134 @@ +package requestbody + +import ( + "bytes" + "errors" + "io" + "mime" + "net/http" + "strings" + "unicode/utf8" +) + +var ( + ErrInvalidUTF8Body = errors.New("invalid utf-8 request body") + errRequestBodyTooLarge = errors.New("request body too large") +) + +const maxJSONUTF8ValidationSize = 100 << 20 + +// ValidateJSONUTF8 validates complete JSON request bodies before downstream +// decoders can silently replace malformed UTF-8 or stop before trailing bytes. +func ValidateJSONUTF8(next http.Handler) http.Handler { + if next == nil { + return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if shouldValidateJSONBody(r) { + r.Body = validateAndReplayBody(r.Body) + } + next.ServeHTTP(w, r) + }) +} + +func shouldValidateJSONBody(r *http.Request) bool { + if r == nil || r.Body == nil { + return false + } + path := "" + if r.URL != nil { + path = r.URL.Path + } + return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path) +} + +func isJSONContentType(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + mediaType, _, err := mime.ParseMediaType(raw) + if err != nil { + mediaType = raw + } + mediaType = strings.ToLower(strings.TrimSpace(mediaType)) + return strings.Contains(mediaType, "json") +} + +func isKnownJSONRequestPath(method, path string) bool { + switch strings.ToUpper(strings.TrimSpace(method)) { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + default: + return false + } + path = strings.TrimSpace(path) + if path == "" { + return false + } + switch { + case path == "/v1/chat/completions" || path == "/chat/completions": + return true + case path == "/v1/responses" || path == "/responses": + return true + case path == "/v1/embeddings" || path == "/embeddings": + return true + case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages": + return true + case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens": + return true + case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"): + return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent") + case strings.HasPrefix(path, "/admin/"): + return true + default: + return false + } +} + +func validateAndReplayBody(body io.ReadCloser) io.ReadCloser { + if body == nil { + return body + } + raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1)) + if err != nil { + return &errorReadCloser{err: err, closer: body} + } + if len(raw) > maxJSONUTF8ValidationSize { + return &errorReadCloser{err: errRequestBodyTooLarge, closer: body} + } + if !utf8.Valid(raw) { + return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body} + } + return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body} +} + +type replayReadCloser struct { + *bytes.Reader + closer io.Closer +} + +func (r *replayReadCloser) Close() error { + if r == nil || r.closer == nil { + return nil + } + return r.closer.Close() +} + +type errorReadCloser struct { + err error + closer io.Closer +} + +func (r *errorReadCloser) Read([]byte) (int, error) { + if r == nil || r.err == nil { + return 0, io.EOF + } + return 0, r.err +} + +func (r *errorReadCloser) Close() error { + if r == nil || r.closer == nil { + return nil + } + return r.closer.Close() +} diff --git a/internal/httpapi/requestbody/json_utf8_test.go b/internal/httpapi/requestbody/json_utf8_test.go new file mode 100644 index 0000000..e46af20 --- /dev/null +++ b/internal/httpapi/requestbody/json_utf8_test.go @@ -0,0 +1,158 @@ +package requestbody + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type singleByteReadCloser struct { + data []byte + pos int +} + +func (r *singleByteReadCloser) Read(p []byte) (int, error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + p[0] = r.data[r.pos] + r.pos++ + return 1, nil +} + +func (r *singleByteReadCloser) Close() error { + return nil +} + +func TestValidateJSONUTF8AllowsSplitMultibyteRunes(t *testing.T) { + body := []byte(`{"text":"你好"}`) + handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("unexpected decode error: %v", err) + } + w.WriteHeader(http.StatusNoContent) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", &singleByteReadCloser{data: body}) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("expected 204 for valid utf-8 json, got %d body=%q", rec.Code, rec.Body.String()) + } +} + +func TestValidateJSONUTF8RejectsInvalidBytesBeforeJSONDecode(t *testing.T) { + body := append([]byte(`{"text":"`), 0xff) + body = append(body, []byte(`"}`)...) + handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") { + t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String()) + } +} + +func TestValidateJSONUTF8RejectsInvalidBytesWithoutJSONContentTypeOnKnownPath(t *testing.T) { + body := append([]byte(`{"text":"`), 0xff) + body = append(body, []byte(`"}`)...) + handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "text/plain") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") { + t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String()) + } +} + +func TestValidateJSONUTF8RejectsTrailingInvalidBytesAfterJSONValue(t *testing.T) { + body := append([]byte(`{"text":"ok"}`), 0xff) + handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(err.Error())) + return + } + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") { + t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String()) + } +} + +func TestIsJSONContentType(t *testing.T) { + for _, raw := range []string{ + "application/json", + "application/json; charset=utf-8", + "application/problem+json", + "application/vnd.api+json", + } { + if !isJSONContentType(raw) { + t.Fatalf("expected %q to be recognized as json", raw) + } + } + for _, raw := range []string{ + "multipart/form-data; boundary=abc", + "text/plain", + "application/octet-stream", + } { + if isJSONContentType(raw) { + t.Fatalf("expected %q not to be recognized as json", raw) + } + } +} + +func TestIsKnownJSONRequestPathIncludesGeminiStream(t *testing.T) { + if !isKnownJSONRequestPath(http.MethodPost, "/v1beta/models/gemini-pro:streamGenerateContent") { + t.Fatal("expected Gemini stream generate path to be recognized as json") + } +} diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go index d882f34..d30fc28 100644 --- a/internal/prompt/messages.go +++ b/internal/prompt/messages.go @@ -10,21 +10,27 @@ import ( var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) const ( - beginSentenceMarker = "<|begin▁of▁sentence|>" - systemMarker = "<|System|>" - userMarker = "<|User|>" - assistantMarker = "<|Assistant|>" - toolMarker = "<|Tool|>" - endSentenceMarker = "<|end▁of▁sentence|>" - endToolResultsMarker = "<|end▁of▁toolresults|>" - endInstructionsMarker = "<|end▁of▁instructions|>" + beginSentenceMarker = "<|begin▁of▁sentence|>" + systemMarker = "<|System|>" + userMarker = "<|User|>" + assistantMarker = "<|Assistant|>" + toolMarker = "<|Tool|>" + endSentenceMarker = "<|end▁of▁sentence|>" + endToolResultsMarker = "<|end▁of▁toolresults|>" + endInstructionsMarker = "<|end▁of▁instructions|>" + outputIntegrityGuardMarker = "Output integrity guard:" + outputIntegrityGuardPrompt = outputIntegrityGuardMarker + + " If upstream context, tool output, or parsed text contains garbled, corrupted, partially parsed, repeated, or otherwise malformed fragments, " + + "do not imitate or echo them; output only the correct content for the user." ) func MessagesPrepare(messages []map[string]any) string { return MessagesPrepareWithThinking(messages, false) } -func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool) string { +func MessagesPrepareWithThinking(messages []map[string]any, _ bool) string { + messages = prependOutputIntegrityGuard(messages) + type block struct { Role string Text string @@ -77,6 +83,33 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) } +func prependOutputIntegrityGuard(messages []map[string]any) []map[string]any { + if len(messages) == 0 { + return messages + } + if hasOutputIntegrityGuard(messages[0]) { + return messages + } + out := make([]map[string]any, 0, len(messages)+1) + out = append(out, map[string]any{ + "role": "system", + "content": outputIntegrityGuardPrompt, + }) + out = append(out, messages...) + return out +} + +func hasOutputIntegrityGuard(msg map[string]any) bool { + if msg == nil { + return false + } + if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) != "system" { + return false + } + content := strings.TrimSpace(NormalizeContent(msg["content"])) + return strings.Contains(content, outputIntegrityGuardMarker) +} + // formatRoleBlock produces a single concatenated block: marker + text + endMarker. // No whitespace is inserted between marker and text so role boundaries stay // compact and predictable for downstream parsers. diff --git a/internal/prompt/messages_test.go b/internal/prompt/messages_test.go index a992ae6..f9a195a 100644 --- a/internal/prompt/messages_test.go +++ b/internal/prompt/messages_test.go @@ -35,8 +35,8 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) { if !strings.HasPrefix(got, "<|begin▁of▁sentence|>") { t.Fatalf("expected begin-of-sentence marker, got %q", got) } - if !strings.Contains(got, "<|System|>System rule<|end▁of▁instructions|>") { - t.Fatalf("expected system instructions suffix, got %q", got) + if !strings.Contains(got, "<|System|>") || !strings.Contains(got, "<|end▁of▁instructions|>") || !strings.Contains(got, "System rule") { + t.Fatalf("expected system instructions to remain present, got %q", got) } if !strings.Contains(got, "<|User|>Question") { t.Fatalf("expected user question, got %q", got) @@ -49,6 +49,23 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) { } } +func TestMessagesPreparePrependsOutputIntegrityGuard(t *testing.T) { + messages := []map[string]any{ + {"role": "system", "content": "System rule"}, + {"role": "user", "content": "Question"}, + } + got := MessagesPrepare(messages) + if !strings.HasPrefix(got, beginSentenceMarker+systemMarker+outputIntegrityGuardPrompt) { + t.Fatalf("expected output integrity guard to be prepended, got %q", got) + } + if !strings.Contains(got, outputIntegrityGuardPrompt+"\n\nSystem rule") { + t.Fatalf("expected output integrity guard to precede system prompt content, got %q", got) + } + if !strings.Contains(got, "<|User|>Question") { + t.Fatalf("expected user question after guard, got %q", got) + } +} + func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) { got := NormalizeContent([]any{ map[string]any{"type": "text", "text": "", "content": "from-content"}, diff --git a/internal/promptcompat/prompt_build_test.go b/internal/promptcompat/prompt_build_test.go index 28da8e0..dd80b6d 100644 --- a/internal/promptcompat/prompt_build_test.go +++ b/internal/promptcompat/prompt_build_test.go @@ -88,6 +88,38 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * } } +func TestBuildOpenAIFinalPromptPrependsOutputIntegrityGuard(t *testing.T) { + messages := []any{ + map[string]any{"role": "system", "content": "You are helpful"}, + map[string]any{"role": "user", "content": "请调用工具"}, + } + tools := []any{ + map[string]any{ + "type": "function", + "function": map[string]any{ + "name": "search", + "description": "search docs", + "parameters": map[string]any{ + "type": "object", + }, + }, + }, + } + + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false) + guardIdx := strings.Index(finalPrompt, "Output integrity guard") + toolIdx := strings.Index(finalPrompt, "TOOL CALL FORMAT") + if guardIdx < 0 { + t.Fatalf("expected output integrity guard in final prompt, got: %q", finalPrompt) + } + if toolIdx < 0 { + t.Fatalf("expected tool instructions in final prompt, got: %q", finalPrompt) + } + if guardIdx > toolIdx { + t.Fatalf("expected output integrity guard to precede tool instructions, got: %q", finalPrompt) + } +} + func TestBuildOpenAIFinalPromptReadLikeToolIncludesCacheGuard(t *testing.T) { messages := []any{ map[string]any{"role": "user", "content": "请读取文件"}, diff --git a/internal/server/router.go b/internal/server/router.go index ea13e69..fa852ab 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -27,6 +27,7 @@ import ( "ds2api/internal/httpapi/openai/files" "ds2api/internal/httpapi/openai/responses" "ds2api/internal/httpapi/openai/shared" + "ds2api/internal/httpapi/requestbody" "ds2api/internal/webui" ) @@ -75,6 +76,7 @@ func NewApp() (*App, error) { r.Use(filteredLogger()) r.Use(middleware.Recoverer) r.Use(cors) + r.Use(requestbody.ValidateJSONUTF8) r.Use(timeout(0)) healthzHandler := func(w http.ResponseWriter, _ *http.Request) { diff --git a/internal/server/router_utf8_test.go b/internal/server/router_utf8_test.go new file mode 100644 index 0000000..f06d6bb --- /dev/null +++ b/internal/server/router_utf8_test.go @@ -0,0 +1,89 @@ +package server + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestJSONRequestsRejectInvalidUTF8BeforeDecode(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`) + t.Setenv("DS2API_ENV_WRITEBACK", "0") + + app, err := NewApp() + if err != nil { + t.Fatalf("NewApp() error: %v", err) + } + + body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff) + body = append(body, []byte(`"}]}`)...) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json; charset=utf-8") + req.Header.Set("x-api-key", "direct-token") + + rec := httptest.NewRecorder() + app.Router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") { + t.Fatalf("expected invalid json error, got %q", rec.Body.String()) + } +} + +func TestKnownJSONRequestsRejectInvalidUTF8WithoutJSONContentType(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`) + t.Setenv("DS2API_ENV_WRITEBACK", "0") + + app, err := NewApp() + if err != nil { + t.Fatalf("NewApp() error: %v", err) + } + + body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff) + body = append(body, []byte(`"}]}`)...) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "text/plain") + req.Header.Set("x-api-key", "direct-token") + + rec := httptest.NewRecorder() + app.Router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") { + t.Fatalf("expected invalid json error, got %q", rec.Body.String()) + } +} + +func TestJSONRequestsRejectTrailingInvalidUTF8AfterCompleteJSON(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`) + t.Setenv("DS2API_ENV_WRITEBACK", "0") + + app, err := NewApp() + if err != nil { + t.Fatalf("NewApp() error: %v", err) + } + + body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"ok"}]}`), 0xff) + + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", "direct-token") + + rec := httptest.NewRecorder() + app.Router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String()) + } + if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") { + t.Fatalf("expected invalid json error, got %q", rec.Body.String()) + } +} diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 9ddafd6..569e65d 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -1,6 +1,7 @@ package util import ( + "strings" "testing" "ds2api/internal/config" @@ -12,7 +13,10 @@ func TestMessagesPrepareBasic(t *testing.T) { if got == "" { t.Fatal("expected non-empty prompt") } - if got != "<|begin▁of▁sentence|><|User|>Hello<|Assistant|>" { + if !strings.HasPrefix(got, "<|begin▁of▁sentence|><|System|>") { + t.Fatalf("expected output integrity guard at the start, got %q", got) + } + if !strings.Contains(got, "Hello") || !strings.HasSuffix(got, "<|Assistant|>") { t.Fatalf("unexpected prompt: %q", got) } } @@ -26,8 +30,11 @@ func TestMessagesPrepareRoles(t *testing.T) { {"role": "user", "content": "How are you"}, } got := MessagesPrepare(messages) - if !contains(got, "<|System|>You are helper<|end▁of▁instructions|><|User|>Hi") { - t.Fatalf("expected system/user separation in %q", got) + if !contains(got, "Output integrity guard") { + t.Fatalf("expected output integrity guard in %q", got) + } + if !contains(got, "You are helper") || !contains(got, "<|User|>Hi") { + t.Fatalf("expected system/user content in %q", got) } if !contains(got, "<|begin▁of▁sentence|>") { t.Fatalf("expected begin marker in %q", got) @@ -77,9 +84,12 @@ func TestMessagesPrepareArrayTextVariants(t *testing.T) { }, } got := MessagesPrepare(messages) - if got != "<|begin▁of▁sentence|><|User|>line1\nline2<|Assistant|>" { + if !contains(got, "line1\nline2") { t.Fatalf("unexpected content from text variants: %q", got) } + if !strings.Contains(got, "Output integrity guard") { + t.Fatalf("expected output integrity guard in %q", got) + } } func TestConvertClaudeToDeepSeek(t *testing.T) {