mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-05 08:55:28 +08:00
135 lines
3.3 KiB
Go
135 lines
3.3 KiB
Go
package requestbody
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"strings"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidUTF8Body = errors.New("invalid utf-8 request body")
|
|
errRequestBodyTooLarge = errors.New("request body too large")
|
|
)
|
|
|
|
const maxJSONUTF8ValidationSize = 100 << 20
|
|
|
|
// ValidateJSONUTF8 validates complete JSON request bodies before downstream
|
|
// decoders can silently replace malformed UTF-8 or stop before trailing bytes.
|
|
func ValidateJSONUTF8(next http.Handler) http.Handler {
|
|
if next == nil {
|
|
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if shouldValidateJSONBody(r) {
|
|
r.Body = validateAndReplayBody(r.Body)
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func shouldValidateJSONBody(r *http.Request) bool {
|
|
if r == nil || r.Body == nil {
|
|
return false
|
|
}
|
|
path := ""
|
|
if r.URL != nil {
|
|
path = r.URL.Path
|
|
}
|
|
return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path)
|
|
}
|
|
|
|
func isJSONContentType(raw string) bool {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return false
|
|
}
|
|
mediaType, _, err := mime.ParseMediaType(raw)
|
|
if err != nil {
|
|
mediaType = raw
|
|
}
|
|
mediaType = strings.ToLower(strings.TrimSpace(mediaType))
|
|
return strings.Contains(mediaType, "json")
|
|
}
|
|
|
|
func isKnownJSONRequestPath(method, path string) bool {
|
|
switch strings.ToUpper(strings.TrimSpace(method)) {
|
|
case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete:
|
|
default:
|
|
return false
|
|
}
|
|
path = strings.TrimSpace(path)
|
|
if path == "" {
|
|
return false
|
|
}
|
|
switch {
|
|
case path == "/v1/chat/completions" || path == "/chat/completions":
|
|
return true
|
|
case path == "/v1/responses" || path == "/responses":
|
|
return true
|
|
case path == "/v1/embeddings" || path == "/embeddings":
|
|
return true
|
|
case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages":
|
|
return true
|
|
case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens":
|
|
return true
|
|
case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"):
|
|
return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent")
|
|
case strings.HasPrefix(path, "/admin/"):
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func validateAndReplayBody(body io.ReadCloser) io.ReadCloser {
|
|
if body == nil {
|
|
return body
|
|
}
|
|
raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1))
|
|
if err != nil {
|
|
return &errorReadCloser{err: err, closer: body}
|
|
}
|
|
if len(raw) > maxJSONUTF8ValidationSize {
|
|
return &errorReadCloser{err: errRequestBodyTooLarge, closer: body}
|
|
}
|
|
if !utf8.Valid(raw) {
|
|
return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body}
|
|
}
|
|
return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body}
|
|
}
|
|
|
|
type replayReadCloser struct {
|
|
*bytes.Reader
|
|
closer io.Closer
|
|
}
|
|
|
|
func (r *replayReadCloser) Close() error {
|
|
if r == nil || r.closer == nil {
|
|
return nil
|
|
}
|
|
return r.closer.Close()
|
|
}
|
|
|
|
type errorReadCloser struct {
|
|
err error
|
|
closer io.Closer
|
|
}
|
|
|
|
func (r *errorReadCloser) Read([]byte) (int, error) {
|
|
if r == nil || r.err == nil {
|
|
return 0, io.EOF
|
|
}
|
|
return 0, r.err
|
|
}
|
|
|
|
func (r *errorReadCloser) Close() error {
|
|
if r == nil || r.closer == nil {
|
|
return nil
|
|
}
|
|
return r.closer.Close()
|
|
}
|