fix(server): enhance CORS handling for local API network access (#5236)

* fix(server): enhance CORS handling for local API network access

- Fix CORS preflight validation to use Host header for target validation
- Use Origin header correctly for CORS response headers
- Improve host validation to support both host:port and host-only formats
- Filter upstream CORS headers to prevent duplicate Access-Control-Allow-Origin
- Add CORS headers to all error responses for consistent behavior
- Fix host matching logic to handle trusted hosts with and without ports
- Ensure single Access-Control-Allow-Origin header per response

This resolves CORS preflight failures that were blocking cross-origin
requests to the local API server, enabling proper network access from
web applications and external tools.

Fixes: OPTIONS requests being rejected due to incorrect host validation
Resolves: "access control allow origin cannot contain more than one origin" error

* fix(proxy): bypass host and authorization checks for root path in CORS preflight

* fix(proxy): bypass host and authorization checks for whitelisted paths
This commit is contained in:
Sam Hoang Van 2025-06-11 09:44:17 +07:00 committed by GitHub
parent 1799bfed3f
commit eef37defb4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -68,50 +68,279 @@ async fn proxy_request(
client: Client,
config: ProxyConfig,
) -> Result<Response<Body>, hyper::Error> {
// Handle OPTIONS requests for CORS preflight
log::debug!(
"Received request: {} {} {:?} {:?} {:?}",
req.method(),
req.uri().path(),
req.headers().get(hyper::header::HOST),
req.headers().get(hyper::header::ORIGIN),
req.headers()
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
);
if req.method() == hyper::Method::OPTIONS {
log::debug!(
"Handling CORS preflight request from {:?} {:?}",
req.headers().get(hyper::header::HOST),
req.headers()
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
);
// Get the Host header to validate the target (where request is going)
let host = req
.headers()
.get(hyper::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Get the Origin header for CORS response
let origin = req
.headers()
.get(hyper::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Validate requested method
let requested_method = req
.headers()
.get("Access-Control-Request-Method")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let allowed_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"];
let method_allowed = requested_method.is_empty()
|| allowed_methods
.iter()
.any(|&method| method.eq_ignore_ascii_case(requested_method));
if !method_allowed {
log::warn!("CORS preflight: Method '{}' not allowed", requested_method);
return Ok(Response::builder()
.status(StatusCode::METHOD_NOT_ALLOWED)
.body(Body::from("Method not allowed"))
.unwrap());
}
// Check if the host (target) is trusted, but bypass for whitelisted paths
let request_path = req.uri().path();
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
let is_whitelisted_path = whitelisted_paths.contains(&request_path);
let is_trusted = if is_whitelisted_path {
log::debug!(
"CORS preflight: Bypassing host check for whitelisted path: {}",
request_path
);
true
} else if !host.is_empty() {
log::debug!(
"CORS preflight: Host is '{}', trusted hosts: [{}]",
host,
&config.trusted_hosts.join(", ")
);
is_valid_host(host, &config.trusted_hosts)
} else {
log::warn!("CORS preflight: No Host header present");
false
};
if !is_trusted {
log::warn!(
"CORS preflight: Host '{}' not trusted for path '{}'",
host,
request_path
);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Host not allowed"))
.unwrap());
}
// Get and validate requested headers
let requested_headers = req
.headers()
.get("Access-Control-Request-Headers")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
// Allow common headers plus our required ones
let allowed_headers = [
"accept",
"accept-language",
"authorization",
"cache-control",
"connection",
"content-type",
"dnt",
"host",
"if-modified-since",
"keep-alive",
"origin",
"user-agent",
"x-api-key",
"x-csrf-token",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-proto",
"x-requested-with",
"x-stainless-arch",
"x-stainless-lang",
"x-stainless-os",
"x-stainless-package-version",
"x-stainless-retry-count",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-timeout",
];
let headers_valid = if requested_headers.is_empty() {
true
} else {
requested_headers
.split(',')
.map(|h| h.trim())
.all(|header| {
allowed_headers
.iter()
.any(|&allowed| allowed.eq_ignore_ascii_case(header))
})
};
if !headers_valid {
log::warn!(
"CORS preflight: Some requested headers not allowed: {}",
requested_headers
);
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Headers not allowed"))
.unwrap());
}
// Build CORS response
let mut response = Response::builder()
.status(StatusCode::OK)
.header("Access-Control-Allow-Methods", allowed_methods.join(", "))
.header("Access-Control-Allow-Headers", allowed_headers.join(", "))
.header("Access-Control-Max-Age", "86400")
.header(
"Vary",
"Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
);
// Set Access-Control-Allow-Origin based on origin presence
if !origin.is_empty() {
response = response
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Credentials", "true");
} else {
// No origin header - allow all origins (useful for non-browser clients)
response = response.header("Access-Control-Allow-Origin", "*");
}
log::debug!(
"CORS preflight response: host_trusted={}, origin='{}'",
is_trusted,
origin
);
return Ok(response.body(Body::empty()).unwrap());
}
// Extract headers early for validation and CORS responses
let origin_header = req
.headers()
.get(hyper::header::ORIGIN)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let host_header = req
.headers()
.get(hyper::header::HOST)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let original_path = req.uri().path();
let path = get_destination_path(original_path, &config.prefix);
// Verify Host header
if let Some(host) = req.headers().get(hyper::header::HOST) {
let host_str = host.to_str().unwrap_or("");
if !is_valid_host(host_str, &config.trusted_hosts) {
return Ok(Response::builder()
.status(StatusCode::FORBIDDEN)
.body(Body::from("Invalid host header"))
// Verify Host header (check target), but bypass for whitelisted paths
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
let is_whitelisted_path = whitelisted_paths.contains(&path.as_str());
if !is_whitelisted_path {
if !host_header.is_empty() {
if !is_valid_host(&host_header, &config.trusted_hosts) {
let mut error_response = Response::builder().status(StatusCode::FORBIDDEN);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Invalid host header"))
.unwrap());
}
} else {
let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Missing host header"))
.unwrap());
}
} else {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("Missing host header"))
.unwrap());
log::debug!("Bypassing host validation for whitelisted path: {}", path);
}
if !config.api_key.is_empty() {
// Skip authorization check for whitelisted paths
if !is_whitelisted_path && !config.api_key.is_empty() {
if let Some(authorization) = req.headers().get(hyper::header::AUTHORIZATION) {
let auth_str = authorization.to_str().unwrap_or("");
if auth_str.strip_prefix("Bearer ") != Some(config.api_key.as_str())
{
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
if auth_str.strip_prefix("Bearer ") != Some(config.api_key.as_str()) {
let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Invalid or missing authorization token"))
.unwrap());
}
} else {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Missing authorization header"))
.unwrap());
}
} else if is_whitelisted_path {
log::debug!("Bypassing authorization check for whitelisted path: {}", path);
}
// Block access to /configs endpoint
if path.contains("/configs") {
return Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("Not Found"))
.unwrap());
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response.body(Body::from("Not Found")).unwrap());
}
// Build the outbound request
@ -139,18 +368,36 @@ async fn proxy_request(
let mut builder = Response::builder().status(status);
// Copy response headers
// Copy response headers, excluding CORS headers to avoid conflicts
for (name, value) in response.headers() {
builder = builder.header(name, value);
// Skip CORS headers from upstream to avoid duplicates
if !is_cors_header(name.as_str()) {
builder = builder.header(name, value);
}
}
// Add our own CORS headers
builder = add_cors_headers_with_host_and_origin(
builder,
&host_header,
&origin_header,
&config.trusted_hosts,
);
// Read response body
match response.bytes().await {
Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()),
Err(e) => {
log::error!("Failed to read response body: {}", e);
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
let mut error_response =
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
Ok(error_response
.body(Body::from("Error reading upstream response"))
.unwrap())
}
@ -158,14 +405,61 @@ async fn proxy_request(
}
Err(e) => {
log::error!("Proxy request failed: {}", e);
Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
let mut error_response = Response::builder().status(StatusCode::BAD_GATEWAY);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
Ok(error_response
.body(Body::from(format!("Upstream error: {}", e)))
.unwrap())
}
}
}
/// Checks if a header is a CORS-related header that should be filtered out from upstream responses
fn is_cors_header(header_name: &str) -> bool {
let header_lower = header_name.to_lowercase();
header_lower.starts_with("access-control-")
}
/// Adds CORS headers to a response builder using host for validation and origin for response
fn add_cors_headers_with_host_and_origin(
builder: hyper::http::response::Builder,
host: &str,
origin: &str,
trusted_hosts: &[String],
) -> hyper::http::response::Builder {
let mut builder = builder;
// Check if host (target) is trusted - this is what we validate
let is_trusted = if !host.is_empty() {
is_valid_host(host, trusted_hosts)
} else {
false // Host is required for validation
};
// Set CORS headers using origin for the response
if !origin.is_empty() && is_trusted {
builder = builder
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Credentials", "true");
} else if !origin.is_empty() {
builder = builder.header("Access-Control-Allow-Origin", origin);
} else {
builder = builder.header("Access-Control-Allow-Origin", "*");
}
builder = builder
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
.header("Access-Control-Allow-Headers", "Authorization, Content-Type, Host, Accept, Accept-Language, Cache-Control, Connection, DNT, If-Modified-Since, Keep-Alive, Origin, User-Agent, X-Requested-With, X-CSRF-Token, X-Forwarded-For, X-Forwarded-Proto, X-Forwarded-Host, authorization, content-type, x-api-key")
.header("Vary", "Origin");
builder
}
// Validates if the host header is allowed
fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
if host.is_empty() {
@ -182,6 +476,7 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
};
let default_valid_hosts = ["localhost", "127.0.0.1", "0.0.0.0"];
// Check default valid hosts (host part only)
if default_valid_hosts
.iter()
.any(|&valid| host_without_port.to_lowercase() == valid.to_lowercase())
@ -189,9 +484,29 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
return true;
}
trusted_hosts
.iter()
.any(|valid| host_without_port.to_lowercase() == valid.to_lowercase())
// Check trusted hosts - support both full host:port and host-only formats
trusted_hosts.iter().any(|valid| {
let host_lower = host.to_lowercase();
let valid_lower = valid.to_lowercase();
// First check exact match (including port)
if host_lower == valid_lower {
return true;
}
// Then check host part only (without port)
let valid_without_port = if valid.starts_with('[') {
valid
.split(']')
.next()
.unwrap_or(valid)
.trim_start_matches('[')
} else {
valid.split(':').next().unwrap_or(valid)
};
host_without_port.to_lowercase() == valid_without_port.to_lowercase()
})
}
/// Starts the proxy server