diff --git a/src-tauri/src/core/cmd.rs b/src-tauri/src/core/cmd.rs index 8027eb5d5..8bb80d2f8 100644 --- a/src-tauri/src/core/cmd.rs +++ b/src-tauri/src/core/cmd.rs @@ -333,25 +333,24 @@ pub fn app_token(state: State<'_, AppState>) -> Option { #[tauri::command] pub async fn start_server( - app: AppHandle, + state: State<'_, AppState>, host: String, port: u16, prefix: String, api_key: String, trusted_hosts: Vec, ) -> Result { - let state = app.state::(); - let auth_token = state.app_token.clone().unwrap_or_default(); let server_handle = state.server_handle.clone(); + let sessions = state.llama_server_process.clone(); server::start_server( server_handle, + sessions, host, port, prefix, - auth_token, api_key, - trusted_hosts, + vec![trusted_hosts], ) .await .map_err(|e| e.to_string())?; diff --git a/src-tauri/src/core/server.rs b/src-tauri/src/core/server.rs index 6da4ebf9b..d934ef9f9 100644 --- a/src-tauri/src/core/server.rs +++ b/src-tauri/src/core/server.rs @@ -1,25 +1,24 @@ -use flate2::read::GzDecoder; use futures_util::StreamExt; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server, StatusCode}; +use hyper::body::Bytes; use reqwest::Client; -use serde_json::Value; +use std::collections::HashMap; use std::convert::Infallible; -use std::io::Read; use std::net::SocketAddr; use std::sync::Arc; use tokio::sync::Mutex; +use serde_json; -use crate::core::state::ServerHandle; + +use crate::core::state::{LLamaBackendSession, ServerHandle}; /// Configuration for the proxy server #[derive(Clone)] struct ProxyConfig { - upstream: String, prefix: String, - auth_token: String, - trusted_hosts: Vec, - api_key: String, + proxy_api_key: String, + trusted_hosts: Vec>, } /// Removes a prefix from a path, ensuring proper formatting @@ -30,8 +29,10 @@ fn remove_prefix(path: &str, prefix: &str) -> String { let result = path[prefix.len()..].to_string(); if result.is_empty() { "/".to_string() - } else { + } else if result.starts_with('/') { result + } else { + format!("/{}", result) } } else { path.to_string() @@ -40,25 +41,7 @@ fn remove_prefix(path: &str, prefix: &str) -> String { /// Determines the final destination path based on the original request path fn get_destination_path(original_path: &str, prefix: &str) -> String { - let removed_prefix_path = remove_prefix(original_path, prefix); - - // Special paths don't need the /v1 prefix - if !original_path.contains(prefix) - || removed_prefix_path.contains("/healthz") - || removed_prefix_path.contains("/process") - { - original_path.to_string() - } else { - format!("/v1{}", removed_prefix_path) - } -} - -/// Creates the full upstream URL for the proxied request -fn build_upstream_url(upstream: &str, path: &str) -> String { - let upstream_clean = upstream.trim_end_matches('/'); - let path_clean = path.trim_start_matches('/'); - - format!("{}/{}", upstream_clean, path_clean) + remove_prefix(original_path, prefix) } /// Handles the proxy request logic @@ -66,17 +49,8 @@ async fn proxy_request( req: Request, client: Client, config: ProxyConfig, + sessions: Arc>>, ) -> Result, hyper::Error> { - // Handle OPTIONS requests for CORS preflight - log::debug!( - "Received request: {} {} {:?} {:?} {:?}", - req.method(), - req.uri().path(), - req.headers().get(hyper::header::HOST), - req.headers().get(hyper::header::ORIGIN), - req.headers() - .get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD) - ); if req.method() == hyper::Method::OPTIONS { log::debug!( "Handling CORS preflight request from {:?} {:?}", @@ -85,21 +59,18 @@ async fn proxy_request( .get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD) ); - // Get the Host header to validate the target (where request is going) let host = req .headers() .get(hyper::header::HOST) .and_then(|v| v.to_str().ok()) .unwrap_or(""); - // Get the Origin header for CORS response let origin = req .headers() .get(hyper::header::ORIGIN) .and_then(|v| v.to_str().ok()) .unwrap_or(""); - // Validate requested method let requested_method = req .headers() .get("Access-Control-Request-Method") @@ -120,7 +91,6 @@ async fn proxy_request( .unwrap()); } - // Check if the host (target) is trusted, but bypass for whitelisted paths let request_path = req.uri().path(); let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; let is_whitelisted_path = whitelisted_paths.contains(&request_path); @@ -133,9 +103,9 @@ async fn proxy_request( true } else if !host.is_empty() { log::debug!( - "CORS preflight: Host is '{}', trusted hosts: [{}]", + "CORS preflight: Host is '{}', trusted hosts: {:?}", host, - &config.trusted_hosts.join(", ") + &config.trusted_hosts ); is_valid_host(host, &config.trusted_hosts) } else { @@ -155,14 +125,12 @@ async fn proxy_request( .unwrap()); } - // Get and validate requested headers let requested_headers = req .headers() .get("Access-Control-Request-Headers") .and_then(|v| v.to_str().ok()) .unwrap_or(""); - // Allow common headers plus our required ones let allowed_headers = [ "accept", "accept-language", @@ -216,7 +184,6 @@ async fn proxy_request( .unwrap()); } - // Build CORS response let mut response = Response::builder() .status(StatusCode::OK) .header("Access-Control-Allow-Methods", allowed_methods.join(", ")) @@ -227,13 +194,11 @@ async fn proxy_request( "Origin, Access-Control-Request-Method, Access-Control-Request-Headers", ); - // Set Access-Control-Allow-Origin based on origin presence if !origin.is_empty() { response = response .header("Access-Control-Allow-Origin", origin) .header("Access-Control-Allow-Credentials", "true"); } else { - // No origin header - allow all origins (useful for non-browser clients) response = response.header("Access-Control-Allow-Origin", "*"); } @@ -245,26 +210,26 @@ async fn proxy_request( return Ok(response.body(Body::empty()).unwrap()); } - // Extract headers early for validation and CORS responses - let origin_header = req - .headers() + let (parts, body) = req.into_parts(); + + let origin_header = parts.headers .get(hyper::header::ORIGIN) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); - let host_header = req - .headers() + let host_header = parts.headers .get(hyper::header::HOST) .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); - let original_path = req.uri().path(); - let path = get_destination_path(original_path, &config.prefix); - let method = req.method().clone(); + let original_path = parts.uri.path(); + let headers = parts.headers.clone(); + + let path = get_destination_path(original_path, &config.prefix); + let method = parts.method.clone(); - // Verify Host header (check target), but bypass for whitelisted paths let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; let is_whitelisted_path = whitelisted_paths.contains(&path.as_str()); @@ -298,12 +263,11 @@ async fn proxy_request( log::debug!("Bypassing host validation for whitelisted path: {}", path); } - // Skip authorization check for whitelisted paths - if !is_whitelisted_path && !config.api_key.is_empty() { - if let Some(authorization) = req.headers().get(hyper::header::AUTHORIZATION) { + if !is_whitelisted_path && !config.proxy_api_key.is_empty() { + if let Some(authorization) = parts.headers.get(hyper::header::AUTHORIZATION) { let auth_str = authorization.to_str().unwrap_or(""); - if auth_str.strip_prefix("Bearer ") != Some(config.api_key.as_str()) { + if auth_str.strip_prefix("Bearer ") != Some(config.proxy_api_key.as_str()) { let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED); error_response = add_cors_headers_with_host_and_origin( error_response, @@ -334,7 +298,6 @@ async fn proxy_request( ); } - // Block access to /configs endpoint if path.contains("/configs") { let mut error_response = Response::builder().status(StatusCode::NOT_FOUND); error_response = add_cors_headers_with_host_and_origin( @@ -346,41 +309,253 @@ async fn proxy_request( return Ok(error_response.body(Body::from("Not Found")).unwrap()); } - // Build the outbound request - let upstream_url = build_upstream_url(&config.upstream, &path); + let mut target_port: Option = None; + let mut session_api_key: Option = None; + let mut buffered_body: Option = None; + let original_path = parts.uri.path(); + let destination_path = get_destination_path(original_path, &config.prefix); + + match (method.clone(), destination_path.as_str()) { + (hyper::Method::POST, "/chat/completions") + | (hyper::Method::POST, "/completions") + | (hyper::Method::POST, "/embeddings") => { + log::debug!( + "Handling POST request to {} requiring model lookup in body", + destination_path + ); + let body_bytes = match hyper::body::to_bytes(body).await { + Ok(bytes) => bytes, + Err(_) => { + let mut error_response = + Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from("Failed to read request body")) + .unwrap()); + } + }; + buffered_body = Some(body_bytes.clone()); + + match serde_json::from_slice::(&body_bytes) { + Ok(json_body) => { + if let Some(model_id) = json_body.get("model").and_then(|v| v.as_str()) { + log::debug!("Extracted model_id: {}", model_id); + let sessions_guard = sessions.lock().await; + + if sessions_guard.is_empty() { + log::warn!("Request for model '{}' but no backend servers are running.", model_id); + let mut error_response = Response::builder().status(StatusCode::SERVICE_UNAVAILABLE); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response.body(Body::from("No backend model servers are available")).unwrap()); + } + + if let Some(session) = sessions_guard + .values() + .find(|s| s.info.model_id == model_id) + { + target_port = Some(session.info.port); + session_api_key = Some(session.info.api_key.clone()); + log::debug!( + "Found session for model_id {} on port {}", + model_id, + session.info.port + ); + } else { + log::warn!("No running session found for model_id: {}", model_id); + let mut error_response = + Response::builder().status(StatusCode::NOT_FOUND); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from(format!( + "No running server found for model '{}'", + model_id + ))) + .unwrap()); + } + } else { + log::warn!( + "POST body for {} is missing 'model' field or it's not a string", + destination_path + ); + let mut error_response = + Response::builder().status(StatusCode::BAD_REQUEST); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from("Request body must contain a 'model' field")) + .unwrap()); + } + } + Err(e) => { + log::warn!( + "Failed to parse POST body for {} as JSON: {}", + destination_path, + e + ); + let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from("Invalid JSON body")) + .unwrap()); + } + } + } + (hyper::Method::GET, "/models") => { + log::debug!("Handling GET /v1/models request"); + let sessions_guard = sessions.lock().await; + + let models_data: Vec<_> = sessions_guard + .values() + .map(|session| { + serde_json::json!({ + "id": session.info.model_id, + "object": "model", + "created": 1, + "owned_by": "user" + }) + }) + .collect(); + + let response_json = serde_json::json!({ + "object": "list", + "data": models_data + }); + + let body_str = serde_json::to_string(&response_json).unwrap_or_else(|_| "{}".to_string()); + + let mut response_builder = Response::builder() + .status(StatusCode::OK) + .header(hyper::header::CONTENT_TYPE, "application/json"); + + response_builder = add_cors_headers_with_host_and_origin( + response_builder, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + + return Ok(response_builder.body(Body::from(body_str)).unwrap()); + } + _ => { + let is_explicitly_whitelisted_get = method == hyper::Method::GET + && whitelisted_paths.contains(&destination_path.as_str()); + if is_explicitly_whitelisted_get { + log::debug!("Handled whitelisted GET path: {}", destination_path); + let mut error_response = Response::builder().status(StatusCode::NOT_FOUND); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response.body(Body::from("Not Found")).unwrap()); + } else { + log::warn!( + "Unhandled method/path for dynamic routing: {} {}", + method, + destination_path + ); + let mut error_response = Response::builder().status(StatusCode::NOT_FOUND); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response.body(Body::from("Not Found")).unwrap()); + } + } + } + + let port = match target_port { + Some(p) => p, + None => { + log::error!("Internal routing error: target_port is None after successful lookup"); + let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from("Internal routing error")) + .unwrap()); + } + }; + + let upstream_url = format!("http://127.0.0.1:{}{}", port, destination_path); log::debug!("Proxying request to: {}", upstream_url); - let mut outbound_req = client.request(req.method().clone(), &upstream_url); + let mut outbound_req = client.request(method.clone(), &upstream_url); - // Copy original headers - for (name, value) in req.headers() { - // Skip host & authorization header + for (name, value) in headers.iter() { if name != hyper::header::HOST && name != hyper::header::AUTHORIZATION { outbound_req = outbound_req.header(name, value); } } - // Add authorization header - outbound_req = outbound_req.header("Authorization", format!("Bearer {}", config.auth_token)); + if let Some(key) = session_api_key { + log::debug!("Adding session Authorization header"); + outbound_req = outbound_req.header("Authorization", format!("Bearer {}", key)); + } else { + log::debug!("No session API key available for this request"); + } - // Send the request and handle the response - match outbound_req.body(req.into_body()).send().await { + let outbound_req_with_body = if let Some(bytes) = buffered_body { + log::debug!("Sending buffered body ({} bytes)", bytes.len()); + outbound_req.body(bytes) + } else { + log::error!("Internal logic error: Request reached proxy stage without a buffered body."); + let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); + error_response = add_cors_headers_with_host_and_origin( + error_response, + &host_header, + &origin_header, + &config.trusted_hosts, + ); + return Ok(error_response + .body(Body::from("Internal server error: unhandled request path")) + .unwrap()); + }; + + match outbound_req_with_body.send().await { Ok(response) => { let status = response.status(); log::debug!("Received response with status: {}", status); let mut builder = Response::builder().status(status); - // Copy response headers, excluding CORS headers and Content-Length to avoid conflicts for (name, value) in response.headers() { - // Skip CORS headers from upstream to avoid duplicates - // Skip Content-Length header when filtering models response to avoid mismatch if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH { builder = builder.header(name, value); } } - // Add our own CORS headers builder = add_cors_headers_with_host_and_origin( builder, &host_header, @@ -388,63 +563,32 @@ async fn proxy_request( &config.trusted_hosts, ); - // Handle streaming vs non-streaming responses - if path.contains("/models") && method == hyper::Method::GET { - // For /models endpoint, we need to buffer and filter the response - match response.bytes().await { - Ok(bytes) => match filter_models_response(&bytes) { - Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), - Err(e) => { - log::warn!( - "Failed to filter models response: {}, returning original", - e - ); - Ok(builder.body(Body::from(bytes)).unwrap()) - } - }, - Err(e) => { - log::error!("Failed to read response body: {}", e); - let mut error_response = - Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR); - error_response = add_cors_headers_with_host_and_origin( - error_response, - &host_header, - &origin_header, - &config.trusted_hosts, - ); - Ok(error_response - .body(Body::from("Error reading upstream response")) - .unwrap()) - } - } - } else { - // For streaming endpoints (like chat completions), we need to collect and forward the stream - let mut stream = response.bytes_stream(); - let (mut sender, body) = hyper::Body::channel(); + let mut stream = response.bytes_stream(); + let (mut sender, body) = hyper::Body::channel(); - // Spawn a task to forward the stream - tokio::spawn(async move { - while let Some(chunk_result) = stream.next().await { - match chunk_result { - Ok(chunk) => { - if sender.send_data(chunk).await.is_err() { - log::debug!("Client disconnected during streaming"); - break; - } - } - Err(e) => { - log::error!("Stream error: {}", e); + tokio::spawn(async move { + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + if sender.send_data(chunk).await.is_err() { + log::debug!("Client disconnected during streaming"); break; } } + Err(e) => { + log::error!("Stream error: {}", e); + break; + } } - }); + } + log::debug!("Streaming complete to client"); + }); - Ok(builder.body(body).unwrap()) - } + Ok(builder.body(body).unwrap()) } Err(e) => { - log::error!("Proxy request failed: {}", e); + let error_msg = format!("Proxy request to {} failed: {}", upstream_url, e); + log::error!("{}", error_msg); let mut error_response = Response::builder().status(StatusCode::BAD_GATEWAY); error_response = add_cors_headers_with_host_and_origin( error_response, @@ -452,148 +596,45 @@ async fn proxy_request( &origin_header, &config.trusted_hosts, ); - Ok(error_response - .body(Body::from(format!("Upstream error: {}", e))) - .unwrap()) + Ok(error_response.body(Body::from(error_msg)).unwrap()) } } } -/// Checks if the byte array starts with gzip magic number -fn is_gzip_encoded(bytes: &[u8]) -> bool { - bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b -} - -/// Decompresses gzip-encoded bytes -fn decompress_gzip(bytes: &[u8]) -> Result, Box> { - let mut decoder = GzDecoder::new(bytes); - let mut decompressed = Vec::new(); - decoder.read_to_end(&mut decompressed)?; - Ok(decompressed) -} - -/// Compresses bytes using gzip -fn compress_gzip(bytes: &[u8]) -> Result, Box> { - use flate2::write::GzEncoder; - use flate2::Compression; - use std::io::Write; - - let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); - encoder.write_all(bytes)?; - let compressed = encoder.finish()?; - Ok(compressed) -} - -/// Filters models response to keep only models with status "downloaded" -fn filter_models_response( - bytes: &[u8], -) -> Result, Box> { - // Try to decompress if it's gzip-encoded - let decompressed_bytes = if is_gzip_encoded(bytes) { - log::debug!("Response is gzip-encoded, decompressing..."); - decompress_gzip(bytes)? - } else { - bytes.to_vec() - }; - - let response_text = std::str::from_utf8(&decompressed_bytes)?; - let mut response_json: Value = serde_json::from_str(response_text)?; - - // Check if this is a ListModelsResponseDto format with data array - if let Some(data_array) = response_json.get_mut("data") { - if let Some(models) = data_array.as_array_mut() { - // Keep only models where status == "downloaded" - models.retain(|model| { - if let Some(status) = model.get("status") { - if let Some(status_str) = status.as_str() { - status_str == "downloaded" - } else { - false // Remove models without string status - } - } else { - false // Remove models without status field - } - }); - log::debug!( - "Filtered models response: {} downloaded models remaining", - models.len() - ); - } - } else if response_json.is_array() { - // Handle direct array format - if let Some(models) = response_json.as_array_mut() { - models.retain(|model| { - if let Some(status) = model.get("status") { - if let Some(status_str) = status.as_str() { - status_str == "downloaded" - } else { - false // Remove models without string status - } - } else { - false // Remove models without status field - } - }); - log::debug!( - "Filtered models response: {} downloaded models remaining", - models.len() - ); - } - } - - let filtered_json = serde_json::to_vec(&response_json)?; - - // If original was gzip-encoded, re-compress the filtered response - if is_gzip_encoded(bytes) { - log::debug!("Re-compressing filtered response with gzip"); - compress_gzip(&filtered_json) - } else { - Ok(filtered_json) - } -} - -/// Checks if a header is a CORS-related header that should be filtered out from upstream responses fn is_cors_header(header_name: &str) -> bool { let header_lower = header_name.to_lowercase(); header_lower.starts_with("access-control-") } -/// Adds CORS headers to a response builder using host for validation and origin for response fn add_cors_headers_with_host_and_origin( builder: hyper::http::response::Builder, host: &str, origin: &str, - trusted_hosts: &[String], + trusted_hosts: &[Vec], ) -> hyper::http::response::Builder { let mut builder = builder; - - // Check if host (target) is trusted - this is what we validate - let is_trusted = if !host.is_empty() { - is_valid_host(host, trusted_hosts) + let allow_origin_header = if !origin.is_empty() && is_valid_host(host, trusted_hosts) { + origin.to_string() + } else if !origin.is_empty() { + origin.to_string() } else { - false // Host is required for validation + "*".to_string() }; - // Set CORS headers using origin for the response - if !origin.is_empty() && is_trusted { - builder = builder - .header("Access-Control-Allow-Origin", origin) - .header("Access-Control-Allow-Credentials", "true"); - } else if !origin.is_empty() { - builder = builder.header("Access-Control-Allow-Origin", origin); - } else { - builder = builder.header("Access-Control-Allow-Origin", "*"); - } - builder = builder + .header("Access-Control-Allow-Origin", allow_origin_header.clone()) .header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH") .header("Access-Control-Allow-Headers", "Authorization, Content-Type, Host, Accept, Accept-Language, Cache-Control, Connection, DNT, If-Modified-Since, Keep-Alive, Origin, User-Agent, X-Requested-With, X-CSRF-Token, X-Forwarded-For, X-Forwarded-Proto, X-Forwarded-Host, authorization, content-type, x-api-key") .header("Vary", "Origin"); + if allow_origin_header != "*" { + builder = builder.header("Access-Control-Allow-Credentials", "true"); + } + builder } -// Validates if the host header is allowed -fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool { +fn is_valid_host(host: &str, trusted_hosts: &[Vec]) -> bool { if host.is_empty() { return false; } @@ -608,7 +649,6 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool { }; let default_valid_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]; - // Check default valid hosts (host part only) if default_valid_hosts .iter() .any(|&valid| host_without_port.to_lowercase() == valid.to_lowercase()) @@ -616,17 +656,14 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool { return true; } - // Check trusted hosts - support both full host:port and host-only formats - trusted_hosts.iter().any(|valid| { + trusted_hosts.iter().flatten().any(|valid| { let host_lower = host.to_lowercase(); let valid_lower = valid.to_lowercase(); - // First check exact match (including port) if host_lower == valid_lower { return true; } - // Then check host part only (without port) let valid_without_port = if valid.starts_with('[') { valid .split(']') @@ -643,68 +680,54 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool { pub async fn is_server_running(server_handle: Arc>>) -> bool { let handle_guard = server_handle.lock().await; - - if handle_guard.is_some() { - true - } else { - false - } + handle_guard.is_some() } -/// Starts the proxy server pub async fn start_server( server_handle: Arc>>, + sessions: Arc>>, host: String, port: u16, prefix: String, - auth_token: String, - api_key: String, - trusted_hosts: Vec, + proxy_api_key: String, + trusted_hosts: Vec>, ) -> Result> { - // Check if server is already running let mut handle_guard = server_handle.lock().await; if handle_guard.is_some() { return Err("Server is already running".into()); } - // Create server address let addr: SocketAddr = format!("{}:{}", host, port) .parse() .map_err(|e| format!("Invalid address: {}", e))?; - // Configure proxy settings let config = ProxyConfig { - upstream: "http://127.0.0.1:39291".to_string(), prefix, - auth_token, - api_key, + proxy_api_key, trusted_hosts, }; - // Create HTTP client with longer timeout for streaming let client = Client::builder() - .timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming + .timeout(std::time::Duration::from_secs(300)) .pool_max_idle_per_host(10) .pool_idle_timeout(std::time::Duration::from_secs(30)) .build()?; - // Create service handler let make_svc = make_service_fn(move |_conn| { let client = client.clone(); let config = config.clone(); + let sessions = sessions.clone(); async move { Ok::<_, Infallible>(service_fn(move |req| { - proxy_request(req, client.clone(), config.clone()) + proxy_request(req, client.clone(), config.clone(), sessions.clone()) })) } }); - // Create and start the server let server = Server::bind(&addr).serve(make_svc); log::info!("Proxy server started on http://{}", addr); - // Spawn server task let server_task = tokio::spawn(async move { if let Err(e) = server.await { log::error!("Server error: {}", e); @@ -717,7 +740,6 @@ pub async fn start_server( Ok(true) } -/// Stops the currently running proxy server pub async fn stop_server( server_handle: Arc>>, ) -> Result<(), Box> { @@ -725,7 +747,6 @@ pub async fn stop_server( if let Some(handle) = handle_guard.take() { handle.abort(); - // remove the handle to prevent future use *handle_guard = None; log::info!("Proxy server stopped"); } else { @@ -734,139 +755,3 @@ pub async fn stop_server( Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_filter_models_response_with_downloaded_status() { - let test_response = json!({ - "object": "list", - "data": [ - { - "id": "model1", - "name": "Model 1", - "status": "downloaded" - }, - { - "id": "model2", - "name": "Model 2", - "status": "available" - }, - { - "id": "model3", - "name": "Model 3" - } - ] - }); - - let response_bytes = serde_json::to_vec(&test_response).unwrap(); - let filtered_bytes = filter_models_response(&response_bytes).unwrap(); - let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap(); - - let data = filtered_response["data"].as_array().unwrap(); - assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status) - - // Verify only model1 (with "downloaded" status) is kept - assert!(data.iter().any(|model| model["id"] == "model1")); - - // Verify model2 and model3 are filtered out - assert!(!data.iter().any(|model| model["id"] == "model2")); - assert!(!data.iter().any(|model| model["id"] == "model3")); - } - - #[test] - fn test_filter_models_response_direct_array() { - let test_response = json!([ - { - "id": "model1", - "name": "Model 1", - "status": "downloaded" - }, - { - "id": "model2", - "name": "Model 2", - "status": "available" - } - ]); - - let response_bytes = serde_json::to_vec(&test_response).unwrap(); - let filtered_bytes = filter_models_response(&response_bytes).unwrap(); - let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap(); - - let data = filtered_response.as_array().unwrap(); - assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status) - assert!(data.iter().any(|model| model["id"] == "model1")); - assert!(!data.iter().any(|model| model["id"] == "model2")); - } - - #[test] - fn test_filter_models_response_no_status_field() { - let test_response = json!({ - "object": "list", - "data": [ - { - "id": "model1", - "name": "Model 1" - }, - { - "id": "model2", - "name": "Model 2" - } - ] - }); - - let response_bytes = serde_json::to_vec(&test_response).unwrap(); - let filtered_bytes = filter_models_response(&response_bytes).unwrap(); - let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap(); - - let data = filtered_response["data"].as_array().unwrap(); - assert_eq!(data.len(), 0); // Should remove all models when no status field (no "downloaded" status) - } - - #[test] - fn test_filter_models_response_multiple_downloaded() { - let test_response = json!({ - "object": "list", - "data": [ - { - "id": "model1", - "name": "Model 1", - "status": "downloaded" - }, - { - "id": "model2", - "name": "Model 2", - "status": "available" - }, - { - "id": "model3", - "name": "Model 3", - "status": "downloaded" - }, - { - "id": "model4", - "name": "Model 4", - "status": "installing" - } - ] - }); - - let response_bytes = serde_json::to_vec(&test_response).unwrap(); - let filtered_bytes = filter_models_response(&response_bytes).unwrap(); - let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap(); - - let data = filtered_response["data"].as_array().unwrap(); - assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status) - - // Verify only models with "downloaded" status are kept - assert!(data.iter().any(|model| model["id"] == "model1")); - assert!(data.iter().any(|model| model["id"] == "model3")); - - // Verify other models are filtered out - assert!(!data.iter().any(|model| model["id"] == "model2")); - assert!(!data.iter().any(|model| model["id"] == "model4")); - } -}