feat(server): filter /models endpoint to show only downloaded models (#5343)

- Add filtering logic to proxy server for GET /models requests
- Keep only models with status "downloaded" in response
- Remove Content-Length header to prevent mismatch after filtering
- Support both ListModelsResponseDto and direct array formats
- Add comprehensive tests for filtering functionality
- Fix Content-Length header conflict causing empty responses

Fixes issue where all models were returned regardless of download status.
This commit is contained in:
Sam Hoang Van 2025-06-18 14:11:53 +07:00 committed by GitHub
parent 771105a5b2
commit 369ba5ac75
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode};
use reqwest::Client;
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::LazyLock;
@ -263,6 +264,7 @@ async fn proxy_request(
let original_path = req.uri().path();
let path = get_destination_path(original_path, &config.prefix);
let method = req.method().clone();
// Verify Host header (check target), but bypass for whitelisted paths
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
@ -368,10 +370,11 @@ async fn proxy_request(
let mut builder = Response::builder().status(status);
// Copy response headers, excluding CORS headers to avoid conflicts
// Copy response headers, excluding CORS headers and Content-Length to avoid conflicts
for (name, value) in response.headers() {
// Skip CORS headers from upstream to avoid duplicates
if !is_cors_header(name.as_str()) {
// Skip Content-Length header when filtering models response to avoid mismatch
if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH {
builder = builder.header(name, value);
}
}
@ -386,7 +389,20 @@ async fn proxy_request(
// Read response body
match response.bytes().await {
Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()),
Ok(bytes) => {
// Check if this is a /models endpoint request and filter the response
if path.contains("/models") && method == hyper::Method::GET {
match filter_models_response(&bytes) {
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
Err(e) => {
log::warn!("Failed to filter models response: {}, returning original", e);
Ok(builder.body(Body::from(bytes)).unwrap())
}
}
} else {
Ok(builder.body(Body::from(bytes)).unwrap())
}
},
Err(e) => {
log::error!("Failed to read response body: {}", e);
let mut error_response =
@ -419,6 +435,50 @@ async fn proxy_request(
}
}
/// Filters models response to keep only models with status "downloaded"
fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let response_text = std::str::from_utf8(bytes)?;
let mut response_json: Value = serde_json::from_str(response_text)?;
// Check if this is a ListModelsResponseDto format with data array
if let Some(data_array) = response_json.get_mut("data") {
if let Some(models) = data_array.as_array_mut() {
// Keep only models where status == "downloaded"
models.retain(|model| {
if let Some(status) = model.get("status") {
if let Some(status_str) = status.as_str() {
status_str == "downloaded"
} else {
false // Remove models without string status
}
} else {
false // Remove models without status field
}
});
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
}
} else if response_json.is_array() {
// Handle direct array format
if let Some(models) = response_json.as_array_mut() {
models.retain(|model| {
if let Some(status) = model.get("status") {
if let Some(status_str) = status.as_str() {
status_str == "downloaded"
} else {
false // Remove models without string status
}
} else {
false // Remove models without status field
}
});
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
}
}
let filtered_response = serde_json::to_vec(&response_json)?;
Ok(filtered_response)
}
/// Checks if a header is a CORS-related header that should be filtered out from upstream responses
fn is_cors_header(header_name: &str) -> bool {
let header_lower = header_name.to_lowercase();
@ -585,3 +645,139 @@ pub async fn stop_server() -> Result<(), Box<dyn std::error::Error + Send + Sync
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_filter_models_response_with_downloaded_status() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
},
{
"id": "model3",
"name": "Model 3"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
// Verify only model1 (with "downloaded" status) is kept
assert!(data.iter().any(|model| model["id"] == "model1"));
// Verify model2 and model3 are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model3"));
}
#[test]
fn test_filter_models_response_direct_array() {
let test_response = json!([
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
}
]);
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response.as_array().unwrap();
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
assert!(data.iter().any(|model| model["id"] == "model1"));
assert!(!data.iter().any(|model| model["id"] == "model2"));
}
#[test]
fn test_filter_models_response_no_status_field() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1"
},
{
"id": "model2",
"name": "Model 2"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 0); // Should remove all models when no status field (no "downloaded" status)
}
#[test]
fn test_filter_models_response_multiple_downloaded() {
let test_response = json!({
"object": "list",
"data": [
{
"id": "model1",
"name": "Model 1",
"status": "downloaded"
},
{
"id": "model2",
"name": "Model 2",
"status": "available"
},
{
"id": "model3",
"name": "Model 3",
"status": "downloaded"
},
{
"id": "model4",
"name": "Model 4",
"status": "installing"
}
]
});
let response_bytes = serde_json::to_vec(&test_response).unwrap();
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
// Verify only models with "downloaded" status are kept
assert!(data.iter().any(|model| model["id"] == "model1"));
assert!(data.iter().any(|model| model["id"] == "model3"));
// Verify other models are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model4"));
}
}