mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 00:45:29 +08:00
feat: introduce JSON UTF-8 validation middleware and prepend output integrity guard system prompt to messages
This commit is contained in:
@@ -33,6 +33,8 @@ Docs: [Overview](README.en.md) / [Architecture](docs/ARCHITECTURE.en.md) / [Depl
|
||||
| Health probes | `GET /healthz`, `GET /readyz` |
|
||||
| CORS | Enabled (uniformly covers `/v1/*`, `/anthropic/*`, `/v1beta/models/*`, and `/admin/*`; echoes the browser `Origin` when present, otherwise `*`; default allow-list includes `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`, `X-Goog-Api-Key`, `Anthropic-Version`, `Anthropic-Beta`, and also accepts third-party preflight-requested headers such as `x-stainless-*`; `/v1/chat/completions` on Vercel Node Runtime matches the same behavior; internal-only `X-Ds2-Internal-Token` remains blocked) |
|
||||
|
||||
- All JSON request bodies must be valid UTF-8; malformed byte sequences are rejected on ingress with `400 invalid json`.
|
||||
|
||||
### 3.0 Adapter-Layer Notes
|
||||
|
||||
- OpenAI / Claude / Gemini protocols are now mounted on one shared `chi` router tree assembled in `internal/server/router.go`.
|
||||
|
||||
2
API.md
2
API.md
@@ -33,6 +33,8 @@
|
||||
| 健康检查 | `GET /healthz`、`GET /readyz` |
|
||||
| CORS | 已启用(统一覆盖 `/v1/*`、`/anthropic/*`、`/v1beta/models/*`、`/admin/*`;浏览器有 `Origin` 时回显该 Origin,否则为 `*`;默认允许 `Content-Type`, `Authorization`, `X-API-Key`, `X-Ds2-Target-Account`, `X-Ds2-Source`, `X-Vercel-Protection-Bypass`, `X-Goog-Api-Key`, `Anthropic-Version`, `Anthropic-Beta`,并会放行预检里声明的第三方请求头,如 `x-stainless-*`;Vercel 上 `/v1/chat/completions` 的 Node Runtime 也对齐相同行为;内部专用头 `X-Ds2-Internal-Token` 仍被拦截) |
|
||||
|
||||
- 所有 JSON 请求体都必须是合法 UTF-8;非法字节序列会在入站阶段被拒绝为 `400 invalid json`。
|
||||
|
||||
### 3.0 接口适配层说明
|
||||
|
||||
- OpenAI / Claude / Gemini 三套协议已统一挂在同一 `chi` 路由树上,由 `internal/server/router.go` 负责装配。
|
||||
|
||||
@@ -117,6 +117,11 @@ OpenAI Chat / Responses 在标准化后、current input file 之前,会默认
|
||||
- 普通请求会直接出现在最终 `prompt` 的最新 user block 末尾。
|
||||
- 如果触发 current input file,它会进入完整上下文文件中。
|
||||
|
||||
另外,`MessagesPrepareWithThinking` 还会在最终 prompt 的最前面预置一段固定的 system 级“输出完整性约束(Output integrity guard)”:
|
||||
|
||||
- 如果上游上下文、工具输出或解析后的文本出现乱码、损坏、部分解析、重复或其他畸形片段,不要模仿、不要回显,只输出给用户的正确内容。
|
||||
- 这段约束位于普通 system / tool prompt 之前,因此是当前最终 prompt 里的最高优先级前置指令。
|
||||
|
||||
### 5.1 角色标记
|
||||
|
||||
最终 prompt 使用 DeepSeek 风格角色标记:
|
||||
|
||||
@@ -3,12 +3,14 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"ds2api/internal/config"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
streamengine "ds2api/internal/stream"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
@@ -33,7 +35,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeClaudeError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
var req map[string]any
|
||||
|
||||
@@ -2,8 +2,8 @@ package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"ds2api/internal/toolcall"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/sse"
|
||||
"ds2api/internal/toolcall"
|
||||
"ds2api/internal/translatorcliproxy"
|
||||
"ds2api/internal/util"
|
||||
|
||||
@@ -32,7 +34,11 @@ func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request,
|
||||
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
|
||||
raw, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid json")
|
||||
} else {
|
||||
writeGeminiError(w, http.StatusBadRequest, "invalid body")
|
||||
}
|
||||
return true
|
||||
}
|
||||
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))
|
||||
|
||||
134
internal/httpapi/requestbody/json_utf8.go
Normal file
134
internal/httpapi/requestbody/json_utf8.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package requestbody
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidUTF8Body = errors.New("invalid utf-8 request body")
|
||||
errRequestBodyTooLarge = errors.New("request body too large")
|
||||
)
|
||||
|
||||
const maxJSONUTF8ValidationSize = 100 << 20
|
||||
|
||||
// ValidateJSONUTF8 validates complete JSON request bodies before downstream
|
||||
// decoders can silently replace malformed UTF-8 or stop before trailing bytes.
|
||||
func ValidateJSONUTF8(next http.Handler) http.Handler {
|
||||
if next == nil {
|
||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if shouldValidateJSONBody(r) {
|
||||
r.Body = validateAndReplayBody(r.Body)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func shouldValidateJSONBody(r *http.Request) bool {
|
||||
if r == nil || r.Body == nil {
|
||||
return false
|
||||
}
|
||||
path := ""
|
||||
if r.URL != nil {
|
||||
path = r.URL.Path
|
||||
}
|
||||
return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path)
|
||||
}
|
||||
|
||||
func isJSONContentType(raw string) bool {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(raw)
|
||||
if err != nil {
|
||||
mediaType = raw
|
||||
}
|
||||
mediaType = strings.ToLower(strings.TrimSpace(mediaType))
|
||||
return strings.Contains(mediaType, "json")
|
||||
}
|
||||
|
||||
func isKnownJSONRequestPath(method, path string) bool {
|
||||
switch strings.ToUpper(strings.TrimSpace(method)) {
|
||||
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
switch {
|
||||
case path == "/v1/chat/completions" || path == "/chat/completions":
|
||||
return true
|
||||
case path == "/v1/responses" || path == "/responses":
|
||||
return true
|
||||
case path == "/v1/embeddings" || path == "/embeddings":
|
||||
return true
|
||||
case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages":
|
||||
return true
|
||||
case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
|
||||
return true
|
||||
case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"):
|
||||
return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent")
|
||||
case strings.HasPrefix(path, "/admin/"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validateAndReplayBody(body io.ReadCloser) io.ReadCloser {
|
||||
if body == nil {
|
||||
return body
|
||||
}
|
||||
raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1))
|
||||
if err != nil {
|
||||
return &errorReadCloser{err: err, closer: body}
|
||||
}
|
||||
if len(raw) > maxJSONUTF8ValidationSize {
|
||||
return &errorReadCloser{err: errRequestBodyTooLarge, closer: body}
|
||||
}
|
||||
if !utf8.Valid(raw) {
|
||||
return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body}
|
||||
}
|
||||
return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body}
|
||||
}
|
||||
|
||||
type replayReadCloser struct {
|
||||
*bytes.Reader
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (r *replayReadCloser) Close() error {
|
||||
if r == nil || r.closer == nil {
|
||||
return nil
|
||||
}
|
||||
return r.closer.Close()
|
||||
}
|
||||
|
||||
type errorReadCloser struct {
|
||||
err error
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func (r *errorReadCloser) Read([]byte) (int, error) {
|
||||
if r == nil || r.err == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
func (r *errorReadCloser) Close() error {
|
||||
if r == nil || r.closer == nil {
|
||||
return nil
|
||||
}
|
||||
return r.closer.Close()
|
||||
}
|
||||
158
internal/httpapi/requestbody/json_utf8_test.go
Normal file
158
internal/httpapi/requestbody/json_utf8_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package requestbody
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type singleByteReadCloser struct {
|
||||
data []byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (r *singleByteReadCloser) Read(p []byte) (int, error) {
|
||||
if r.pos >= len(r.data) {
|
||||
return 0, io.EOF
|
||||
}
|
||||
p[0] = r.data[r.pos]
|
||||
r.pos++
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
func (r *singleByteReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestValidateJSONUTF8AllowsSplitMultibyteRunes(t *testing.T) {
|
||||
body := []byte(`{"text":"你好"}`)
|
||||
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("unexpected decode error: %v", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", &singleByteReadCloser{data: body})
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204 for valid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJSONUTF8RejectsInvalidBytesBeforeJSONDecode(t *testing.T) {
|
||||
body := append([]byte(`{"text":"`), 0xff)
|
||||
body = append(body, []byte(`"}`)...)
|
||||
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
|
||||
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJSONUTF8RejectsInvalidBytesWithoutJSONContentTypeOnKnownPath(t *testing.T) {
|
||||
body := append([]byte(`{"text":"`), 0xff)
|
||||
body = append(body, []byte(`"}`)...)
|
||||
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
|
||||
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJSONUTF8RejectsTrailingInvalidBytesAfterJSONValue(t *testing.T) {
|
||||
body := append([]byte(`{"text":"ok"}`), 0xff)
|
||||
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var req map[string]any
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
|
||||
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsJSONContentType(t *testing.T) {
|
||||
for _, raw := range []string{
|
||||
"application/json",
|
||||
"application/json; charset=utf-8",
|
||||
"application/problem+json",
|
||||
"application/vnd.api+json",
|
||||
} {
|
||||
if !isJSONContentType(raw) {
|
||||
t.Fatalf("expected %q to be recognized as json", raw)
|
||||
}
|
||||
}
|
||||
for _, raw := range []string{
|
||||
"multipart/form-data; boundary=abc",
|
||||
"text/plain",
|
||||
"application/octet-stream",
|
||||
} {
|
||||
if isJSONContentType(raw) {
|
||||
t.Fatalf("expected %q not to be recognized as json", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsKnownJSONRequestPathIncludesGeminiStream(t *testing.T) {
|
||||
if !isKnownJSONRequestPath(http.MethodPost, "/v1beta/models/gemini-pro:streamGenerateContent") {
|
||||
t.Fatal("expected Gemini stream generate path to be recognized as json")
|
||||
}
|
||||
}
|
||||
@@ -10,21 +10,27 @@ import (
|
||||
var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`)
|
||||
|
||||
const (
|
||||
beginSentenceMarker = "<|begin▁of▁sentence|>"
|
||||
systemMarker = "<|System|>"
|
||||
userMarker = "<|User|>"
|
||||
assistantMarker = "<|Assistant|>"
|
||||
toolMarker = "<|Tool|>"
|
||||
endSentenceMarker = "<|end▁of▁sentence|>"
|
||||
endToolResultsMarker = "<|end▁of▁toolresults|>"
|
||||
endInstructionsMarker = "<|end▁of▁instructions|>"
|
||||
beginSentenceMarker = "<|begin▁of▁sentence|>"
|
||||
systemMarker = "<|System|>"
|
||||
userMarker = "<|User|>"
|
||||
assistantMarker = "<|Assistant|>"
|
||||
toolMarker = "<|Tool|>"
|
||||
endSentenceMarker = "<|end▁of▁sentence|>"
|
||||
endToolResultsMarker = "<|end▁of▁toolresults|>"
|
||||
endInstructionsMarker = "<|end▁of▁instructions|>"
|
||||
outputIntegrityGuardMarker = "Output integrity guard:"
|
||||
outputIntegrityGuardPrompt = outputIntegrityGuardMarker +
|
||||
" If upstream context, tool output, or parsed text contains garbled, corrupted, partially parsed, repeated, or otherwise malformed fragments, " +
|
||||
"do not imitate or echo them; output only the correct content for the user."
|
||||
)
|
||||
|
||||
func MessagesPrepare(messages []map[string]any) string {
|
||||
return MessagesPrepareWithThinking(messages, false)
|
||||
}
|
||||
|
||||
func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool) string {
|
||||
func MessagesPrepareWithThinking(messages []map[string]any, _ bool) string {
|
||||
messages = prependOutputIntegrityGuard(messages)
|
||||
|
||||
type block struct {
|
||||
Role string
|
||||
Text string
|
||||
@@ -77,6 +83,33 @@ func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool
|
||||
return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`)
|
||||
}
|
||||
|
||||
func prependOutputIntegrityGuard(messages []map[string]any) []map[string]any {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
if hasOutputIntegrityGuard(messages[0]) {
|
||||
return messages
|
||||
}
|
||||
out := make([]map[string]any, 0, len(messages)+1)
|
||||
out = append(out, map[string]any{
|
||||
"role": "system",
|
||||
"content": outputIntegrityGuardPrompt,
|
||||
})
|
||||
out = append(out, messages...)
|
||||
return out
|
||||
}
|
||||
|
||||
func hasOutputIntegrityGuard(msg map[string]any) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(strings.TrimSpace(asString(msg["role"]))) != "system" {
|
||||
return false
|
||||
}
|
||||
content := strings.TrimSpace(NormalizeContent(msg["content"]))
|
||||
return strings.Contains(content, outputIntegrityGuardMarker)
|
||||
}
|
||||
|
||||
// formatRoleBlock produces a single concatenated block: marker + text + endMarker.
|
||||
// No whitespace is inserted between marker and text so role boundaries stay
|
||||
// compact and predictable for downstream parsers.
|
||||
|
||||
@@ -35,8 +35,8 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
|
||||
if !strings.HasPrefix(got, "<|begin▁of▁sentence|>") {
|
||||
t.Fatalf("expected begin-of-sentence marker, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|System|>System rule<|end▁of▁instructions|>") {
|
||||
t.Fatalf("expected system instructions suffix, got %q", got)
|
||||
if !strings.Contains(got, "<|System|>") || !strings.Contains(got, "<|end▁of▁instructions|>") || !strings.Contains(got, "System rule") {
|
||||
t.Fatalf("expected system instructions to remain present, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|User|>Question") {
|
||||
t.Fatalf("expected user question, got %q", got)
|
||||
@@ -49,6 +49,23 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesPreparePrependsOutputIntegrityGuard(t *testing.T) {
|
||||
messages := []map[string]any{
|
||||
{"role": "system", "content": "System rule"},
|
||||
{"role": "user", "content": "Question"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !strings.HasPrefix(got, beginSentenceMarker+systemMarker+outputIntegrityGuardPrompt) {
|
||||
t.Fatalf("expected output integrity guard to be prepended, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, outputIntegrityGuardPrompt+"\n\nSystem rule") {
|
||||
t.Fatalf("expected output integrity guard to precede system prompt content, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "<|User|>Question") {
|
||||
t.Fatalf("expected user question after guard, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) {
|
||||
got := NormalizeContent([]any{
|
||||
map[string]any{"type": "text", "text": "", "content": "from-content"},
|
||||
|
||||
@@ -88,6 +88,38 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t *
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPromptPrependsOutputIntegrityGuard(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "system", "content": "You are helpful"},
|
||||
map[string]any{"role": "user", "content": "请调用工具"},
|
||||
}
|
||||
tools := []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "search",
|
||||
"description": "search docs",
|
||||
"parameters": map[string]any{
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false)
|
||||
guardIdx := strings.Index(finalPrompt, "Output integrity guard")
|
||||
toolIdx := strings.Index(finalPrompt, "TOOL CALL FORMAT")
|
||||
if guardIdx < 0 {
|
||||
t.Fatalf("expected output integrity guard in final prompt, got: %q", finalPrompt)
|
||||
}
|
||||
if toolIdx < 0 {
|
||||
t.Fatalf("expected tool instructions in final prompt, got: %q", finalPrompt)
|
||||
}
|
||||
if guardIdx > toolIdx {
|
||||
t.Fatalf("expected output integrity guard to precede tool instructions, got: %q", finalPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIFinalPromptReadLikeToolIncludesCacheGuard(t *testing.T) {
|
||||
messages := []any{
|
||||
map[string]any{"role": "user", "content": "请读取文件"},
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"ds2api/internal/httpapi/openai/files"
|
||||
"ds2api/internal/httpapi/openai/responses"
|
||||
"ds2api/internal/httpapi/openai/shared"
|
||||
"ds2api/internal/httpapi/requestbody"
|
||||
"ds2api/internal/webui"
|
||||
)
|
||||
|
||||
@@ -75,6 +76,7 @@ func NewApp() (*App, error) {
|
||||
r.Use(filteredLogger())
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(cors)
|
||||
r.Use(requestbody.ValidateJSONUTF8)
|
||||
r.Use(timeout(0))
|
||||
|
||||
healthzHandler := func(w http.ResponseWriter, _ *http.Request) {
|
||||
|
||||
89
internal/server/router_utf8_test.go
Normal file
89
internal/server/router_utf8_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJSONRequestsRejectInvalidUTF8BeforeDecode(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "0")
|
||||
|
||||
app, err := NewApp()
|
||||
if err != nil {
|
||||
t.Fatalf("NewApp() error: %v", err)
|
||||
}
|
||||
|
||||
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff)
|
||||
body = append(body, []byte(`"}]}`)...)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("x-api-key", "direct-token")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
|
||||
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestKnownJSONRequestsRejectInvalidUTF8WithoutJSONContentType(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "0")
|
||||
|
||||
app, err := NewApp()
|
||||
if err != nil {
|
||||
t.Fatalf("NewApp() error: %v", err)
|
||||
}
|
||||
|
||||
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"`), 0xff)
|
||||
body = append(body, []byte(`"}]}`)...)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
req.Header.Set("x-api-key", "direct-token")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for invalid utf-8 request body, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
|
||||
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONRequestsRejectTrailingInvalidUTF8AfterCompleteJSON(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["managed-key"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
|
||||
t.Setenv("DS2API_ENV_WRITEBACK", "0")
|
||||
|
||||
app, err := NewApp()
|
||||
if err != nil {
|
||||
t.Fatalf("NewApp() error: %v", err)
|
||||
}
|
||||
|
||||
body := append([]byte(`{"model":"deepseek-v4-flash","messages":[{"role":"user","content":"ok"}]}`), 0xff)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", "direct-token")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
app.Router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String())
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid json") {
|
||||
t.Fatalf("expected invalid json error, got %q", rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ds2api/internal/config"
|
||||
@@ -12,7 +13,10 @@ func TestMessagesPrepareBasic(t *testing.T) {
|
||||
if got == "" {
|
||||
t.Fatal("expected non-empty prompt")
|
||||
}
|
||||
if got != "<|begin▁of▁sentence|><|User|>Hello<|Assistant|>" {
|
||||
if !strings.HasPrefix(got, "<|begin▁of▁sentence|><|System|>") {
|
||||
t.Fatalf("expected output integrity guard at the start, got %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Hello") || !strings.HasSuffix(got, "<|Assistant|>") {
|
||||
t.Fatalf("unexpected prompt: %q", got)
|
||||
}
|
||||
}
|
||||
@@ -26,8 +30,11 @@ func TestMessagesPrepareRoles(t *testing.T) {
|
||||
{"role": "user", "content": "How are you"},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if !contains(got, "<|System|>You are helper<|end▁of▁instructions|><|User|>Hi") {
|
||||
t.Fatalf("expected system/user separation in %q", got)
|
||||
if !contains(got, "Output integrity guard") {
|
||||
t.Fatalf("expected output integrity guard in %q", got)
|
||||
}
|
||||
if !contains(got, "You are helper") || !contains(got, "<|User|>Hi") {
|
||||
t.Fatalf("expected system/user content in %q", got)
|
||||
}
|
||||
if !contains(got, "<|begin▁of▁sentence|>") {
|
||||
t.Fatalf("expected begin marker in %q", got)
|
||||
@@ -77,9 +84,12 @@ func TestMessagesPrepareArrayTextVariants(t *testing.T) {
|
||||
},
|
||||
}
|
||||
got := MessagesPrepare(messages)
|
||||
if got != "<|begin▁of▁sentence|><|User|>line1\nline2<|Assistant|>" {
|
||||
if !contains(got, "line1\nline2") {
|
||||
t.Fatalf("unexpected content from text variants: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, "Output integrity guard") {
|
||||
t.Fatalf("expected output integrity guard in %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeToDeepSeek(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user