fix: stop api server on page unload (#5356)

* fix: stop api server on page unload

* fix: check api server status on reload

* refactor: api server state

* fix: should not pop the guard
This commit is contained in:
Louis 2025-06-19 00:12:03 +07:00 committed by GitHub
parent 5b60116d21
commit 22396111be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 109 additions and 58 deletions

View File

@ -348,23 +348,41 @@ pub async fn start_server(
api_key: String,
trusted_hosts: Vec<String>,
) -> Result<bool, String> {
let auth_token = app
.state::<AppState>()
.app_token
.clone()
.unwrap_or_default();
server::start_server(host, port, prefix, auth_token, api_key, trusted_hosts)
.await
.map_err(|e| e.to_string())?;
let state = app.state::<AppState>();
let auth_token = state.app_token.clone().unwrap_or_default();
let server_handle = state.server_handle.clone();
server::start_server(
server_handle,
host,
port,
prefix,
auth_token,
api_key,
trusted_hosts,
)
.await
.map_err(|e| e.to_string())?;
Ok(true)
}
#[tauri::command]
pub async fn stop_server() -> Result<(), String> {
server::stop_server().await.map_err(|e| e.to_string())?;
pub async fn stop_server(state: State<'_, AppState>) -> Result<(), String> {
let server_handle = state.server_handle.clone();
server::stop_server(server_handle)
.await
.map_err(|e| e.to_string())?;
Ok(())
}
#[tauri::command]
pub async fn get_server_status(state: State<'_, AppState>) -> Result<bool, String> {
let server_handle = state.server_handle.clone();
Ok(server::is_server_running(server_handle).await)
}
#[tauri::command]
pub async fn read_logs(app: AppHandle) -> Result<String, String> {
let log_path = get_jan_data_folder_path(app).join("logs").join("app.log");

View File

@ -1,21 +1,16 @@
use flate2::read::GzDecoder;
use futures_util::StreamExt;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server, StatusCode};
use reqwest::Client;
use serde_json::Value;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::LazyLock;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use futures_util::StreamExt;
use flate2::read::GzDecoder;
use std::io::Read;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Server handle type for managing the proxy server lifecycle
type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
/// Global singleton for the current server instance
static SERVER_HANDLE: LazyLock<Mutex<Option<ServerHandle>>> = LazyLock::new(|| Mutex::new(None));
use crate::core::state::ServerHandle;
/// Configuration for the proxy server
#[derive(Clone)]
@ -272,7 +267,7 @@ async fn proxy_request(
// Verify Host header (check target), but bypass for whitelisted paths
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
let is_whitelisted_path = whitelisted_paths.contains(&path.as_str());
if !is_whitelisted_path {
if !host_header.is_empty() {
if !is_valid_host(&host_header, &config.trusted_hosts) {
@ -333,7 +328,10 @@ async fn proxy_request(
.unwrap());
}
} else if is_whitelisted_path {
log::debug!("Bypassing authorization check for whitelisted path: {}", path);
log::debug!(
"Bypassing authorization check for whitelisted path: {}",
path
);
}
// Block access to /configs endpoint
@ -394,13 +392,14 @@ async fn proxy_request(
if path.contains("/models") && method == hyper::Method::GET {
// For /models endpoint, we need to buffer and filter the response
match response.bytes().await {
Ok(bytes) => {
match filter_models_response(&bytes) {
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
Err(e) => {
log::warn!("Failed to filter models response: {}, returning original", e);
Ok(builder.body(Body::from(bytes)).unwrap())
}
Ok(bytes) => match filter_models_response(&bytes) {
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
Err(e) => {
log::warn!(
"Failed to filter models response: {}, returning original",
e
);
Ok(builder.body(Body::from(bytes)).unwrap())
}
},
Err(e) => {
@ -422,7 +421,7 @@ async fn proxy_request(
// For streaming endpoints (like chat completions), we need to collect and forward the stream
let mut stream = response.bytes_stream();
let (mut sender, body) = hyper::Body::channel();
// Spawn a task to forward the stream
tokio::spawn(async move {
while let Some(chunk_result) = stream.next().await {
@ -440,7 +439,7 @@ async fn proxy_request(
}
}
});
Ok(builder.body(body).unwrap())
}
}
@ -478,7 +477,7 @@ fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Se
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(bytes)?;
let compressed = encoder.finish()?;
@ -486,7 +485,9 @@ fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Se
}
/// Filters models response to keep only models with status "downloaded"
fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
fn filter_models_response(
bytes: &[u8],
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
// Try to decompress if it's gzip-encoded
let decompressed_bytes = if is_gzip_encoded(bytes) {
log::debug!("Response is gzip-encoded, decompressing...");
@ -494,10 +495,10 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
} else {
bytes.to_vec()
};
let response_text = std::str::from_utf8(&decompressed_bytes)?;
let mut response_json: Value = serde_json::from_str(response_text)?;
// Check if this is a ListModelsResponseDto format with data array
if let Some(data_array) = response_json.get_mut("data") {
if let Some(models) = data_array.as_array_mut() {
@ -513,7 +514,10 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
false // Remove models without status field
}
});
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
log::debug!(
"Filtered models response: {} downloaded models remaining",
models.len()
);
}
} else if response_json.is_array() {
// Handle direct array format
@ -529,12 +533,15 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
false // Remove models without status field
}
});
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
log::debug!(
"Filtered models response: {} downloaded models remaining",
models.len()
);
}
}
let filtered_json = serde_json::to_vec(&response_json)?;
// If original was gzip-encoded, re-compress the filtered response
if is_gzip_encoded(bytes) {
log::debug!("Re-compressing filtered response with gzip");
@ -634,8 +641,19 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
})
}
pub async fn is_server_running(server_handle: Arc<Mutex<Option<ServerHandle>>>) -> bool {
let handle_guard = server_handle.lock().await;
if handle_guard.is_some() {
true
} else {
false
}
}
/// Starts the proxy server
pub async fn start_server(
server_handle: Arc<Mutex<Option<ServerHandle>>>,
host: String,
port: u16,
prefix: String,
@ -644,7 +662,7 @@ pub async fn start_server(
trusted_hosts: Vec<String>,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
// Check if server is already running
let mut handle_guard = SERVER_HANDLE.lock().await;
let mut handle_guard = server_handle.lock().await;
if handle_guard.is_some() {
return Err("Server is already running".into());
}
@ -687,7 +705,7 @@ pub async fn start_server(
log::info!("Proxy server started on http://{}", addr);
// Spawn server task
let server_handle = tokio::spawn(async move {
let server_task = tokio::spawn(async move {
if let Err(e) = server.await {
log::error!("Server error: {}", e);
return Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>);
@ -695,16 +713,20 @@ pub async fn start_server(
Ok(())
});
*handle_guard = Some(server_handle);
*handle_guard = Some(server_task);
Ok(true)
}
/// Stops the currently running proxy server
pub async fn stop_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut handle_guard = SERVER_HANDLE.lock().await;
pub async fn stop_server(
server_handle: Arc<Mutex<Option<ServerHandle>>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let mut handle_guard = server_handle.lock().await;
if let Some(handle) = handle_guard.take() {
handle.abort();
// remove the handle to prevent future use
*handle_guard = None;
log::info!("Proxy server stopped");
} else {
log::debug!("No server was running");
@ -746,10 +768,10 @@ mod tests {
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
// Verify only model1 (with "downloaded" status) is kept
assert!(data.iter().any(|model| model["id"] == "model1"));
// Verify model2 and model3 are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model3"));
@ -838,11 +860,11 @@ mod tests {
let data = filtered_response["data"].as_array().unwrap();
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
// Verify only models with "downloaded" status are kept
assert!(data.iter().any(|model| model["id"] == "model1"));
assert!(data.iter().any(|model| model["id"] == "model3"));
// Verify other models are filtered out
assert!(!data.iter().any(|model| model["id"] == "model2"));
assert!(!data.iter().any(|model| model["id"] == "model4"));

View File

@ -4,6 +4,10 @@ use crate::core::utils::download::DownloadManagerState;
use rand::{distributions::Alphanumeric, Rng};
use rmcp::{service::RunningService, RoleClient};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
#[derive(Default)]
pub struct AppState {
@ -12,6 +16,7 @@ pub struct AppState {
pub download_manager: Arc<Mutex<DownloadManagerState>>,
pub cortex_restart_count: Arc<Mutex<u32>>,
pub cortex_killed_intentionally: Arc<Mutex<bool>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
}
pub fn generate_app_token() -> String {
rand::thread_rng()

View File

@ -55,6 +55,7 @@ pub fn run() {
core::cmd::app_token,
core::cmd::start_server,
core::cmd::stop_server,
core::cmd::get_server_status,
core::cmd::read_logs,
core::cmd::change_app_data_folder,
core::cmd::reset_cortex_restart_count,
@ -92,6 +93,7 @@ pub fn run() {
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
cortex_restart_count: Arc::new(Mutex::new(0)),
cortex_killed_intentionally: Arc::new(Mutex::new(false)),
server_handle: Arc::new(Mutex::new(None)),
})
.setup(|app| {
app.handle().plugin(

View File

@ -18,7 +18,6 @@ import { AnalyticProvider } from '@/providers/AnalyticProvider'
import { useLeftPanel } from '@/hooks/useLeftPanel'
import { cn } from '@/lib/utils'
import ToolApproval from '@/containers/dialogs/ToolApproval'
import { useEffect } from 'react'
export const Route = createRootRoute({
component: RootLayout,
@ -83,13 +82,6 @@ function RootLayout() {
router.location.pathname === route.systemMonitor ||
router.location.pathname === route.appLogs
useEffect(() => {
return () => {
// This is to attempt to stop the local API server when the app is closed or reloaded.
window.core?.api?.stopServer()
}
}, [])
return (
<Fragment>
<ThemeProvider />

View File

@ -17,7 +17,8 @@ import { windowKey } from '@/constants/windows'
import { IconLogs } from '@tabler/icons-react'
import { cn } from '@/lib/utils'
import { ApiKeyInput } from '@/containers/ApiKeyInput'
import { useState } from 'react'
import { useEffect, useState } from 'react'
import { invoke } from '@tauri-apps/api/core'
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const Route = createFileRoute(route.settings.local_api_server as any)({
@ -44,6 +45,17 @@ function LocalAPIServer() {
!apiKey || apiKey.toString().trim().length === 0
)
useEffect(() => {
const checkServerStatus = async () => {
invoke('get_server_status').then((running) => {
if (running) {
setServerStatus('running')
}
})
}
checkServerStatus()
}, [setServerStatus])
const handleApiKeyValidation = (isValid: boolean) => {
setIsApiKeyEmpty(!isValid)
}