Relax CORS preflight handling across interfaces

This commit is contained in:
CJACK
2026-04-26 00:37:25 +08:00
parent f1ba805173
commit a44afb335a
10 changed files with 440 additions and 17 deletions

View File

@@ -0,0 +1,134 @@
'use strict';
const DEFAULT_CORS_ALLOW_HEADERS = [
'Content-Type',
'Authorization',
'X-API-Key',
'X-Ds2-Target-Account',
'X-Ds2-Source',
'X-Vercel-Protection-Bypass',
'X-Goog-Api-Key',
'Anthropic-Version',
'Anthropic-Beta',
];
const BLOCKED_CORS_REQUEST_HEADERS = new Set([
'x-ds2-internal-token',
]);
function setCorsHeaders(res, req) {
const origin = asString(readHeader(req, 'origin'));
res.setHeader('Access-Control-Allow-Origin', origin || '*');
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
res.setHeader('Access-Control-Max-Age', '600');
res.setHeader(
'Access-Control-Allow-Headers',
buildCORSAllowHeaders(req),
);
addVaryHeader(res, 'Origin');
addVaryHeader(res, 'Access-Control-Request-Headers');
if (asString(readHeader(req, 'access-control-request-private-network')).toLowerCase() === 'true') {
res.setHeader('Access-Control-Allow-Private-Network', 'true');
addVaryHeader(res, 'Access-Control-Request-Private-Network');
}
}
function buildCORSAllowHeaders(req) {
const seen = new Set();
const headers = [];
for (const name of DEFAULT_CORS_ALLOW_HEADERS) {
appendCORSHeaderName(headers, seen, name);
}
for (const name of splitCORSRequestHeaders(readHeader(req, 'access-control-request-headers'))) {
appendCORSHeaderName(headers, seen, name);
}
return headers.join(', ');
}
function splitCORSRequestHeaders(raw) {
const text = asString(raw);
if (!text) {
return [];
}
return text
.split(',')
.map((part) => asString(part))
.filter((name) => isValidCORSHeaderToken(name))
.filter((name) => !BLOCKED_CORS_REQUEST_HEADERS.has(name.toLowerCase()));
}
function appendCORSHeaderName(headers, seen, name) {
const text = asString(name);
if (!isValidCORSHeaderToken(text)) {
return;
}
const lower = text.toLowerCase();
if (BLOCKED_CORS_REQUEST_HEADERS.has(lower) || seen.has(lower)) {
return;
}
seen.add(lower);
headers.push(text);
}
function isValidCORSHeaderToken(name) {
return /^[A-Za-z0-9!#$%&'*+.^_`|~-]+$/.test(asString(name));
}
function addVaryHeader(res, token) {
const text = asString(token);
if (!text || typeof res.setHeader !== 'function') {
return;
}
const current = typeof res.getHeader === 'function' ? res.getHeader('Vary') : '';
const seen = new Set();
const merged = [];
const addToken = (value) => {
const trimmed = asString(value);
if (!trimmed) {
return;
}
const lower = trimmed.toLowerCase();
if (seen.has(lower)) {
return;
}
seen.add(lower);
merged.push(trimmed);
};
if (Array.isArray(current)) {
for (const value of current) {
for (const part of String(value).split(',')) {
addToken(part);
}
}
} else {
for (const part of String(current || '').split(',')) {
addToken(part);
}
}
addToken(text);
res.setHeader('Vary', merged.join(', '));
}
function readHeader(req, key) {
if (!req || !req.headers) {
return '';
}
return req.headers[String(key).toLowerCase()];
}
function asString(v) {
if (typeof v === 'string') {
return v.trim();
}
if (Array.isArray(v)) {
return asString(v[0]);
}
if (v == null) {
return '';
}
return String(v).trim();
}
module.exports = {
setCorsHeaders,
};

View File

@@ -3,15 +3,9 @@
const {
writeOpenAIError,
} = require('./error_shape');
function setCorsHeaders(res) {
res.setHeader('Access-Control-Allow-Origin', '*');
res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS, PUT, DELETE');
res.setHeader(
'Access-Control-Allow-Headers',
'Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Vercel-Protection-Bypass',
);
}
const {
setCorsHeaders,
} = require('./cors');
function header(req, key) {
if (!req || !req.headers) {

View File

@@ -40,7 +40,7 @@ const {
} = require('./dedupe');
async function handler(req, res) {
setCorsHeaders(res);
setCorsHeaders(res, req);
if (req.method === 'OPTIONS') {
res.statusCode = 204;
res.end();

View File

@@ -140,11 +140,25 @@ func (noopLogEntry) Write(_ int, _ int, _ http.Header, _ time.Duration, _ interf
func (noopLogEntry) Panic(_ interface{}, _ []byte) {}
var defaultCORSAllowHeaders = []string{
"Content-Type",
"Authorization",
"X-API-Key",
"X-Ds2-Target-Account",
"X-Ds2-Source",
"X-Vercel-Protection-Bypass",
"X-Goog-Api-Key",
"Anthropic-Version",
"Anthropic-Beta",
}
var blockedCORSRequestHeaders = map[string]struct{}{
"x-ds2-internal-token": {},
}
func cors(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-API-Key, X-Ds2-Target-Account, X-Ds2-Source, X-Vercel-Protection-Bypass")
setCORSHeaders(w, r)
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
@@ -153,6 +167,125 @@ func cors(next http.Handler) http.Handler {
})
}
func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
w.Header().Set("Access-Control-Allow-Origin", origin)
addVaryHeaderToken(w.Header(), "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", buildCORSAllowHeaders(r))
w.Header().Set("Access-Control-Max-Age", "600")
addVaryHeaderToken(w.Header(), "Access-Control-Request-Headers")
if strings.EqualFold(strings.TrimSpace(r.Header.Get("Access-Control-Request-Private-Network")), "true") {
w.Header().Set("Access-Control-Allow-Private-Network", "true")
addVaryHeaderToken(w.Header(), "Access-Control-Request-Private-Network")
}
}
func buildCORSAllowHeaders(r *http.Request) string {
names := make([]string, 0, len(defaultCORSAllowHeaders)+4)
seen := make(map[string]struct{}, len(defaultCORSAllowHeaders)+4)
for _, name := range defaultCORSAllowHeaders {
appendCORSHeaderName(&names, seen, name)
}
if r == nil {
return strings.Join(names, ", ")
}
for _, name := range splitCORSRequestHeaders(r.Header.Get("Access-Control-Request-Headers")) {
appendCORSHeaderName(&names, seen, name)
}
return strings.Join(names, ", ")
}
func splitCORSRequestHeaders(raw string) []string {
if strings.TrimSpace(raw) == "" {
return nil
}
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, part := range parts {
name := strings.TrimSpace(part)
if !isValidCORSHeaderToken(name) {
continue
}
if _, blocked := blockedCORSRequestHeaders[strings.ToLower(name)]; blocked {
continue
}
out = append(out, name)
}
return out
}
func appendCORSHeaderName(dst *[]string, seen map[string]struct{}, name string) {
name = strings.TrimSpace(name)
if !isValidCORSHeaderToken(name) {
return
}
key := strings.ToLower(name)
if _, blocked := blockedCORSRequestHeaders[key]; blocked {
return
}
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
*dst = append(*dst, name)
}
func isValidCORSHeaderToken(v string) bool {
if v == "" {
return false
}
for i := 0; i < len(v); i++ {
c := v[i]
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
continue
}
switch c {
case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '_', '`', '|', '~':
continue
default:
return false
}
}
return true
}
func addVaryHeaderToken(h http.Header, token string) {
if h == nil {
return
}
token = strings.TrimSpace(token)
if token == "" {
return
}
current := h.Values("Vary")
seen := map[string]struct{}{}
merged := make([]string, 0, len(current)+1)
for _, value := range current {
for _, part := range strings.Split(value, ",") {
name := strings.TrimSpace(part)
if name == "" {
continue
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
merged = append(merged, name)
}
}
key := strings.ToLower(token)
if _, ok := seen[key]; !ok {
merged = append(merged, token)
}
h.Set("Vary", strings.Join(merged, ", "))
}
func WriteUnhandledError(w http.ResponseWriter, err error) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)

View File

@@ -0,0 +1,119 @@
package server
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCORSPreflightAllowsThirdPartyRequestedHeaders(t *testing.T) {
handler := cors(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTeapot)
}))
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
req.Header.Set("Origin", "app://obsidian.md")
req.Header.Set("Access-Control-Request-Headers", "authorization, x-stainless-os, x-stainless-runtime, x-ds2-internal-token")
req.Header.Set("Access-Control-Request-Private-Network", "true")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("expected 204 for preflight, got %d", rec.Code)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "app://obsidian.md" {
t.Fatalf("expected origin echo, got %q", got)
}
if got := rec.Header().Get("Access-Control-Allow-Private-Network"); got != "true" {
t.Fatalf("expected private network allow header, got %q", got)
}
allowHeaders := strings.ToLower(rec.Header().Get("Access-Control-Allow-Headers"))
for _, want := range []string{"authorization", "x-stainless-os", "x-stainless-runtime"} {
if !strings.Contains(allowHeaders, want) {
t.Fatalf("expected allow headers to include %q, got %q", want, rec.Header().Get("Access-Control-Allow-Headers"))
}
}
if strings.Contains(allowHeaders, "x-ds2-internal-token") {
t.Fatalf("expected internal-only header to stay blocked, got %q", rec.Header().Get("Access-Control-Allow-Headers"))
}
vary := strings.ToLower(rec.Header().Get("Vary"))
for _, want := range []string{"origin", "access-control-request-headers", "access-control-request-private-network"} {
if !strings.Contains(vary, want) {
t.Fatalf("expected vary to include %q, got %q", want, rec.Header().Get("Vary"))
}
}
}
func TestBuildCORSAllowHeadersKeepsDefaultsWithoutRequest(t *testing.T) {
got := strings.ToLower(buildCORSAllowHeaders(nil))
for _, want := range []string{"content-type", "x-goog-api-key", "anthropic-version", "x-ds2-source"} {
if !strings.Contains(got, want) {
t.Fatalf("expected default allow headers to include %q, got %q", want, got)
}
}
}
func TestAppCORSPreflightIsUnifiedAcrossInterfaces(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
app, err := NewApp()
if err != nil {
t.Fatalf("NewApp() error: %v", err)
}
cases := []struct {
name string
path string
headers string
}{
{
name: "openai",
path: "/v1/chat/completions",
headers: "authorization, x-stainless-os",
},
{
name: "claude",
path: "/anthropic/v1/messages",
headers: "x-api-key, anthropic-version, x-stainless-os",
},
{
name: "gemini",
path: "/v1beta/models/gemini-2.5-pro:generateContent",
headers: "x-goog-api-key, x-client-version",
},
{
name: "admin",
path: "/admin/login",
headers: "content-type, x-requested-with",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodOptions, tc.path, nil)
req.Header.Set("Origin", "app://obsidian.md")
req.Header.Set("Access-Control-Request-Headers", tc.headers)
rec := httptest.NewRecorder()
app.Router.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Fatalf("expected %s preflight status 204, got %d", tc.path, rec.Code)
}
if got := rec.Header().Get("Access-Control-Allow-Origin"); got != "app://obsidian.md" {
t.Fatalf("expected origin echo for %s, got %q", tc.path, got)
}
allowHeaders := strings.ToLower(rec.Header().Get("Access-Control-Allow-Headers"))
for _, want := range splitCORSRequestHeaders(tc.headers) {
if !strings.Contains(allowHeaders, strings.ToLower(want)) {
t.Fatalf("expected allow headers for %s to include %q, got %q", tc.path, want, rec.Header().Get("Access-Control-Allow-Headers"))
}
}
})
}
}