Files
ds2api/internal/httpapi/requestbody/json_utf8.go

135 lines
3.3 KiB
Go

package requestbody
import (
"bytes"
"errors"
"io"
"mime"
"net/http"
"strings"
"unicode/utf8"
)
var (
ErrInvalidUTF8Body = errors.New("invalid utf-8 request body")
errRequestBodyTooLarge = errors.New("request body too large")
)
const maxJSONUTF8ValidationSize = 100 << 20
// ValidateJSONUTF8 validates complete JSON request bodies before downstream
// decoders can silently replace malformed UTF-8 or stop before trailing bytes.
func ValidateJSONUTF8(next http.Handler) http.Handler {
if next == nil {
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if shouldValidateJSONBody(r) {
r.Body = validateAndReplayBody(r.Body)
}
next.ServeHTTP(w, r)
})
}
func shouldValidateJSONBody(r *http.Request) bool {
if r == nil || r.Body == nil {
return false
}
path := ""
if r.URL != nil {
path = r.URL.Path
}
return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path)
}
func isJSONContentType(raw string) bool {
raw = strings.TrimSpace(raw)
if raw == "" {
return false
}
mediaType, _, err := mime.ParseMediaType(raw)
if err != nil {
mediaType = raw
}
mediaType = strings.ToLower(strings.TrimSpace(mediaType))
return strings.Contains(mediaType, "json")
}
func isKnownJSONRequestPath(method, path string) bool {
switch strings.ToUpper(strings.TrimSpace(method)) {
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
default:
return false
}
path = strings.TrimSpace(path)
if path == "" {
return false
}
switch {
case path == "/v1/chat/completions" || path == "/chat/completions":
return true
case path == "/v1/responses" || path == "/responses":
return true
case path == "/v1/embeddings" || path == "/embeddings":
return true
case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages":
return true
case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
return true
case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"):
return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent")
case strings.HasPrefix(path, "/admin/"):
return true
default:
return false
}
}
func validateAndReplayBody(body io.ReadCloser) io.ReadCloser {
if body == nil {
return body
}
raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1))
if err != nil {
return &errorReadCloser{err: err, closer: body}
}
if len(raw) > maxJSONUTF8ValidationSize {
return &errorReadCloser{err: errRequestBodyTooLarge, closer: body}
}
if !utf8.Valid(raw) {
return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body}
}
return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body}
}
type replayReadCloser struct {
*bytes.Reader
closer io.Closer
}
func (r *replayReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}
type errorReadCloser struct {
err error
closer io.Closer
}
func (r *errorReadCloser) Read([]byte) (int, error) {
if r == nil || r.err == nil {
return 0, io.EOF
}
return 0, r.err
}
func (r *errorReadCloser) Close() error {
if r == nil || r.closer == nil {
return nil
}
return r.closer.Close()
}