feat: introduce JSON UTF-8 validation middleware and prepend output integrity guard system prompt to messages

This commit is contained in:
CJACK
2026-05-02 02:22:34 +08:00
parent 55abf64717
commit e2756f800d
13 changed files with 514 additions and 18 deletions

View File

@@ -3,12 +3,14 @@ package claude
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"ds2api/internal/config"
"ds2api/internal/httpapi/requestbody"
streamengine "ds2api/internal/stream"
"ds2api/internal/translatorcliproxy"
"ds2api/internal/util"
@@ -33,7 +35,11 @@ func (h *Handler) Messages(w http.ResponseWriter, r *http.Request) {
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, store ConfigReader) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
writeClaudeError(w, http.StatusBadRequest, "invalid body")
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeClaudeError(w, http.StatusBadRequest, "invalid json")
} else {
writeClaudeError(w, http.StatusBadRequest, "invalid body")
}
return true
}
var req map[string]any

View File

@@ -2,8 +2,8 @@ package gemini
import (
"bytes"
"ds2api/internal/toolcall"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
@@ -11,7 +11,9 @@ import (
"github.com/go-chi/chi/v5"
"ds2api/internal/httpapi/requestbody"
"ds2api/internal/sse"
"ds2api/internal/toolcall"
"ds2api/internal/translatorcliproxy"
"ds2api/internal/util"
@@ -32,7 +34,11 @@ func (h *Handler) handleGenerateContent(w http.ResponseWriter, r *http.Request,
func (h *Handler) proxyViaOpenAI(w http.ResponseWriter, r *http.Request, stream bool) bool {
raw, err := io.ReadAll(r.Body)
if err != nil {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
if errors.Is(err, requestbody.ErrInvalidUTF8Body) {
writeGeminiError(w, http.StatusBadRequest, "invalid json")
} else {
writeGeminiError(w, http.StatusBadRequest, "invalid body")
}
return true
}
routeModel := strings.TrimSpace(chi.URLParam(r, "model"))

View File

@@ -0,0 +1,134 @@
package requestbody
import (
"bytes"
"errors"
"io"
"mime"
"net/http"
"strings"
"unicode/utf8"
)
var (
ErrInvalidUTF8Body = errors.New("invalid utf-8 request body")
errRequestBodyTooLarge = errors.New("request body too large")
)
const maxJSONUTF8ValidationSize = 100 << 20
// ValidateJSONUTF8 validates complete JSON request bodies before downstream
// decoders can silently replace malformed UTF-8 or stop before trailing bytes.
func ValidateJSONUTF8(next http.Handler) http.Handler {
if next == nil {
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if shouldValidateJSONBody(r) {
r.Body = validateAndReplayBody(r.Body)
}
next.ServeHTTP(w, r)
})
}
func shouldValidateJSONBody(r *http.Request) bool {
if r == nil || r.Body == nil {
return false
}
path := ""
if r.URL != nil {
path = r.URL.Path
}
return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path)
}
func isJSONContentType(raw string) bool {
raw = strings.TrimSpace(raw)
if raw == "" {
return false
}
mediaType, _, err := mime.ParseMediaType(raw)
if err != nil {
mediaType = raw
}
mediaType = strings.ToLower(strings.TrimSpace(mediaType))
return strings.Contains(mediaType, "json")
}
func isKnownJSONRequestPath(method, path string) bool {
switch strings.ToUpper(strings.TrimSpace(method)) {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
default:
return false
}
path = strings.TrimSpace(path)
if path == "" {
return false
}
switch {
case path == "/v1/chat/completions" || path == "/chat/completions":
return true
case path == "/v1/responses" || path == "/responses":
return true
case path == "/v1/embeddings" || path == "/embeddings":
return true
case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages":
return true
case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
return true
case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"):
return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent")
case strings.HasPrefix(path, "/admin/"):
return true
default:
return false
}
}
func validateAndReplayBody(body io.ReadCloser) io.ReadCloser {
if body == nil {
return body
}
raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1))
if err != nil {
return &errorReadCloser{err: err, closer: body}
}
if len(raw) > maxJSONUTF8ValidationSize {
return &errorReadCloser{err: errRequestBodyTooLarge, closer: body}
}
if !utf8.Valid(raw) {
return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body}
}
return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body}
}
type replayReadCloser struct {
*bytes.Reader
closer io.Closer
}
func (r *replayReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}
type errorReadCloser struct {
err error
closer io.Closer
}
func (r *errorReadCloser) Read([]byte) (int, error) {
if r == nil || r.err == nil {
return 0, io.EOF
}
return 0, r.err
}
func (r *errorReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}

View File

@@ -0,0 +1,158 @@
package requestbody
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type singleByteReadCloser struct {
data []byte
pos int
}
func (r *singleByteReadCloser) Read(p []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, io.EOF
}
p[0] = r.data[r.pos]
r.pos++
return 1, nil
}
func (r *singleByteReadCloser) Close() error {
return nil
}
func TestValidateJSONUTF8AllowsSplitMultibyteRunes(t *testing.T) {
body := []byte(`{"text":"你好"}`)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("unexpected decode error: %v", err)
}
w.WriteHeader(http.StatusNoContent)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", &singleByteReadCloser{data: body})
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("expected 204 for valid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsInvalidBytesBeforeJSONDecode(t *testing.T) {
body := append([]byte(`{"text":"`), 0xff)
body = append(body, []byte(`"}`)...)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json; charset=utf-8")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsInvalidBytesWithoutJSONContentTypeOnKnownPath(t *testing.T) {
body := append([]byte(`{"text":"`), 0xff)
body = append(body, []byte(`"}`)...)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "text/plain")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for invalid utf-8 json, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestValidateJSONUTF8RejectsTrailingInvalidBytesAfterJSONValue(t *testing.T) {
body := append([]byte(`{"text":"ok"}`), 0xff)
handler := ValidateJSONUTF8(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req map[string]any
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(err.Error()))
return
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected 400 for trailing invalid utf-8, got %d body=%q", rec.Code, rec.Body.String())
}
if !strings.Contains(strings.ToLower(rec.Body.String()), "invalid utf-8") {
t.Fatalf("expected utf-8 validation error, got %q", rec.Body.String())
}
}
func TestIsJSONContentType(t *testing.T) {
for _, raw := range []string{
"application/json",
"application/json; charset=utf-8",
"application/problem+json",
"application/vnd.api+json",
} {
if !isJSONContentType(raw) {
t.Fatalf("expected %q to be recognized as json", raw)
}
}
for _, raw := range []string{
"multipart/form-data; boundary=abc",
"text/plain",
"application/octet-stream",
} {
if isJSONContentType(raw) {
t.Fatalf("expected %q not to be recognized as json", raw)
}
}
}
func TestIsKnownJSONRequestPathIncludesGeminiStream(t *testing.T) {
if !isKnownJSONRequestPath(http.MethodPost, "/v1beta/models/gemini-pro:streamGenerateContent") {
t.Fatal("expected Gemini stream generate path to be recognized as json")
}
}