diff --git a/src-tauri/src/core/cmd.rs b/src-tauri/src/core/cmd.rs index 4a48e63d3..4b4463d12 100644 --- a/src-tauri/src/core/cmd.rs +++ b/src-tauri/src/core/cmd.rs @@ -348,23 +348,41 @@ pub async fn start_server( api_key: String, trusted_hosts: Vec, ) -> Result { - let auth_token = app - .state::() - .app_token - .clone() - .unwrap_or_default(); - server::start_server(host, port, prefix, auth_token, api_key, trusted_hosts) - .await - .map_err(|e| e.to_string())?; + let state = app.state::(); + let auth_token = state.app_token.clone().unwrap_or_default(); + let server_handle = state.server_handle.clone(); + + server::start_server( + server_handle, + host, + port, + prefix, + auth_token, + api_key, + trusted_hosts, + ) + .await + .map_err(|e| e.to_string())?; Ok(true) } #[tauri::command] -pub async fn stop_server() -> Result<(), String> { - server::stop_server().await.map_err(|e| e.to_string())?; +pub async fn stop_server(state: State<'_, AppState>) -> Result<(), String> { + let server_handle = state.server_handle.clone(); + + server::stop_server(server_handle) + .await + .map_err(|e| e.to_string())?; Ok(()) } +#[tauri::command] +pub async fn get_server_status(state: State<'_, AppState>) -> Result { + let server_handle = state.server_handle.clone(); + + Ok(server::is_server_running(server_handle).await) +} + #[tauri::command] pub async fn read_logs(app: AppHandle) -> Result { let log_path = get_jan_data_folder_path(app).join("logs").join("app.log"); diff --git a/src-tauri/src/core/server.rs b/src-tauri/src/core/server.rs index ee8b1cbb1..6da4ebf9b 100644 --- a/src-tauri/src/core/server.rs +++ b/src-tauri/src/core/server.rs @@ -1,21 +1,16 @@ +use flate2::read::GzDecoder; +use futures_util::StreamExt; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Request, Response, Server, StatusCode}; use reqwest::Client; use serde_json::Value; use std::convert::Infallible; -use std::net::SocketAddr; -use std::sync::LazyLock; -use tokio::sync::Mutex; -use tokio::task::JoinHandle; -use futures_util::StreamExt; -use flate2::read::GzDecoder; use std::io::Read; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::Mutex; -/// Server handle type for managing the proxy server lifecycle -type ServerHandle = JoinHandle>>; - -/// Global singleton for the current server instance -static SERVER_HANDLE: LazyLock>> = LazyLock::new(|| Mutex::new(None)); +use crate::core::state::ServerHandle; /// Configuration for the proxy server #[derive(Clone)] @@ -272,7 +267,7 @@ async fn proxy_request( // Verify Host header (check target), but bypass for whitelisted paths let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"]; let is_whitelisted_path = whitelisted_paths.contains(&path.as_str()); - + if !is_whitelisted_path { if !host_header.is_empty() { if !is_valid_host(&host_header, &config.trusted_hosts) { @@ -333,7 +328,10 @@ async fn proxy_request( .unwrap()); } } else if is_whitelisted_path { - log::debug!("Bypassing authorization check for whitelisted path: {}", path); + log::debug!( + "Bypassing authorization check for whitelisted path: {}", + path + ); } // Block access to /configs endpoint @@ -394,13 +392,14 @@ async fn proxy_request( if path.contains("/models") && method == hyper::Method::GET { // For /models endpoint, we need to buffer and filter the response match response.bytes().await { - Ok(bytes) => { - match filter_models_response(&bytes) { - Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), - Err(e) => { - log::warn!("Failed to filter models response: {}, returning original", e); - Ok(builder.body(Body::from(bytes)).unwrap()) - } + Ok(bytes) => match filter_models_response(&bytes) { + Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()), + Err(e) => { + log::warn!( + "Failed to filter models response: {}, returning original", + e + ); + Ok(builder.body(Body::from(bytes)).unwrap()) } }, Err(e) => { @@ -422,7 +421,7 @@ async fn proxy_request( // For streaming endpoints (like chat completions), we need to collect and forward the stream let mut stream = response.bytes_stream(); let (mut sender, body) = hyper::Body::channel(); - + // Spawn a task to forward the stream tokio::spawn(async move { while let Some(chunk_result) = stream.next().await { @@ -440,7 +439,7 @@ async fn proxy_request( } } }); - + Ok(builder.body(body).unwrap()) } } @@ -478,7 +477,7 @@ fn compress_gzip(bytes: &[u8]) -> Result, Box Result, Box Result, Box> { +fn filter_models_response( + bytes: &[u8], +) -> Result, Box> { // Try to decompress if it's gzip-encoded let decompressed_bytes = if is_gzip_encoded(bytes) { log::debug!("Response is gzip-encoded, decompressing..."); @@ -494,10 +495,10 @@ fn filter_models_response(bytes: &[u8]) -> Result, Box Result, Box Result, Box bool { }) } +pub async fn is_server_running(server_handle: Arc>>) -> bool { + let handle_guard = server_handle.lock().await; + + if handle_guard.is_some() { + true + } else { + false + } +} + /// Starts the proxy server pub async fn start_server( + server_handle: Arc>>, host: String, port: u16, prefix: String, @@ -644,7 +662,7 @@ pub async fn start_server( trusted_hosts: Vec, ) -> Result> { // Check if server is already running - let mut handle_guard = SERVER_HANDLE.lock().await; + let mut handle_guard = server_handle.lock().await; if handle_guard.is_some() { return Err("Server is already running".into()); } @@ -687,7 +705,7 @@ pub async fn start_server( log::info!("Proxy server started on http://{}", addr); // Spawn server task - let server_handle = tokio::spawn(async move { + let server_task = tokio::spawn(async move { if let Err(e) = server.await { log::error!("Server error: {}", e); return Err(Box::new(e) as Box); @@ -695,16 +713,20 @@ pub async fn start_server( Ok(()) }); - *handle_guard = Some(server_handle); + *handle_guard = Some(server_task); Ok(true) } /// Stops the currently running proxy server -pub async fn stop_server() -> Result<(), Box> { - let mut handle_guard = SERVER_HANDLE.lock().await; +pub async fn stop_server( + server_handle: Arc>>, +) -> Result<(), Box> { + let mut handle_guard = server_handle.lock().await; if let Some(handle) = handle_guard.take() { handle.abort(); + // remove the handle to prevent future use + *handle_guard = None; log::info!("Proxy server stopped"); } else { log::debug!("No server was running"); @@ -746,10 +768,10 @@ mod tests { let data = filtered_response["data"].as_array().unwrap(); assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status) - + // Verify only model1 (with "downloaded" status) is kept assert!(data.iter().any(|model| model["id"] == "model1")); - + // Verify model2 and model3 are filtered out assert!(!data.iter().any(|model| model["id"] == "model2")); assert!(!data.iter().any(|model| model["id"] == "model3")); @@ -838,11 +860,11 @@ mod tests { let data = filtered_response["data"].as_array().unwrap(); assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status) - + // Verify only models with "downloaded" status are kept assert!(data.iter().any(|model| model["id"] == "model1")); assert!(data.iter().any(|model| model["id"] == "model3")); - + // Verify other models are filtered out assert!(!data.iter().any(|model| model["id"] == "model2")); assert!(!data.iter().any(|model| model["id"] == "model4")); diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index cb6a5d3fa..9957ba92e 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -4,6 +4,10 @@ use crate::core::utils::download::DownloadManagerState; use rand::{distributions::Alphanumeric, Rng}; use rmcp::{service::RunningService, RoleClient}; use tokio::sync::Mutex; +use tokio::task::JoinHandle; + +/// Server handle type for managing the proxy server lifecycle +pub type ServerHandle = JoinHandle>>; #[derive(Default)] pub struct AppState { @@ -12,6 +16,7 @@ pub struct AppState { pub download_manager: Arc>, pub cortex_restart_count: Arc>, pub cortex_killed_intentionally: Arc>, + pub server_handle: Arc>>, } pub fn generate_app_token() -> String { rand::thread_rng() diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 076984106..4ed6ecee7 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -55,6 +55,7 @@ pub fn run() { core::cmd::app_token, core::cmd::start_server, core::cmd::stop_server, + core::cmd::get_server_status, core::cmd::read_logs, core::cmd::change_app_data_folder, core::cmd::reset_cortex_restart_count, @@ -92,6 +93,7 @@ pub fn run() { download_manager: Arc::new(Mutex::new(DownloadManagerState::default())), cortex_restart_count: Arc::new(Mutex::new(0)), cortex_killed_intentionally: Arc::new(Mutex::new(false)), + server_handle: Arc::new(Mutex::new(None)), }) .setup(|app| { app.handle().plugin( diff --git a/web-app/src/routes/__root.tsx b/web-app/src/routes/__root.tsx index 6f6099cbb..67e88ed90 100644 --- a/web-app/src/routes/__root.tsx +++ b/web-app/src/routes/__root.tsx @@ -18,7 +18,6 @@ import { AnalyticProvider } from '@/providers/AnalyticProvider' import { useLeftPanel } from '@/hooks/useLeftPanel' import { cn } from '@/lib/utils' import ToolApproval from '@/containers/dialogs/ToolApproval' -import { useEffect } from 'react' export const Route = createRootRoute({ component: RootLayout, @@ -83,13 +82,6 @@ function RootLayout() { router.location.pathname === route.systemMonitor || router.location.pathname === route.appLogs - useEffect(() => { - return () => { - // This is to attempt to stop the local API server when the app is closed or reloaded. - window.core?.api?.stopServer() - } - }, []) - return ( diff --git a/web-app/src/routes/settings/local-api-server.tsx b/web-app/src/routes/settings/local-api-server.tsx index dd7561be5..94f577074 100644 --- a/web-app/src/routes/settings/local-api-server.tsx +++ b/web-app/src/routes/settings/local-api-server.tsx @@ -17,7 +17,8 @@ import { windowKey } from '@/constants/windows' import { IconLogs } from '@tabler/icons-react' import { cn } from '@/lib/utils' import { ApiKeyInput } from '@/containers/ApiKeyInput' -import { useState } from 'react' +import { useEffect, useState } from 'react' +import { invoke } from '@tauri-apps/api/core' // eslint-disable-next-line @typescript-eslint/no-explicit-any export const Route = createFileRoute(route.settings.local_api_server as any)({ @@ -44,6 +45,17 @@ function LocalAPIServer() { !apiKey || apiKey.toString().trim().length === 0 ) + useEffect(() => { + const checkServerStatus = async () => { + invoke('get_server_status').then((running) => { + if (running) { + setServerStatus('running') + } + }) + } + checkServerStatus() + }, [setServerStatus]) + const handleApiKeyValidation = (isValid: boolean) => { setIsApiKeyEmpty(!isValid) }