diff --git a/src-tauri/src/core/server.rs b/src-tauri/src/core/server.rs index e5a784670..f4f270106 100644 --- a/src-tauri/src/core/server.rs +++ b/src-tauri/src/core/server.rs @@ -1,6 +1,7 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server, StatusCode}; use reqwest::Client; +use serde_json::Value; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::LazyLock; @@ -263,6 +264,7 @@ async fn proxy_request( let original_path = req.uri().path(); let path = get_destination_path(original_path, &config.prefix); + let method = req.method().clone(); // Verify Host header (check target), but bypass for whitelisted paths let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; @@ -368,10 +370,11 @@ async fn proxy_request( let mut builder = Response::builder().status(status); - // Copy response headers, excluding CORS headers to avoid conflicts + // Copy response headers, excluding CORS headers and Content-Length to avoid conflicts for (name, value) in response.headers() { // Skip CORS headers from upstream to avoid duplicates - if !is_cors_header(name.as_str()) { + // Skip Content-Length header when filtering models response to avoid mismatch + if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH { builder = builder.header(name, value); } } @@ -386,7 +389,20 @@ async fn proxy_request( // Read response body match response.bytes().await { - Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()), + Ok(bytes) => { + // Check if this is a /models endpoint request and filter the response + if path.contains("/models") && method == hyper::Method::GET { + match filter_models_response(&bytes) { + Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), + Err(e) => { + log::warn!("Failed to filter models response: {}, returning original", e); + Ok(builder.body(Body::from(bytes)).unwrap()) + } + } + } else { + Ok(builder.body(Body::from(bytes)).unwrap()) + } + }, Err(e) => { log::error!("Failed to read response body: {}", e); let mut error_response = @@ -419,6 +435,50 @@ async fn proxy_request( } } +/// Filters models response to keep only models with status "downloaded" +fn filter_models_response(bytes: &[u8]) -> Result, Box> { + let response_text = std::str::from_utf8(bytes)?; + let mut response_json: Value = serde_json::from_str(response_text)?; + + // Check if this is a ListModelsResponseDto format with data array + if let Some(data_array) = response_json.get_mut("data") { + if let Some(models) = data_array.as_array_mut() { + // Keep only models where status == "downloaded" + models.retain(|model| { + if let Some(status) = model.get("status") { + if let Some(status_str) = status.as_str() { + status_str == "downloaded" + } else { + false // Remove models without string status + } + } else { + false // Remove models without status field + } + }); + log::debug!("Filtered models response: {} downloaded models remaining", models.len()); + } + } else if response_json.is_array() { + // Handle direct array format + if let Some(models) = response_json.as_array_mut() { + models.retain(|model| { + if let Some(status) = model.get("status") { + if let Some(status_str) = status.as_str() { + status_str == "downloaded" + } else { + false // Remove models without string status + } + } else { + false // Remove models without status field + } + }); + log::debug!("Filtered models response: {} downloaded models remaining", models.len()); + } + } + + let filtered_response = serde_json::to_vec(&response_json)?; + Ok(filtered_response) +} + /// Checks if a header is a CORS-related header that should be filtered out from upstream responses fn is_cors_header(header_name: &str) -> bool { let header_lower = header_name.to_lowercase(); @@ -585,3 +645,139 @@ pub async fn stop_server() -> Result<(), Box