feat: Migrate Jan's API server to llamacpp-extension
Things to ponder: - Now, the v1/models endpoint of the API server will return an empty list if no models are loaded - Streaming v1/chat/completion routing works as well as v1/models; needs further testing
This commit is contained in:
parent
e3faf09ab2
commit
d5ffc6a476
@ -333,25 +333,24 @@ pub fn app_token(state: State<'_, AppState>) -> Option<String> {
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn start_server(
|
||||
app: AppHandle,
|
||||
state: State<'_, AppState>,
|
||||
host: String,
|
||||
port: u16,
|
||||
prefix: String,
|
||||
api_key: String,
|
||||
trusted_hosts: Vec<String>,
|
||||
) -> Result<bool, String> {
|
||||
let state = app.state::<AppState>();
|
||||
let auth_token = state.app_token.clone().unwrap_or_default();
|
||||
let server_handle = state.server_handle.clone();
|
||||
let sessions = state.llama_server_process.clone();
|
||||
|
||||
server::start_server(
|
||||
server_handle,
|
||||
sessions,
|
||||
host,
|
||||
port,
|
||||
prefix,
|
||||
auth_token,
|
||||
api_key,
|
||||
trusted_hosts,
|
||||
vec![trusted_hosts],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
@ -1,25 +1,24 @@
|
||||
use flate2::read::GzDecoder;
|
||||
use futures_util::StreamExt;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use hyper::body::Bytes;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::io::Read;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use serde_json;
|
||||
|
||||
use crate::core::state::ServerHandle;
|
||||
|
||||
use crate::core::state::{LLamaBackendSession, ServerHandle};
|
||||
|
||||
/// Configuration for the proxy server
|
||||
#[derive(Clone)]
|
||||
struct ProxyConfig {
|
||||
upstream: String,
|
||||
prefix: String,
|
||||
auth_token: String,
|
||||
trusted_hosts: Vec<String>,
|
||||
api_key: String,
|
||||
proxy_api_key: String,
|
||||
trusted_hosts: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Removes a prefix from a path, ensuring proper formatting
|
||||
@ -30,8 +29,10 @@ fn remove_prefix(path: &str, prefix: &str) -> String {
|
||||
let result = path[prefix.len()..].to_string();
|
||||
if result.is_empty() {
|
||||
"/".to_string()
|
||||
} else {
|
||||
} else if result.starts_with('/') {
|
||||
result
|
||||
} else {
|
||||
format!("/{}", result)
|
||||
}
|
||||
} else {
|
||||
path.to_string()
|
||||
@ -40,25 +41,7 @@ fn remove_prefix(path: &str, prefix: &str) -> String {
|
||||
|
||||
/// Determines the final destination path based on the original request path
|
||||
fn get_destination_path(original_path: &str, prefix: &str) -> String {
|
||||
let removed_prefix_path = remove_prefix(original_path, prefix);
|
||||
|
||||
// Special paths don't need the /v1 prefix
|
||||
if !original_path.contains(prefix)
|
||||
|| removed_prefix_path.contains("/healthz")
|
||||
|| removed_prefix_path.contains("/process")
|
||||
{
|
||||
original_path.to_string()
|
||||
} else {
|
||||
format!("/v1{}", removed_prefix_path)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the full upstream URL for the proxied request
|
||||
fn build_upstream_url(upstream: &str, path: &str) -> String {
|
||||
let upstream_clean = upstream.trim_end_matches('/');
|
||||
let path_clean = path.trim_start_matches('/');
|
||||
|
||||
format!("{}/{}", upstream_clean, path_clean)
|
||||
remove_prefix(original_path, prefix)
|
||||
}
|
||||
|
||||
/// Handles the proxy request logic
|
||||
@ -66,17 +49,8 @@ async fn proxy_request(
|
||||
req: Request<Body>,
|
||||
client: Client,
|
||||
config: ProxyConfig,
|
||||
sessions: Arc<Mutex<HashMap<i32, LLamaBackendSession>>>,
|
||||
) -> Result<Response<Body>, hyper::Error> {
|
||||
// Handle OPTIONS requests for CORS preflight
|
||||
log::debug!(
|
||||
"Received request: {} {} {:?} {:?} {:?}",
|
||||
req.method(),
|
||||
req.uri().path(),
|
||||
req.headers().get(hyper::header::HOST),
|
||||
req.headers().get(hyper::header::ORIGIN),
|
||||
req.headers()
|
||||
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
|
||||
);
|
||||
if req.method() == hyper::Method::OPTIONS {
|
||||
log::debug!(
|
||||
"Handling CORS preflight request from {:?} {:?}",
|
||||
@ -85,21 +59,18 @@ async fn proxy_request(
|
||||
.get(hyper::header::ACCESS_CONTROL_REQUEST_METHOD)
|
||||
);
|
||||
|
||||
// Get the Host header to validate the target (where request is going)
|
||||
let host = req
|
||||
.headers()
|
||||
.get(hyper::header::HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Get the Origin header for CORS response
|
||||
let origin = req
|
||||
.headers()
|
||||
.get(hyper::header::ORIGIN)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Validate requested method
|
||||
let requested_method = req
|
||||
.headers()
|
||||
.get("Access-Control-Request-Method")
|
||||
@ -120,7 +91,6 @@ async fn proxy_request(
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
// Check if the host (target) is trusted, but bypass for whitelisted paths
|
||||
let request_path = req.uri().path();
|
||||
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
|
||||
let is_whitelisted_path = whitelisted_paths.contains(&request_path);
|
||||
@ -133,9 +103,9 @@ async fn proxy_request(
|
||||
true
|
||||
} else if !host.is_empty() {
|
||||
log::debug!(
|
||||
"CORS preflight: Host is '{}', trusted hosts: [{}]",
|
||||
"CORS preflight: Host is '{}', trusted hosts: {:?}",
|
||||
host,
|
||||
&config.trusted_hosts.join(", ")
|
||||
&config.trusted_hosts
|
||||
);
|
||||
is_valid_host(host, &config.trusted_hosts)
|
||||
} else {
|
||||
@ -155,14 +125,12 @@ async fn proxy_request(
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
// Get and validate requested headers
|
||||
let requested_headers = req
|
||||
.headers()
|
||||
.get("Access-Control-Request-Headers")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Allow common headers plus our required ones
|
||||
let allowed_headers = [
|
||||
"accept",
|
||||
"accept-language",
|
||||
@ -216,7 +184,6 @@ async fn proxy_request(
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
// Build CORS response
|
||||
let mut response = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Access-Control-Allow-Methods", allowed_methods.join(", "))
|
||||
@ -227,13 +194,11 @@ async fn proxy_request(
|
||||
"Origin, Access-Control-Request-Method, Access-Control-Request-Headers",
|
||||
);
|
||||
|
||||
// Set Access-Control-Allow-Origin based on origin presence
|
||||
if !origin.is_empty() {
|
||||
response = response
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Credentials", "true");
|
||||
} else {
|
||||
// No origin header - allow all origins (useful for non-browser clients)
|
||||
response = response.header("Access-Control-Allow-Origin", "*");
|
||||
}
|
||||
|
||||
@ -245,26 +210,26 @@ async fn proxy_request(
|
||||
return Ok(response.body(Body::empty()).unwrap());
|
||||
}
|
||||
|
||||
// Extract headers early for validation and CORS responses
|
||||
let origin_header = req
|
||||
.headers()
|
||||
let (parts, body) = req.into_parts();
|
||||
|
||||
let origin_header = parts.headers
|
||||
.get(hyper::header::ORIGIN)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let host_header = req
|
||||
.headers()
|
||||
let host_header = parts.headers
|
||||
.get(hyper::header::HOST)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let original_path = req.uri().path();
|
||||
let path = get_destination_path(original_path, &config.prefix);
|
||||
let method = req.method().clone();
|
||||
let original_path = parts.uri.path();
|
||||
let headers = parts.headers.clone();
|
||||
|
||||
let path = get_destination_path(original_path, &config.prefix);
|
||||
let method = parts.method.clone();
|
||||
|
||||
// Verify Host header (check target), but bypass for whitelisted paths
|
||||
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
|
||||
let is_whitelisted_path = whitelisted_paths.contains(&path.as_str());
|
||||
|
||||
@ -298,12 +263,11 @@ async fn proxy_request(
|
||||
log::debug!("Bypassing host validation for whitelisted path: {}", path);
|
||||
}
|
||||
|
||||
// Skip authorization check for whitelisted paths
|
||||
if !is_whitelisted_path && !config.api_key.is_empty() {
|
||||
if let Some(authorization) = req.headers().get(hyper::header::AUTHORIZATION) {
|
||||
if !is_whitelisted_path && !config.proxy_api_key.is_empty() {
|
||||
if let Some(authorization) = parts.headers.get(hyper::header::AUTHORIZATION) {
|
||||
let auth_str = authorization.to_str().unwrap_or("");
|
||||
|
||||
if auth_str.strip_prefix("Bearer ") != Some(config.api_key.as_str()) {
|
||||
if auth_str.strip_prefix("Bearer ") != Some(config.proxy_api_key.as_str()) {
|
||||
let mut error_response = Response::builder().status(StatusCode::UNAUTHORIZED);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
@ -334,7 +298,6 @@ async fn proxy_request(
|
||||
);
|
||||
}
|
||||
|
||||
// Block access to /configs endpoint
|
||||
if path.contains("/configs") {
|
||||
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
@ -346,64 +309,23 @@ async fn proxy_request(
|
||||
return Ok(error_response.body(Body::from("Not Found")).unwrap());
|
||||
}
|
||||
|
||||
// Build the outbound request
|
||||
let upstream_url = build_upstream_url(&config.upstream, &path);
|
||||
log::debug!("Proxying request to: {}", upstream_url);
|
||||
let mut target_port: Option<i32> = None;
|
||||
let mut session_api_key: Option<String> = None;
|
||||
let mut buffered_body: Option<Bytes> = None;
|
||||
let original_path = parts.uri.path();
|
||||
let destination_path = get_destination_path(original_path, &config.prefix);
|
||||
|
||||
let mut outbound_req = client.request(req.method().clone(), &upstream_url);
|
||||
|
||||
// Copy original headers
|
||||
for (name, value) in req.headers() {
|
||||
// Skip host & authorization header
|
||||
if name != hyper::header::HOST && name != hyper::header::AUTHORIZATION {
|
||||
outbound_req = outbound_req.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add authorization header
|
||||
outbound_req = outbound_req.header("Authorization", format!("Bearer {}", config.auth_token));
|
||||
|
||||
// Send the request and handle the response
|
||||
match outbound_req.body(req.into_body()).send().await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
log::debug!("Received response with status: {}", status);
|
||||
|
||||
let mut builder = Response::builder().status(status);
|
||||
|
||||
// Copy response headers, excluding CORS headers and Content-Length to avoid conflicts
|
||||
for (name, value) in response.headers() {
|
||||
// Skip CORS headers from upstream to avoid duplicates
|
||||
// Skip Content-Length header when filtering models response to avoid mismatch
|
||||
if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add our own CORS headers
|
||||
builder = add_cors_headers_with_host_and_origin(
|
||||
builder,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
match (method.clone(), destination_path.as_str()) {
|
||||
(hyper::Method::POST, "/chat/completions")
|
||||
| (hyper::Method::POST, "/completions")
|
||||
| (hyper::Method::POST, "/embeddings") => {
|
||||
log::debug!(
|
||||
"Handling POST request to {} requiring model lookup in body",
|
||||
destination_path
|
||||
);
|
||||
|
||||
// Handle streaming vs non-streaming responses
|
||||
if path.contains("/models") && method == hyper::Method::GET {
|
||||
// For /models endpoint, we need to buffer and filter the response
|
||||
match response.bytes().await {
|
||||
Ok(bytes) => match filter_models_response(&bytes) {
|
||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Failed to filter models response: {}, returning original",
|
||||
e
|
||||
);
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
log::error!("Failed to read response body: {}", e);
|
||||
let body_bytes = match hyper::body::to_bytes(body).await {
|
||||
Ok(bytes) => bytes,
|
||||
Err(_) => {
|
||||
let mut error_response =
|
||||
Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
@ -412,17 +334,238 @@ async fn proxy_request(
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
Ok(error_response
|
||||
.body(Body::from("Error reading upstream response"))
|
||||
.unwrap())
|
||||
return Ok(error_response
|
||||
.body(Body::from("Failed to read request body"))
|
||||
.unwrap());
|
||||
}
|
||||
};
|
||||
buffered_body = Some(body_bytes.clone());
|
||||
|
||||
match serde_json::from_slice::<serde_json::Value>(&body_bytes) {
|
||||
Ok(json_body) => {
|
||||
if let Some(model_id) = json_body.get("model").and_then(|v| v.as_str()) {
|
||||
log::debug!("Extracted model_id: {}", model_id);
|
||||
let sessions_guard = sessions.lock().await;
|
||||
|
||||
if sessions_guard.is_empty() {
|
||||
log::warn!("Request for model '{}' but no backend servers are running.", model_id);
|
||||
let mut error_response = Response::builder().status(StatusCode::SERVICE_UNAVAILABLE);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response.body(Body::from("No backend model servers are available")).unwrap());
|
||||
}
|
||||
|
||||
if let Some(session) = sessions_guard
|
||||
.values()
|
||||
.find(|s| s.info.model_id == model_id)
|
||||
{
|
||||
target_port = Some(session.info.port);
|
||||
session_api_key = Some(session.info.api_key.clone());
|
||||
log::debug!(
|
||||
"Found session for model_id {} on port {}",
|
||||
model_id,
|
||||
session.info.port
|
||||
);
|
||||
} else {
|
||||
log::warn!("No running session found for model_id: {}", model_id);
|
||||
let mut error_response =
|
||||
Response::builder().status(StatusCode::NOT_FOUND);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response
|
||||
.body(Body::from(format!(
|
||||
"No running server found for model '{}'",
|
||||
model_id
|
||||
)))
|
||||
.unwrap());
|
||||
}
|
||||
} else {
|
||||
// For streaming endpoints (like chat completions), we need to collect and forward the stream
|
||||
log::warn!(
|
||||
"POST body for {} is missing 'model' field or it's not a string",
|
||||
destination_path
|
||||
);
|
||||
let mut error_response =
|
||||
Response::builder().status(StatusCode::BAD_REQUEST);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response
|
||||
.body(Body::from("Request body must contain a 'model' field"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Failed to parse POST body for {} as JSON: {}",
|
||||
destination_path,
|
||||
e
|
||||
);
|
||||
let mut error_response = Response::builder().status(StatusCode::BAD_REQUEST);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response
|
||||
.body(Body::from("Invalid JSON body"))
|
||||
.unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
(hyper::Method::GET, "/models") => {
|
||||
log::debug!("Handling GET /v1/models request");
|
||||
let sessions_guard = sessions.lock().await;
|
||||
|
||||
let models_data: Vec<_> = sessions_guard
|
||||
.values()
|
||||
.map(|session| {
|
||||
serde_json::json!({
|
||||
"id": session.info.model_id,
|
||||
"object": "model",
|
||||
"created": 1,
|
||||
"owned_by": "user"
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response_json = serde_json::json!({
|
||||
"object": "list",
|
||||
"data": models_data
|
||||
});
|
||||
|
||||
let body_str = serde_json::to_string(&response_json).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
let mut response_builder = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(hyper::header::CONTENT_TYPE, "application/json");
|
||||
|
||||
response_builder = add_cors_headers_with_host_and_origin(
|
||||
response_builder,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
|
||||
return Ok(response_builder.body(Body::from(body_str)).unwrap());
|
||||
}
|
||||
_ => {
|
||||
let is_explicitly_whitelisted_get = method == hyper::Method::GET
|
||||
&& whitelisted_paths.contains(&destination_path.as_str());
|
||||
if is_explicitly_whitelisted_get {
|
||||
log::debug!("Handled whitelisted GET path: {}", destination_path);
|
||||
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response.body(Body::from("Not Found")).unwrap());
|
||||
} else {
|
||||
log::warn!(
|
||||
"Unhandled method/path for dynamic routing: {} {}",
|
||||
method,
|
||||
destination_path
|
||||
);
|
||||
let mut error_response = Response::builder().status(StatusCode::NOT_FOUND);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response.body(Body::from("Not Found")).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let port = match target_port {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
log::error!("Internal routing error: target_port is None after successful lookup");
|
||||
let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response
|
||||
.body(Body::from("Internal routing error"))
|
||||
.unwrap());
|
||||
}
|
||||
};
|
||||
|
||||
let upstream_url = format!("http://127.0.0.1:{}{}", port, destination_path);
|
||||
log::debug!("Proxying request to: {}", upstream_url);
|
||||
|
||||
let mut outbound_req = client.request(method.clone(), &upstream_url);
|
||||
|
||||
for (name, value) in headers.iter() {
|
||||
if name != hyper::header::HOST && name != hyper::header::AUTHORIZATION {
|
||||
outbound_req = outbound_req.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(key) = session_api_key {
|
||||
log::debug!("Adding session Authorization header");
|
||||
outbound_req = outbound_req.header("Authorization", format!("Bearer {}", key));
|
||||
} else {
|
||||
log::debug!("No session API key available for this request");
|
||||
}
|
||||
|
||||
let outbound_req_with_body = if let Some(bytes) = buffered_body {
|
||||
log::debug!("Sending buffered body ({} bytes)", bytes.len());
|
||||
outbound_req.body(bytes)
|
||||
} else {
|
||||
log::error!("Internal logic error: Request reached proxy stage without a buffered body.");
|
||||
let mut error_response = Response::builder().status(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
return Ok(error_response
|
||||
.body(Body::from("Internal server error: unhandled request path"))
|
||||
.unwrap());
|
||||
};
|
||||
|
||||
match outbound_req_with_body.send().await {
|
||||
Ok(response) => {
|
||||
let status = response.status();
|
||||
log::debug!("Received response with status: {}", status);
|
||||
|
||||
let mut builder = Response::builder().status(status);
|
||||
|
||||
for (name, value) in response.headers() {
|
||||
if !is_cors_header(name.as_str()) && name != hyper::header::CONTENT_LENGTH {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
builder = add_cors_headers_with_host_and_origin(
|
||||
builder,
|
||||
&host_header,
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let (mut sender, body) = hyper::Body::channel();
|
||||
|
||||
// Spawn a task to forward the stream
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
@ -438,13 +581,14 @@ async fn proxy_request(
|
||||
}
|
||||
}
|
||||
}
|
||||
log::debug!("Streaming complete to client");
|
||||
});
|
||||
|
||||
Ok(builder.body(body).unwrap())
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("Proxy request failed: {}", e);
|
||||
let error_msg = format!("Proxy request to {} failed: {}", upstream_url, e);
|
||||
log::error!("{}", error_msg);
|
||||
let mut error_response = Response::builder().status(StatusCode::BAD_GATEWAY);
|
||||
error_response = add_cors_headers_with_host_and_origin(
|
||||
error_response,
|
||||
@ -452,148 +596,45 @@ async fn proxy_request(
|
||||
&origin_header,
|
||||
&config.trusted_hosts,
|
||||
);
|
||||
Ok(error_response
|
||||
.body(Body::from(format!("Upstream error: {}", e)))
|
||||
.unwrap())
|
||||
Ok(error_response.body(Body::from(error_msg)).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the byte array starts with gzip magic number
|
||||
fn is_gzip_encoded(bytes: &[u8]) -> bool {
|
||||
bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b
|
||||
}
|
||||
|
||||
/// Decompresses gzip-encoded bytes
|
||||
fn decompress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut decoder = GzDecoder::new(bytes);
|
||||
let mut decompressed = Vec::new();
|
||||
decoder.read_to_end(&mut decompressed)?;
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
/// Compresses bytes using gzip
|
||||
fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(bytes)?;
|
||||
let compressed = encoder.finish()?;
|
||||
Ok(compressed)
|
||||
}
|
||||
|
||||
/// Filters models response to keep only models with status "downloaded"
|
||||
fn filter_models_response(
|
||||
bytes: &[u8],
|
||||
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Try to decompress if it's gzip-encoded
|
||||
let decompressed_bytes = if is_gzip_encoded(bytes) {
|
||||
log::debug!("Response is gzip-encoded, decompressing...");
|
||||
decompress_gzip(bytes)?
|
||||
} else {
|
||||
bytes.to_vec()
|
||||
};
|
||||
|
||||
let response_text = std::str::from_utf8(&decompressed_bytes)?;
|
||||
let mut response_json: Value = serde_json::from_str(response_text)?;
|
||||
|
||||
// Check if this is a ListModelsResponseDto format with data array
|
||||
if let Some(data_array) = response_json.get_mut("data") {
|
||||
if let Some(models) = data_array.as_array_mut() {
|
||||
// Keep only models where status == "downloaded"
|
||||
models.retain(|model| {
|
||||
if let Some(status) = model.get("status") {
|
||||
if let Some(status_str) = status.as_str() {
|
||||
status_str == "downloaded"
|
||||
} else {
|
||||
false // Remove models without string status
|
||||
}
|
||||
} else {
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!(
|
||||
"Filtered models response: {} downloaded models remaining",
|
||||
models.len()
|
||||
);
|
||||
}
|
||||
} else if response_json.is_array() {
|
||||
// Handle direct array format
|
||||
if let Some(models) = response_json.as_array_mut() {
|
||||
models.retain(|model| {
|
||||
if let Some(status) = model.get("status") {
|
||||
if let Some(status_str) = status.as_str() {
|
||||
status_str == "downloaded"
|
||||
} else {
|
||||
false // Remove models without string status
|
||||
}
|
||||
} else {
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!(
|
||||
"Filtered models response: {} downloaded models remaining",
|
||||
models.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let filtered_json = serde_json::to_vec(&response_json)?;
|
||||
|
||||
// If original was gzip-encoded, re-compress the filtered response
|
||||
if is_gzip_encoded(bytes) {
|
||||
log::debug!("Re-compressing filtered response with gzip");
|
||||
compress_gzip(&filtered_json)
|
||||
} else {
|
||||
Ok(filtered_json)
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a header is a CORS-related header that should be filtered out from upstream responses
|
||||
fn is_cors_header(header_name: &str) -> bool {
|
||||
let header_lower = header_name.to_lowercase();
|
||||
header_lower.starts_with("access-control-")
|
||||
}
|
||||
|
||||
/// Adds CORS headers to a response builder using host for validation and origin for response
|
||||
fn add_cors_headers_with_host_and_origin(
|
||||
builder: hyper::http::response::Builder,
|
||||
host: &str,
|
||||
origin: &str,
|
||||
trusted_hosts: &[String],
|
||||
trusted_hosts: &[Vec<String>],
|
||||
) -> hyper::http::response::Builder {
|
||||
let mut builder = builder;
|
||||
|
||||
// Check if host (target) is trusted - this is what we validate
|
||||
let is_trusted = if !host.is_empty() {
|
||||
is_valid_host(host, trusted_hosts)
|
||||
let allow_origin_header = if !origin.is_empty() && is_valid_host(host, trusted_hosts) {
|
||||
origin.to_string()
|
||||
} else if !origin.is_empty() {
|
||||
origin.to_string()
|
||||
} else {
|
||||
false // Host is required for validation
|
||||
"*".to_string()
|
||||
};
|
||||
|
||||
// Set CORS headers using origin for the response
|
||||
if !origin.is_empty() && is_trusted {
|
||||
builder = builder
|
||||
.header("Access-Control-Allow-Origin", origin)
|
||||
.header("Access-Control-Allow-Credentials", "true");
|
||||
} else if !origin.is_empty() {
|
||||
builder = builder.header("Access-Control-Allow-Origin", origin);
|
||||
} else {
|
||||
builder = builder.header("Access-Control-Allow-Origin", "*");
|
||||
}
|
||||
|
||||
builder = builder
|
||||
.header("Access-Control-Allow-Origin", allow_origin_header.clone())
|
||||
.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, PATCH")
|
||||
.header("Access-Control-Allow-Headers", "Authorization, Content-Type, Host, Accept, Accept-Language, Cache-Control, Connection, DNT, If-Modified-Since, Keep-Alive, Origin, User-Agent, X-Requested-With, X-CSRF-Token, X-Forwarded-For, X-Forwarded-Proto, X-Forwarded-Host, authorization, content-type, x-api-key")
|
||||
.header("Vary", "Origin");
|
||||
|
||||
if allow_origin_header != "*" {
|
||||
builder = builder.header("Access-Control-Allow-Credentials", "true");
|
||||
}
|
||||
|
||||
builder
|
||||
}
|
||||
|
||||
// Validates if the host header is allowed
|
||||
fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
|
||||
fn is_valid_host(host: &str, trusted_hosts: &[Vec<String>]) -> bool {
|
||||
if host.is_empty() {
|
||||
return false;
|
||||
}
|
||||
@ -608,7 +649,6 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
|
||||
};
|
||||
let default_valid_hosts = ["localhost", "127.0.0.1", "0.0.0.0"];
|
||||
|
||||
// Check default valid hosts (host part only)
|
||||
if default_valid_hosts
|
||||
.iter()
|
||||
.any(|&valid| host_without_port.to_lowercase() == valid.to_lowercase())
|
||||
@ -616,17 +656,14 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check trusted hosts - support both full host:port and host-only formats
|
||||
trusted_hosts.iter().any(|valid| {
|
||||
trusted_hosts.iter().flatten().any(|valid| {
|
||||
let host_lower = host.to_lowercase();
|
||||
let valid_lower = valid.to_lowercase();
|
||||
|
||||
// First check exact match (including port)
|
||||
if host_lower == valid_lower {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Then check host part only (without port)
|
||||
let valid_without_port = if valid.starts_with('[') {
|
||||
valid
|
||||
.split(']')
|
||||
@ -643,68 +680,54 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
|
||||
|
||||
pub async fn is_server_running(server_handle: Arc<Mutex<Option<ServerHandle>>>) -> bool {
|
||||
let handle_guard = server_handle.lock().await;
|
||||
|
||||
if handle_guard.is_some() {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
handle_guard.is_some()
|
||||
}
|
||||
|
||||
/// Starts the proxy server
|
||||
pub async fn start_server(
|
||||
server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
sessions: Arc<Mutex<HashMap<i32, LLamaBackendSession>>>,
|
||||
host: String,
|
||||
port: u16,
|
||||
prefix: String,
|
||||
auth_token: String,
|
||||
api_key: String,
|
||||
trusted_hosts: Vec<String>,
|
||||
proxy_api_key: String,
|
||||
trusted_hosts: Vec<Vec<String>>,
|
||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Check if server is already running
|
||||
let mut handle_guard = server_handle.lock().await;
|
||||
if handle_guard.is_some() {
|
||||
return Err("Server is already running".into());
|
||||
}
|
||||
|
||||
// Create server address
|
||||
let addr: SocketAddr = format!("{}:{}", host, port)
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid address: {}", e))?;
|
||||
|
||||
// Configure proxy settings
|
||||
let config = ProxyConfig {
|
||||
upstream: "http://127.0.0.1:39291".to_string(),
|
||||
prefix,
|
||||
auth_token,
|
||||
api_key,
|
||||
proxy_api_key,
|
||||
trusted_hosts,
|
||||
};
|
||||
|
||||
// Create HTTP client with longer timeout for streaming
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5 minutes for streaming
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.pool_max_idle_per_host(10)
|
||||
.pool_idle_timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
// Create service handler
|
||||
let make_svc = make_service_fn(move |_conn| {
|
||||
let client = client.clone();
|
||||
let config = config.clone();
|
||||
let sessions = sessions.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
proxy_request(req, client.clone(), config.clone())
|
||||
proxy_request(req, client.clone(), config.clone(), sessions.clone())
|
||||
}))
|
||||
}
|
||||
});
|
||||
|
||||
// Create and start the server
|
||||
let server = Server::bind(&addr).serve(make_svc);
|
||||
log::info!("Proxy server started on http://{}", addr);
|
||||
|
||||
// Spawn server task
|
||||
let server_task = tokio::spawn(async move {
|
||||
if let Err(e) = server.await {
|
||||
log::error!("Server error: {}", e);
|
||||
@ -717,7 +740,6 @@ pub async fn start_server(
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Stops the currently running proxy server
|
||||
pub async fn stop_server(
|
||||
server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
@ -725,7 +747,6 @@ pub async fn stop_server(
|
||||
|
||||
if let Some(handle) = handle_guard.take() {
|
||||
handle.abort();
|
||||
// remove the handle to prevent future use
|
||||
*handle_guard = None;
|
||||
log::info!("Proxy server stopped");
|
||||
} else {
|
||||
@ -734,139 +755,3 @@ pub async fn stop_server(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_with_downloaded_status() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
},
|
||||
{
|
||||
"id": "model3",
|
||||
"name": "Model 3"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
|
||||
|
||||
// Verify only model1 (with "downloaded" status) is kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
|
||||
// Verify model2 and model3 are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_direct_array() {
|
||||
let test_response = json!([
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
}
|
||||
]);
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response.as_array().unwrap();
|
||||
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_no_status_field() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 0); // Should remove all models when no status field (no "downloaded" status)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_models_response_multiple_downloaded() {
|
||||
let test_response = json!({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "model1",
|
||||
"name": "Model 1",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model2",
|
||||
"name": "Model 2",
|
||||
"status": "available"
|
||||
},
|
||||
{
|
||||
"id": "model3",
|
||||
"name": "Model 3",
|
||||
"status": "downloaded"
|
||||
},
|
||||
{
|
||||
"id": "model4",
|
||||
"name": "Model 4",
|
||||
"status": "installing"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let response_bytes = serde_json::to_vec(&test_response).unwrap();
|
||||
let filtered_bytes = filter_models_response(&response_bytes).unwrap();
|
||||
let filtered_response: serde_json::Value = serde_json::from_slice(&filtered_bytes).unwrap();
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
|
||||
|
||||
// Verify only models with "downloaded" status are kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
assert!(data.iter().any(|model| model["id"] == "model3"));
|
||||
|
||||
// Verify other models are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model4"));
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user