feat: Migrate Jan's API server to llamacpp-extension

Things to ponder:
- Now, the v1/models endpoint of the API server will return an empty
  list if no models are loaded
- Streaming v1/chat/completion routing works as well as v1/models; needs
  further testing
This commit is contained in:
Akarshan 2025-07-07 20:50:22 +05:30
parent e3faf09ab2
commit d5ffc6a476
No known key found for this signature in database
GPG Key ID: D75C9634A870665F
2 changed files with 296 additions and 412 deletions

View File

@ -333,25 +333,24 @@ pub fn app_token(state: State<'_, AppState>) -> Option<String> {
#[tauri::command] #[tauri::command]
pub async fn start_server( pub async fn start_server(
app: AppHandle, state: State<'_, AppState>,
host: String, host: String,
port: u16, port: u16,
prefix: String, prefix: String,
api_key: String, api_key: String,
trusted_hosts: Vec<String>, trusted_hosts: Vec<String>,
) -> Result<bool, String> { ) -> Result<bool, String> {
let state = app.state::<AppState>();
let auth_token = state.app_token.clone().unwrap_or_default();
let server_handle = state.server_handle.clone(); let server_handle = state.server_handle.clone();
let sessions = state.llama_server_process.clone();
server::start_server( server::start_server(
server_handle, server_handle,
sessions,
host, host,
port, port,
prefix, prefix,
auth_token,
api_key, api_key,
trusted_hosts, vec![trusted_hosts],
) )
.await .await
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;

View File

@ -1,25 +1,24 @@
use flate2::read::GzDecoder;
use futures_util::StreamExt; use futures_util::StreamExt;
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode}; use hyper::{Body, Request, Response, Server, StatusCode};
use hyper::body::Bytes;
use reqwest::Client; use reqwest::Client;
use serde_json::Value; use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::io::Read;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use serde_json;
use crate::core::state::ServerHandle;
use crate::core::state::{LLamaBackendSession, ServerHandle};
/// Configuration for the proxy server /// Configuration for the proxy server
#[derive(Clone)] #[derive(Clone)]
struct ProxyConfig { struct ProxyConfig {
upstream: String,
prefix: String, prefix: String,
auth_token: String, proxy_api_key: String,
trusted_hosts: Vec<String>, trusted_hosts: Vec<Vec<String>>,
api_key: String,
} }
/// Removes a prefix from a path, ensuring proper formatting /// Removes a prefix from a path, ensuring proper formatting
@ -30,8 +29,10 @@ fn remove_prefix(path: &str, prefix: &str) -> String {
let result = path[prefix.len()..].to_string(); let result = path[prefix.len()..].to_string();
if result.is_empty() { if result.is_empty() {
"/".to_string() "/".to_string()
} else { } else if result.starts_with('/') {
result result
} else {
format!("/{}", result)
} }
} else { } else {
path.to_string() path.to_string()
@ -40,25 +41,7 @@ fn remove_prefix(path: &str, prefix: &str) -> String {
/// Determines the final destination path based on the original request path /// Determines the final destination path based on the original request path
fn get_destination_path(original_path: &str, prefix: &str) -> String { fn get_destination_path(original_path: &str, prefix: &str) -> String {
let removed_prefix_path = remove_prefix(original_path, prefix); remove_prefix(original_path, prefix)
// Special paths don't need the /v1 prefix
if !original_path.contains(prefix)
|| removed_prefix_path.contains("/healthz")
|| removed_prefix_path.contains("/process")
{
original_path.to_string()
} else {
format!("/v1{}", removed_prefix_path)
}
}
/// Creates the full upstream URL for the proxied request
fn build_upstream_url(upstream: &str, path: &str) -> String {
let upstream_clean = upstream.trim_end_matches('/');
let path_clean = path.trim_start_matches('/');
format!("{}/{}", upstream_clean, path_clean)
} }
/// Handles the proxy request logic /// Handles the proxy request logic
@ -66,17 +49,8 @@ async fn proxy_request(
req: Request<Body>, req: Request<Body>,
client: Client, client: Client,
config: ProxyConfig, config: ProxyConfig,
sessions: Arc<Mutex<HashMap<i32, LLamaBackendSession>>>,
) -> Result<Response<Body>, hyper::Error> { ) -> Result<Response<Body>, hyper::Error> {
// Handle OPTIONS requests for CORS preflight
log::debug!(
"Received request: {} {} {:?} {:?} {:?}",
req.method(),
req.uri().path(),
req.headers().get(hyper::header::HOST),
req.headers().get(hyper::header::ORIGIN),
req.headers()
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
);
if req.method() == hyper::Method::OPTIONS { if req.method() == hyper::Method::OPTIONS {
log::debug!( log::debug!(
"Handling CORS preflight request from {:?} {:?}", "Handling CORS preflight request from {:?} {:?}",
@ -85,21 +59,18 @@ async fn proxy_request(
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD) .get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
); );
// Get the Host header to validate the target (where request is going)
let host = req let host = req
.headers() .headers()
.get(hyper::header::HOST) .get(hyper::header::HOST)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or(""); .unwrap_or("");
// Get the Origin header for CORS response
let origin = req let origin = req
.headers() .headers()
.get(hyper::header::ORIGIN) .get(hyper::header::ORIGIN)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or(""); .unwrap_or("");
// Validate requested method
let requested_method = req let requested_method = req
.headers() .headers()
.get("Access-Control-Request-Method") .get("Access-Control-Request-Method")
@ -120,7 +91,6 @@ async fn proxy_request(
.unwrap()); .unwrap());
} }
// Check if the host (target) is trusted, but bypass for whitelisted paths
let request_path = req.uri().path(); let request_path = req.uri().path();
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
let is_whitelisted_path = whitelisted_paths.contains(&request_path); let is_whitelisted_path = whitelisted_paths.contains(&request_path);
@ -133,9 +103,9 @@ async fn proxy_request(
true true
} else if !host.is_empty() { } else if !host.is_empty() {
log::debug!( log::debug!(
"CORS preflight: Host is '{}', trusted hosts: [{}]", "CORS preflight: Host is '{}', trusted hosts: {:?}",
host, host,
&config.trusted_hosts.join(", ") &config.trusted_hosts
); );
is_valid_host(host, &config.trusted_hosts) is_valid_host(host, &config.trusted_hosts)
} else { } else {
@ -155,14 +125,12 @@ async fn proxy_request(
.unwrap()); .unwrap());
} }
// Get and validate requested headers
let requested_headers = req let requested_headers = req
.headers() .headers()
.get("Access-Control-Request-Headers") .get("Access-Control-Request-Headers")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or(""); .unwrap_or("");
// Allow common headers plus our required ones
let allowed_headers = [ let allowed_headers = [
"accept", "accept",
"accept-language", "accept-language",
@ -216,7 +184,6 @@ async fn proxy_request(
.unwrap()); .unwrap());
} }
// Build CORS response
let mut response = Response::builder() let mut response = Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
.header("Access-Control-Allow-Methods", allowed_methods.join(", ")) .header("Access-Control-Allow-Methods", allowed_methods.join(", "))
@ -227,13 +194,11 @@ async fn proxy_request(
"Origin, Access-Control-Request-Method, Access-Control-Request-Headers", "Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
); );
// Set Access-Control-Allow-Origin based on origin presence
if !origin.is_empty() { if !origin.is_empty() {
response = response response = response
.header("Access-Control-Allow-Origin", origin) .header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Credentials", "true"); .header("Access-Control-Allow-Credentials", "true");
} else { } else {
// No origin header - allow all origins (useful for non-browser clients)
response = response.header("Access-Control-Allow-Origin", "*"); response = response.header("Access-Control-Allow-Origin", "*");
} }
@ -245,26 +210,26 @@ async fn proxy_request(
return Ok(response.body(Body::empty()).unwrap()); return Ok(response.body(Body::empty()).unwrap());
} }
// Extract headers early for validation and CORS responses let (parts, body) = req.into_parts();
let origin_header = req
.headers() let origin_header = parts.headers
.get(hyper::header::ORIGIN) .get(hyper::header::ORIGIN)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
let host_header = req let host_header = parts.headers
.headers()
.get(hyper::header::HOST) .get(hyper::header::HOST)
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
let original_path = req.uri().path(); let original_path = parts.uri.path();
let path = get_destination_path(original_path, &config.prefix); let headers = parts.headers.clone();
let method = req.method().clone();
let path = get_destination_path(original_path, &config.prefix);
let method = parts.method.clone();
// Verify Host header (check target), but bypass for whitelisted paths
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
let is_whitelisted_path = whitelisted_paths.contains(&path.as_str()); let is_whitelisted_path = whitelisted_paths.contains(&path.as_str());
@ -298,12 +263,11 @@ async fn proxy_request(
log::debug!("Bypassing host validation for whitelisted path: {}", path); log::debug!("Bypassing host validation for whitelisted path: {}", path);
} }
// Skip authorization check for whitelisted paths if !is_whitelisted_path && !config.proxy_api_key.is_empty() {
if !is_whitelisted_path && !config.api_key.is_empty() { if let Some(authorization) = parts.headers.get(hyper::header::AUTHORIZATION) {
if let Some(authorization) = req.headers().get(hyper::header::AUTHORIZATION) {
let auth_str = authorization.to_str().unwrap_or(""); let auth_str = authorization.to_str().unwrap_or("");
if auth_str.strip_prefix("Bearer ") != Some(config.api_key.as_str()) { if auth_str.strip_prefix("Bearer ") != Some(config.proxy_api_key.as_str()) {
let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED); let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED);
error_response = add_cors_headers_with_host_and_origin( error_response = add_cors_headers_with_host_and_origin(
error_response, error_response,
@ -334,7 +298,6 @@ async fn proxy_request(
); );
} }
// Block access to /configs endpoint
if path.contains("/configs") { if path.contains("/configs") {
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND); let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
error_response = add_cors_headers_with_host_and_origin( error_response = add_cors_headers_with_host_and_origin(
@ -346,41 +309,253 @@ async fn proxy_request(
return Ok(error_response.body(Body::from("Not Found")).unwrap()); return Ok(error_response.body(Body::from("Not Found")).unwrap());
} }
// Build the outbound request let mut target_port: Option<i32> = None;
let upstream_url = build_upstream_url(&config.upstream, &path); let mut session_api_key: Option<String> = None;
let mut buffered_body: Option<Bytes> = None;
let original_path = parts.uri.path();
let destination_path = get_destination_path(original_path, &config.prefix);
match (method.clone(), destination_path.as_str()) {
(hyper::Method::POST, "/chat/completions")
| (hyper::Method::POST, "/completions")
| (hyper::Method::POST, "/embeddings") => {
log::debug!(
"Handling POST request to {} requiring model lookup in body",
destination_path
);
let body_bytes = match hyper::body::to_bytes(body).await {
Ok(bytes) => bytes,
Err(_) => {
let mut error_response =
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Failed to read request body"))
.unwrap());
}
};
buffered_body = Some(body_bytes.clone());
match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
Ok(json_body) => {
if let Some(model_id) = json_body.get("model").and_then(|v| v.as_str()) {
log::debug!("Extracted model_id: {}", model_id);
let sessions_guard = sessions.lock().await;
if sessions_guard.is_empty() {
log::warn!("Request for model '{}' but no backend servers are running.", model_id);
let mut error_response = Response::builder().status(StatusCode::SERVICE_UNAVAILABLE);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response.body(Body::from("No backend model servers are available")).unwrap());
}
if let Some(session) = sessions_guard
.values()
.find(|s| s.info.model_id == model_id)
{
target_port = Some(session.info.port);
session_api_key = Some(session.info.api_key.clone());
log::debug!(
"Found session for model_id {} on port {}",
model_id,
session.info.port
);
} else {
log::warn!("No running session found for model_id: {}", model_id);
let mut error_response =
Response::builder().status(StatusCode::NOT_FOUND);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from(format!(
"No running server found for model '{}'",
model_id
)))
.unwrap());
}
} else {
log::warn!(
"POST body for {} is missing 'model' field or it's not a string",
destination_path
);
let mut error_response =
Response::builder().status(StatusCode::BAD_REQUEST);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Request body must contain a 'model' field"))
.unwrap());
}
}
Err(e) => {
log::warn!(
"Failed to parse POST body for {} as JSON: {}",
destination_path,
e
);
let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Invalid JSON body"))
.unwrap());
}
}
}
(hyper::Method::GET, "/models") => {
log::debug!("Handling GET /v1/models request");
let sessions_guard = sessions.lock().await;
let models_data: Vec<_> = sessions_guard
.values()
.map(|session| {
serde_json::json!({
"id": session.info.model_id,
"object": "model",
"created": 1,
"owned_by": "user"
})
})
.collect();
let response_json = serde_json::json!({
"object": "list",
"data": models_data
});
let body_str = serde_json::to_string(&response_json).unwrap_or_else(|_| "{}".to_string());
let mut response_builder = Response::builder()
.status(StatusCode::OK)
.header(hyper::header::CONTENT_TYPE, "application/json");
response_builder = add_cors_headers_with_host_and_origin(
response_builder,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(response_builder.body(Body::from(body_str)).unwrap());
}
_ => {
let is_explicitly_whitelisted_get = method == hyper::Method::GET
&& whitelisted_paths.contains(&destination_path.as_str());
if is_explicitly_whitelisted_get {
log::debug!("Handled whitelisted GET path: {}", destination_path);
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response.body(Body::from("Not Found")).unwrap());
} else {
log::warn!(
"Unhandled method/path for dynamic routing: {} {}",
method,
destination_path
);
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response.body(Body::from("Not Found")).unwrap());
}
}
}
let port = match target_port {
Some(p) => p,
None => {
log::error!("Internal routing error: target_port is None after successful lookup");
let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Internal routing error"))
.unwrap());
}
};
let upstream_url = format!("http://127.0.0.1:{}{}", port, destination_path);
log::debug!("Proxying request to: {}", upstream_url); log::debug!("Proxying request to: {}", upstream_url);
let mut outbound_req = client.request(req.method().clone(), &upstream_url); let mut outbound_req = client.request(method.clone(), &upstream_url);
// Copy original headers for (name, value) in headers.iter() {
for (name, value) in req.headers() {
// Skip host & authorization header
if name != hyper::header::HOST && name != hyper::header::AUTHORIZATION { if name != hyper::header::HOST && name != hyper::header::AUTHORIZATION {
outbound_req = outbound_req.header(name, value); outbound_req = outbound_req.header(name, value);
} }
} }
// Add authorization header if let Some(key) = session_api_key {
outbound_req = outbound_req.header("Authorization", format!("Bearer {}", config.auth_token)); log::debug!("Adding session Authorization header");
outbound_req = outbound_req.header("Authorization", format!("Bearer {}", key));
} else {
log::debug!("No session API key available for this request");
}
// Send the request and handle the response let outbound_req_with_body = if let Some(bytes) = buffered_body {
match outbound_req.body(req.into_body()).send().await { log::debug!("Sending buffered body ({} bytes)", bytes.len());
outbound_req.body(bytes)
} else {
log::error!("Internal logic error: Request reached proxy stage without a buffered body.");
let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
return Ok(error_response
.body(Body::from("Internal server error: unhandled request path"))
.unwrap());
};
match outbound_req_with_body.send().await {
Ok(response) => { Ok(response) => {
let status = response.status(); let status = response.status();
log::debug!("Received response with status: {}", status); log::debug!("Received response with status: {}", status);
let mut builder = Response::builder().status(status); let mut builder = Response::builder().status(status);
// Copy response headers, excluding CORS headers and Content-Length to avoid conflicts
for (name, value) in response.headers() { for (name, value) in response.headers() {
// Skip CORS headers from upstream to avoid duplicates
// Skip Content-Length header when filtering models response to avoid mismatch
if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH { if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH {
builder = builder.header(name, value); builder = builder.header(name, value);
} }
} }
// Add our own CORS headers
builder = add_cors_headers_with_host_and_origin( builder = add_cors_headers_with_host_and_origin(
builder, builder,
&host_header, &host_header,
@ -388,63 +563,32 @@ async fn proxy_request(
&config.trusted_hosts, &config.trusted_hosts,
); );
// Handle streaming vs non-streaming responses let mut stream = response.bytes_stream();
if path.contains("/models") && method == hyper::Method::GET { let (mut sender, body) = hyper::Body::channel();
// For /models endpoint, we need to buffer and filter the response
match response.bytes().await {
Ok(bytes) => match filter_models_response(&bytes) {
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
Err(e) => {
log::warn!(
"Failed to filter models response: {}, returning original",
e
);
Ok(builder.body(Body::from(bytes)).unwrap())
}
},
Err(e) => {
log::error!("Failed to read response body: {}", e);
let mut error_response =
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
error_response = add_cors_headers_with_host_and_origin(
error_response,
&host_header,
&origin_header,
&config.trusted_hosts,
);
Ok(error_response
.body(Body::from("Error reading upstream response"))
.unwrap())
}
}
} else {
// For streaming endpoints (like chat completions), we need to collect and forward the stream
let mut stream = response.bytes_stream();
let (mut sender, body) = hyper::Body::channel();
// Spawn a task to forward the stream tokio::spawn(async move {
tokio::spawn(async move { while let Some(chunk_result) = stream.next().await {
while let Some(chunk_result) = stream.next().await { match chunk_result {
match chunk_result { Ok(chunk) => {
Ok(chunk) => { if sender.send_data(chunk).await.is_err() {
if sender.send_data(chunk).await.is_err() { log::debug!("Client disconnected during streaming");
log::debug!("Client disconnected during streaming");
break;
}
}
Err(e) => {
log::error!("Stream error: {}", e);
break; break;
} }
} }
Err(e) => {
log::error!("Stream error: {}", e);
break;
}
} }
}); }
log::debug!("Streaming complete to client");
});
Ok(builder.body(body).unwrap()) Ok(builder.body(body).unwrap())
}
} }
Err(e) => { Err(e) => {
log::error!("Proxy request failed: {}", e); let error_msg = format!("Proxy request to {} failed: {}", upstream_url, e);
log::error!("{}", error_msg);
let mut error_response = Response::builder().status(StatusCode::BAD_GATEWAY); let mut error_response = Response::builder().status(StatusCode::BAD_GATEWAY);
error_response = add_cors_headers_with_host_and_origin( error_response = add_cors_headers_with_host_and_origin(
error_response, error_response,
@ -452,148 +596,45 @@ async fn proxy_request(
&origin_header, &origin_header,
&config.trusted_hosts, &config.trusted_hosts,
); );
Ok(error_response Ok(error_response.body(Body::from(error_msg)).unwrap())
.body(Body::from(format!("Upstream error: {}", e)))
.unwrap())
} }
} }
} }
/// Checks if the byte array starts with gzip magic number
fn is_gzip_encoded(bytes: &[u8]) -> bool {
bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b
}
/// Decompresses gzip-encoded bytes
fn decompress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let mut decoder = GzDecoder::new(bytes);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
Ok(decompressed)
}
/// Compresses bytes using gzip
fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(bytes)?;
let compressed = encoder.finish()?;
Ok(compressed)
}
/// Filters models response to keep only models with status "downloaded"
fn filter_models_response(
bytes: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// Try to decompress if it's gzip-encoded
let decompressed_bytes = if is_gzip_encoded(bytes) {
log::debug!("Response is gzip-encoded, decompressing...");
decompress_gzip(bytes)?
} else {
bytes.to_vec()
};
let response_text = std::str::from_utf8(&decompressed_bytes)?;
let mut response_json: Value = serde_json::from_str(response_text)?;
// Check if this is a ListModelsResponseDto format with data array
if let Some(data_array) = response_json.get_mut("data") {
if let Some(models) = data_array.as_array_mut() {
// Keep only models where status == "downloaded"
models.retain(|model| {
if let Some(status) = model.get("status") {
if let Some(status_str) = status.as_str() {
status_str == "downloaded"
} else {
false // Remove models without string status
}
} else {
false // Remove models without status field
}
});
log::debug!(
"Filtered models response: {} downloaded models remaining",
models.len()
);
}
} else if response_json.is_array() {
// Handle direct array format
if let Some(models) = response_json.as_array_mut() {
models.retain(|model| {
if let Some(status) = model.get("status") {
if let Some(status_str) = status.as_str() {
status_str == "downloaded"
} else {
false // Remove models without string status
}
} else {
false // Remove models without status field
}
});
log::debug!(
"Filtered models response: {} downloaded models remaining",
models.len()
);
}
}
let filtered_json = serde_json::to_vec(&response_json)?;
// If original was gzip-encoded, re-compress the filtered response
if is_gzip_encoded(bytes) {
log::debug!("Re-compressing filtered response with gzip");
compress_gzip(&filtered_json)
} else {
Ok(filtered_json)
}
}
/// Checks if a header is a CORS-related header that should be filtered out from upstream responses
fn is_cors_header(header_name: &str) -> bool { fn is_cors_header(header_name: &str) -> bool {
let header_lower = header_name.to_lowercase(); let header_lower = header_name.to_lowercase();
header_lower.starts_with("access-control-") header_lower.starts_with("access-control-")
} }
/// Adds CORS headers to a response builder using host for validation and origin for response
fn add_cors_headers_with_host_and_origin( fn add_cors_headers_with_host_and_origin(
builder: hyper::http::response::Builder, builder: hyper::http::response::Builder,
host: &str, host: &str,
origin: &str, origin: &str,
trusted_hosts: &[String], trusted_hosts: &[Vec<String>],
) -> hyper::http::response::Builder { ) -> hyper::http::response::Builder {
let mut builder = builder; let mut builder = builder;
let allow_origin_header = if !origin.is_empty() && is_valid_host(host, trusted_hosts) {
// Check if host (target) is trusted - this is what we validate origin.to_string()
let is_trusted = if !host.is_empty() { } else if !origin.is_empty() {
is_valid_host(host, trusted_hosts) origin.to_string()
} else { } else {
false // Host is required for validation "*".to_string()
}; };
// Set CORS headers using origin for the response
if !origin.is_empty() && is_trusted {
builder = builder
.header("Access-Control-Allow-Origin", origin)
.header("Access-Control-Allow-Credentials", "true");
} else if !origin.is_empty() {
builder = builder.header("Access-Control-Allow-Origin", origin);
} else {
builder = builder.header("Access-Control-Allow-Origin", "*");
}
builder = builder builder = builder
.header("Access-Control-Allow-Origin", allow_origin_header.clone())
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
.header("Access-Control-Allow-Headers", "Authorization, Content-Type, Host, Accept, Accept-Language, Cache-Control, Connection, DNT, If-Modified-Since, Keep-Alive, Origin, User-Agent, X-Requested-With, X-CSRF-Token, X-Forwarded-For, X-Forwarded-Proto, X-Forwarded-Host, authorization, content-type, x-api-key") .header("Access-Control-Allow-Headers", "Authorization, Content-Type, Host, Accept, Accept-Language, Cache-Control, Connection, DNT, If-Modified-Since, Keep-Alive, Origin, User-Agent, X-Requested-With, X-CSRF-Token, X-Forwarded-For, X-Forwarded-Proto, X-Forwarded-Host, authorization, content-type, x-api-key")
.header("Vary", "Origin"); .header("Vary", "Origin");
if allow_origin_header != "*" {
builder = builder.header("Access-Control-Allow-Credentials", "true");
}
builder builder
} }
// Validates if the host header is allowed fn is_valid_host(host: &str, trusted_hosts: &[Vec<String>]) -> bool {
fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
if host.is_empty() { if host.is_empty() {
return false; return false;
} }
@ -608,7 +649,6 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
}; };
let default_valid_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]; let default_valid_hosts = ["localhost", "127.0.0.1", "0.0.0.0"];
// Check default valid hosts (host part only)
if default_valid_hosts if default_valid_hosts
.iter() .iter()
.any(|&valid| host_without_port.to_lowercase() == valid.to_lowercase()) .any(|&valid| host_without_port.to_lowercase() == valid.to_lowercase())
@ -616,17 +656,14 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
return true; return true;
} }
// Check trusted hosts - support both full host:port and host-only formats trusted_hosts.iter().flatten().any(|valid| {
trusted_hosts.iter().any(|valid| {
let host_lower = host.to_lowercase(); let host_lower = host.to_lowercase();
let valid_lower = valid.to_lowercase(); let valid_lower = valid.to_lowercase();
// First check exact match (including port)
if host_lower == valid_lower { if host_lower == valid_lower {
return true; return true;
} }
// Then check host part only (without port)
let valid_without_port = if valid.starts_with('[') { let valid_without_port = if valid.starts_with('[') {
valid valid
.split(']') .split(']')
@ -643,68 +680,54 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
pub async fn is_server_running(server_handle: Arc<Mutex<Option<ServerHandle>>>) -> bool { pub async fn is_server_running(server_handle: Arc<Mutex<Option<ServerHandle>>>) -> bool {
let handle_guard = server_handle.lock().await; let handle_guard = server_handle.lock().await;
handle_guard.is_some()
if handle_guard.is_some() {
true
} else {
false
}
} }
/// Starts the proxy server
pub async fn start_server( pub async fn start_server(
server_handle: Arc<Mutex<Option<ServerHandle>>>, server_handle: Arc<Mutex<Option<ServerHandle>>>,
sessions: Arc<Mutex<HashMap<i32, LLamaBackendSession>>>,
host: String, host: String,
port: u16, port: u16,
prefix: String, prefix: String,
auth_token: String, proxy_api_key: String,
api_key: String, trusted_hosts: Vec<Vec<String>>,
trusted_hosts: Vec<String>,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
// Check if server is already running
let mut handle_guard = server_handle.lock().await; let mut handle_guard = server_handle.lock().await;
if handle_guard.is_some() { if handle_guard.is_some() {
return Err("Server is already running".into()); return Err("Server is already running".into());
} }
// Create server address
let addr: SocketAddr = format!("{}:{}", host, port) let addr: SocketAddr = format!("{}:{}", host, port)
.parse() .parse()
.map_err(|e| format!("Invalid address: {}", e))?; .map_err(|e| format!("Invalid address: {}", e))?;
// Configure proxy settings
let config = ProxyConfig { let config = ProxyConfig {
upstream: "http://127.0.0.1:39291".to_string(),
prefix, prefix,
auth_token, proxy_api_key,
api_key,
trusted_hosts, trusted_hosts,
}; };
// Create HTTP client with longer timeout for streaming
let client = Client::builder() let client = Client::builder()
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming .timeout(std::time::Duration::from_secs(300))
.pool_max_idle_per_host(10) .pool_max_idle_per_host(10)
.pool_idle_timeout(std::time::Duration::from_secs(30)) .pool_idle_timeout(std::time::Duration::from_secs(30))
.build()?; .build()?;
// Create service handler
let make_svc = make_service_fn(move |_conn| { let make_svc = make_service_fn(move |_conn| {
let client = client.clone(); let client = client.clone();
let config = config.clone(); let config = config.clone();
let sessions = sessions.clone();
async move { async move {
Ok::<_, Infallible>(service_fn(move |req| { Ok::<_, Infallible>(service_fn(move |req| {
proxy_request(req, client.clone(), config.clone()) proxy_request(req, client.clone(), config.clone(), sessions.clone())
})) }))
} }
}); });
// Create and start the server
let server = Server::bind(&addr).serve(make_svc); let server = Server::bind(&addr).serve(make_svc);
log::info!("Proxy server started on http://{}", addr); log::info!("Proxy server started on http://{}", addr);
// Spawn server task
let server_task = tokio::spawn(async move { let server_task = tokio::spawn(async move {
if let Err(e) = server.await { if let Err(e) = server.await {
log::error!("Server error: {}", e); log::error!("Server error: {}", e);
@ -717,7 +740,6 @@ pub async fn start_server(
Ok(true) Ok(true)
} }
/// Stops the currently running proxy server
pub async fn stop_server( pub async fn stop_server(
server_handle: Arc<Mutex<Option<ServerHandle>>>, server_handle: Arc<Mutex<Option<ServerHandle>>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
@ -725,7 +747,6 @@ pub async fn stop_server(
if let Some(handle) = handle_guard.take() { if let Some(handle) = handle_guard.take() {
handle.abort(); handle.abort();
// remove the handle to prevent future use
*handle_guard = None; *handle_guard = None;
log::info!("Proxy server stopped"); log::info!("Proxy server stopped");
} else { } else {
@ -734,139 +755,3 @@ pub async fn stop_server(
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_filter_models_response_with_downloaded_status() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
},
{
"id": "model3",
"name": "Model 3"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
// Verify only model1 (with "downloaded" status) is kept
assert!(data.iter().any(|model| model["id"] == "model1"));
// Verify model2 and model3 are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model3"));
}
#[test]
fn test_filter_models_response_direct_array() {
let test_response = json!([
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
}
]);
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response.as_array().unwrap();
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
assert!(data.iter().any(|model| model["id"] == "model1"));
assert!(!data.iter().any(|model| model["id"] == "model2"));
}
#[test]
fn test_filter_models_response_no_status_field() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1"
},
{
"id": "model2",
"name": "Model 2"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 0); // Should remove all models when no status field (no "downloaded" status)
}
#[test]
fn test_filter_models_response_multiple_downloaded() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
},
{
"id": "model3",
"name": "Model 3",
"status": "downloaded"
},
{
"id": "model4",
"name": "Model 4",
"status": "installing"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
// Verify only models with "downloaded" status are kept
assert!(data.iter().any(|model| model["id"] == "model1"));
assert!(data.iter().any(|model| model["id"] == "model3"));
// Verify other models are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model4"));
}
}