feat(server): filter /models endpoint to show only downloaded models (#5343)
- Add filtering logic to proxy server for GET /models requests - Keep only models with status "downloaded" in response - Remove Content-Length header to prevent mismatch after filtering - Support both ListModelsResponseDto and direct array formats - Add comprehensive tests for filtering functionality - Fix Content-Length header conflict causing empty responses Fixes issue where all models were returned regardless of download status.
This commit is contained in:
parent
771105a5b2
commit
369ba5ac75
@ -1,6 +1,7 @@
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::LazyLock;
|
||||
@ -263,6 +264,7 @@ async fn proxy_request(
|
||||
|
||||
let original_path = req.uri().path();
|
||||
let path = get_destination_path(original_path, &config.prefix);
|
||||
let method = req.method().clone();
|
||||
|
||||
// Verify Host header (check target), but bypass for whitelisted paths
|
||||
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
|
||||
@ -368,10 +370,11 @@ async fn proxy_request(
|
||||
|
||||
let mut builder = Response::builder().status(status);
|
||||
|
||||
// Copy response headers, excluding CORS headers to avoid conflicts
|
||||
// Copy response headers, excluding CORS headers and Content-Length to avoid conflicts
|
||||
for (name, value) in response.headers() {
|
||||
// Skip CORS headers from upstream to avoid duplicates
|
||||
if !is_cors_header(name.as_str()) {
|
||||
// Skip Content-Length header when filtering models response to avoid mismatch
|
||||
if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
}
|
||||
@ -386,7 +389,20 @@ async fn proxy_request(
|
||||
|
||||
// Read response body
|
||||
match response.bytes().await {
|
||||
Ok(bytes) => Ok(builder.body(Body::from(bytes)).unwrap()),
|
||||
Ok(bytes) => {
|
||||
// Check if this is a /models endpoint request and filter the response
|
||||
if path.contains("/models") && method == hyper::Method::GET {
|
||||
match filter_models_response(&bytes) {
|
||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||
Err(e) => {
|
||||
log::warn!("Failed to filter models response: {}, returning original", e);
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read response body: {}", e);
|
||||
let mut error_response =
|
||||
@ -419,6 +435,50 @@ async fn proxy_request(
|
||||
}
|
||||
}
|
||||
|
||||
/// Filters models response to keep only models with status "downloaded"
|
||||
fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let response_text = std::str::from_utf8(bytes)?;
|
||||
let mut response_json: Value = serde_json::from_str(response_text)?;
|
||||
|
||||
// Check if this is a ListModelsResponseDto format with data array
|
||||
if let Some(data_array) = response_json.get_mut("data") {
|
||||
if let Some(models) = data_array.as_array_mut() {
|
||||
// Keep only models where status == "downloaded"
|
||||
models.retain(|model| {
|
||||
if let Some(status) = model.get("status") {
|
||||
if let Some(status_str) = status.as_str() {
|
||||
status_str == "downloaded"
|
||||
} else {
|
||||
false // Remove models without string status
|
||||
}
|
||||
} else {
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
|
||||
}
|
||||
} else if response_json.is_array() {
|
||||
// Handle direct array format
|
||||
if let Some(models) = response_json.as_array_mut() {
|
||||
models.retain(|model| {
|
||||
if let Some(status) = model.get("status") {
|
||||
if let Some(status_str) = status.as_str() {
|
||||
status_str == "downloaded"
|
||||
} else {
|
||||
false // Remove models without string status
|
||||
}
|
||||
} else {
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
|
||||
}
|
||||
}
|
||||
|
||||
let filtered_response = serde_json::to_vec(&response_json)?;
|
||||
Ok(filtered_response)
|
||||
}
|
||||
|
||||
/// Checks if a header is a CORS-related header that should be filtered out from upstream responses
|
||||
fn is_cors_header(header_name: &str) -> bool {
|
||||
let header_lower = header_name.to_lowercase();
|
||||
@ -585,3 +645,139 @@ pub async fn stop_server() -> Result<(), Box<dyn std::error::Error + Send + Sync
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_with_downloaded_status() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
},
|
||||
{
|
||||
"id": "model3",
|
||||
"name": "Model 3"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
|
||||
|
||||
// Verify only model1 (with "downloaded" status) is kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
|
||||
// Verify model2 and model3 are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_direct_array() {
|
||||
let test_response = json!([
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
}
|
||||
]);
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response.as_array().unwrap();
|
||||
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_no_status_field() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 0); // Should remove all models when no status field (no "downloaded" status)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_multiple_downloaded() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
},
|
||||
{
|
||||
"id": "model3",
|
||||
"name": "Model 3",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model4",
|
||||
"name": "Model 4",
|
||||
"status": "installing"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
|
||||
|
||||
// Verify only models with "downloaded" status are kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
assert!(data.iter().any(|model| model["id"] == "model3"));
|
||||
|
||||
// Verify other models are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model4"));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user