fix: stop api server on page unload (#5356)
* fix: stop api server on page unload * fix: check api server status on reload * refactor: api server state * fix: should not pop the guard
This commit is contained in:
parent
5b60116d21
commit
22396111be
@ -348,23 +348,41 @@ pub async fn start_server(
|
||||
api_key: String,
|
||||
trusted_hosts: Vec<String>,
|
||||
) -> Result<bool, String> {
|
||||
let auth_token = app
|
||||
.state::<AppState>()
|
||||
.app_token
|
||||
.clone()
|
||||
.unwrap_or_default();
|
||||
server::start_server(host, port, prefix, auth_token, api_key, trusted_hosts)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let state = app.state::<AppState>();
|
||||
let auth_token = state.app_token.clone().unwrap_or_default();
|
||||
let server_handle = state.server_handle.clone();
|
||||
|
||||
server::start_server(
|
||||
server_handle,
|
||||
host,
|
||||
port,
|
||||
prefix,
|
||||
auth_token,
|
||||
api_key,
|
||||
trusted_hosts,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn stop_server() -> Result<(), String> {
|
||||
server::stop_server().await.map_err(|e| e.to_string())?;
|
||||
pub async fn stop_server(state: State<'_, AppState>) -> Result<(), String> {
|
||||
let server_handle = state.server_handle.clone();
|
||||
|
||||
server::stop_server(server_handle)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_server_status(state: State<'_, AppState>) -> Result<bool, String> {
|
||||
let server_handle = state.server_handle.clone();
|
||||
|
||||
Ok(server::is_server_running(server_handle).await)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn read_logs(app: AppHandle) -> Result<String, String> {
|
||||
let log_path = get_jan_data_folder_path(app).join("logs").join("app.log");
|
||||
|
||||
@ -1,21 +1,16 @@
|
||||
use flate2::read::GzDecoder;
|
||||
use futures_util::StreamExt;
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Request, Response, Server, StatusCode};
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::LazyLock;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use futures_util::StreamExt;
|
||||
use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Server handle type for managing the proxy server lifecycle
|
||||
type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
|
||||
|
||||
/// Global singleton for the current server instance
|
||||
static SERVER_HANDLE: LazyLock<Mutex<Option<ServerHandle>>> = LazyLock::new(|| Mutex::new(None));
|
||||
use crate::core::state::ServerHandle;
|
||||
|
||||
/// Configuration for the proxy server
|
||||
#[derive(Clone)]
|
||||
@ -272,7 +267,7 @@ async fn proxy_request(
|
||||
// Verify Host header (check target), but bypass for whitelisted paths
|
||||
let whitelisted_paths = ["/", "/openapi.json", "/favicon.ico"];
|
||||
let is_whitelisted_path = whitelisted_paths.contains(&path.as_str());
|
||||
|
||||
|
||||
if !is_whitelisted_path {
|
||||
if !host_header.is_empty() {
|
||||
if !is_valid_host(&host_header, &config.trusted_hosts) {
|
||||
@ -333,7 +328,10 @@ async fn proxy_request(
|
||||
.unwrap());
|
||||
}
|
||||
} else if is_whitelisted_path {
|
||||
log::debug!("Bypassing authorization check for whitelisted path: {}", path);
|
||||
log::debug!(
|
||||
"Bypassing authorization check for whitelisted path: {}",
|
||||
path
|
||||
);
|
||||
}
|
||||
|
||||
// Block access to /configs endpoint
|
||||
@ -394,13 +392,14 @@ async fn proxy_request(
|
||||
if path.contains("/models") && method == hyper::Method::GET {
|
||||
// For /models endpoint, we need to buffer and filter the response
|
||||
match response.bytes().await {
|
||||
Ok(bytes) => {
|
||||
match filter_models_response(&bytes) {
|
||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||
Err(e) => {
|
||||
log::warn!("Failed to filter models response: {}, returning original", e);
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
Ok(bytes) => match filter_models_response(&bytes) {
|
||||
Ok(filtered_bytes) => Ok(builder.body(Body::from(filtered_bytes)).unwrap()),
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Failed to filter models response: {}, returning original",
|
||||
e
|
||||
);
|
||||
Ok(builder.body(Body::from(bytes)).unwrap())
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
@ -422,7 +421,7 @@ async fn proxy_request(
|
||||
// For streaming endpoints (like chat completions), we need to collect and forward the stream
|
||||
let mut stream = response.bytes_stream();
|
||||
let (mut sender, body) = hyper::Body::channel();
|
||||
|
||||
|
||||
// Spawn a task to forward the stream
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
@ -440,7 +439,7 @@ async fn proxy_request(
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Ok(builder.body(body).unwrap())
|
||||
}
|
||||
}
|
||||
@ -478,7 +477,7 @@ fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Se
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(bytes)?;
|
||||
let compressed = encoder.finish()?;
|
||||
@ -486,7 +485,9 @@ fn compress_gzip(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Se
|
||||
}
|
||||
|
||||
/// Filters models response to keep only models with status "downloaded"
|
||||
fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn filter_models_response(
|
||||
bytes: &[u8],
|
||||
) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Try to decompress if it's gzip-encoded
|
||||
let decompressed_bytes = if is_gzip_encoded(bytes) {
|
||||
log::debug!("Response is gzip-encoded, decompressing...");
|
||||
@ -494,10 +495,10 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
|
||||
} else {
|
||||
bytes.to_vec()
|
||||
};
|
||||
|
||||
|
||||
let response_text = std::str::from_utf8(&decompressed_bytes)?;
|
||||
let mut response_json: Value = serde_json::from_str(response_text)?;
|
||||
|
||||
|
||||
// Check if this is a ListModelsResponseDto format with data array
|
||||
if let Some(data_array) = response_json.get_mut("data") {
|
||||
if let Some(models) = data_array.as_array_mut() {
|
||||
@ -513,7 +514,10 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
|
||||
log::debug!(
|
||||
"Filtered models response: {} downloaded models remaining",
|
||||
models.len()
|
||||
);
|
||||
}
|
||||
} else if response_json.is_array() {
|
||||
// Handle direct array format
|
||||
@ -529,12 +533,15 @@ fn filter_models_response(bytes: &[u8]) -> Result<Vec<u8>, Box<dyn std::error::E
|
||||
false // Remove models without status field
|
||||
}
|
||||
});
|
||||
log::debug!("Filtered models response: {} downloaded models remaining", models.len());
|
||||
log::debug!(
|
||||
"Filtered models response: {} downloaded models remaining",
|
||||
models.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let filtered_json = serde_json::to_vec(&response_json)?;
|
||||
|
||||
|
||||
// If original was gzip-encoded, re-compress the filtered response
|
||||
if is_gzip_encoded(bytes) {
|
||||
log::debug!("Re-compressing filtered response with gzip");
|
||||
@ -634,8 +641,19 @@ fn is_valid_host(host: &str, trusted_hosts: &[String]) -> bool {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn is_server_running(server_handle: Arc<Mutex<Option<ServerHandle>>>) -> bool {
|
||||
let handle_guard = server_handle.lock().await;
|
||||
|
||||
if handle_guard.is_some() {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts the proxy server
|
||||
pub async fn start_server(
|
||||
server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
host: String,
|
||||
port: u16,
|
||||
prefix: String,
|
||||
@ -644,7 +662,7 @@ pub async fn start_server(
|
||||
trusted_hosts: Vec<String>,
|
||||
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Check if server is already running
|
||||
let mut handle_guard = SERVER_HANDLE.lock().await;
|
||||
let mut handle_guard = server_handle.lock().await;
|
||||
if handle_guard.is_some() {
|
||||
return Err("Server is already running".into());
|
||||
}
|
||||
@ -687,7 +705,7 @@ pub async fn start_server(
|
||||
log::info!("Proxy server started on http://{}", addr);
|
||||
|
||||
// Spawn server task
|
||||
let server_handle = tokio::spawn(async move {
|
||||
let server_task = tokio::spawn(async move {
|
||||
if let Err(e) = server.await {
|
||||
log::error!("Server error: {}", e);
|
||||
return Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>);
|
||||
@ -695,16 +713,20 @@ pub async fn start_server(
|
||||
Ok(())
|
||||
});
|
||||
|
||||
*handle_guard = Some(server_handle);
|
||||
*handle_guard = Some(server_task);
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// Stops the currently running proxy server
|
||||
pub async fn stop_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut handle_guard = SERVER_HANDLE.lock().await;
|
||||
pub async fn stop_server(
|
||||
server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut handle_guard = server_handle.lock().await;
|
||||
|
||||
if let Some(handle) = handle_guard.take() {
|
||||
handle.abort();
|
||||
// remove the handle to prevent future use
|
||||
*handle_guard = None;
|
||||
log::info!("Proxy server stopped");
|
||||
} else {
|
||||
log::debug!("No server was running");
|
||||
@ -746,10 +768,10 @@ mod tests {
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 1); // Should have 1 model (only model1 with "downloaded" status)
|
||||
|
||||
|
||||
// Verify only model1 (with "downloaded" status) is kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
|
||||
|
||||
// Verify model2 and model3 are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model3"));
|
||||
@ -838,11 +860,11 @@ mod tests {
|
||||
|
||||
let data = filtered_response["data"].as_array().unwrap();
|
||||
assert_eq!(data.len(), 2); // Should have 2 models (model1 and model3 with "downloaded" status)
|
||||
|
||||
|
||||
// Verify only models with "downloaded" status are kept
|
||||
assert!(data.iter().any(|model| model["id"] == "model1"));
|
||||
assert!(data.iter().any(|model| model["id"] == "model3"));
|
||||
|
||||
|
||||
// Verify other models are filtered out
|
||||
assert!(!data.iter().any(|model| model["id"] == "model2"));
|
||||
assert!(!data.iter().any(|model| model["id"] == "model4"));
|
||||
|
||||
@ -4,6 +4,10 @@ use crate::core::utils::download::DownloadManagerState;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use rmcp::{service::RunningService, RoleClient};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Server handle type for managing the proxy server lifecycle
|
||||
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AppState {
|
||||
@ -12,6 +16,7 @@ pub struct AppState {
|
||||
pub download_manager: Arc<Mutex<DownloadManagerState>>,
|
||||
pub cortex_restart_count: Arc<Mutex<u32>>,
|
||||
pub cortex_killed_intentionally: Arc<Mutex<bool>>,
|
||||
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
}
|
||||
pub fn generate_app_token() -> String {
|
||||
rand::thread_rng()
|
||||
|
||||
@ -55,6 +55,7 @@ pub fn run() {
|
||||
core::cmd::app_token,
|
||||
core::cmd::start_server,
|
||||
core::cmd::stop_server,
|
||||
core::cmd::get_server_status,
|
||||
core::cmd::read_logs,
|
||||
core::cmd::change_app_data_folder,
|
||||
core::cmd::reset_cortex_restart_count,
|
||||
@ -92,6 +93,7 @@ pub fn run() {
|
||||
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
|
||||
cortex_restart_count: Arc::new(Mutex::new(0)),
|
||||
cortex_killed_intentionally: Arc::new(Mutex::new(false)),
|
||||
server_handle: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
.setup(|app| {
|
||||
app.handle().plugin(
|
||||
|
||||
@ -18,7 +18,6 @@ import { AnalyticProvider } from '@/providers/AnalyticProvider'
|
||||
import { useLeftPanel } from '@/hooks/useLeftPanel'
|
||||
import { cn } from '@/lib/utils'
|
||||
import ToolApproval from '@/containers/dialogs/ToolApproval'
|
||||
import { useEffect } from 'react'
|
||||
|
||||
export const Route = createRootRoute({
|
||||
component: RootLayout,
|
||||
@ -83,13 +82,6 @@ function RootLayout() {
|
||||
router.location.pathname === route.systemMonitor ||
|
||||
router.location.pathname === route.appLogs
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
// This is to attempt to stop the local API server when the app is closed or reloaded.
|
||||
window.core?.api?.stopServer()
|
||||
}
|
||||
}, [])
|
||||
|
||||
return (
|
||||
<Fragment>
|
||||
<ThemeProvider />
|
||||
|
||||
@ -17,7 +17,8 @@ import { windowKey } from '@/constants/windows'
|
||||
import { IconLogs } from '@tabler/icons-react'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { ApiKeyInput } from '@/containers/ApiKeyInput'
|
||||
import { useState } from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const Route = createFileRoute(route.settings.local_api_server as any)({
|
||||
@ -44,6 +45,17 @@ function LocalAPIServer() {
|
||||
!apiKey || apiKey.toString().trim().length === 0
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
const checkServerStatus = async () => {
|
||||
invoke('get_server_status').then((running) => {
|
||||
if (running) {
|
||||
setServerStatus('running')
|
||||
}
|
||||
})
|
||||
}
|
||||
checkServerStatus()
|
||||
}, [setServerStatus])
|
||||
|
||||
const handleApiKeyValidation = (isValid: boolean) => {
|
||||
setIsApiKeyEmpty(!isValid)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user