feat: MCP - State update

This commit is contained in:
Louis 2025-08-15 10:02:06 +07:00
parent e1c8d98bf2
commit 13a1969150
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 89 additions and 29 deletions

View File

@ -44,9 +44,10 @@ jan-utils = { path = "./utils" }
libloading = "0.8.7" libloading = "0.8.7"
log = "0.4" log = "0.4"
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [
"client", "client",
"transport-sse-client", "transport-sse-client",
"transport-streamable-http-client",
"transport-child-process", "transport-child-process",
"tower", "tower",
"reqwest", "reqwest",

View File

@ -327,4 +327,4 @@
] ]
} }
} }
} }

View File

@ -447,4 +447,4 @@
] ]
} }
} }
} }

View File

@ -1,14 +1,13 @@
use rmcp::model::{CallToolRequestParam, CallToolResult, Tool}; use rmcp::model::{CallToolRequestParam, CallToolResult, Tool};
use rmcp::{service::RunningService, RoleClient};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::{collections::HashMap, sync::Arc};
use tauri::{AppHandle, Emitter, Runtime, State}; use tauri::{AppHandle, Emitter, Runtime, State};
use tokio::{sync::Mutex, time::timeout}; use tokio::time::timeout;
use super::{ use super::{
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers}, helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers},
}; };
use crate::core::state::{RunningServiceEnum, SharedMcpServers};
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use std::fs; use std::fs;
@ -19,8 +18,7 @@ pub async fn activate_mcp_server<R: Runtime>(
name: String, name: String,
config: Value, config: Value,
) -> Result<(), String> { ) -> Result<(), String> {
let servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> = let servers: SharedMcpServers = state.mcp_servers.clone();
state.mcp_servers.clone();
// Use the modified start_mcp_server_with_restart that returns first attempt result // Use the modified start_mcp_server_with_restart that returns first attempt result
start_mcp_server_with_restart(app, servers, name, config, Some(3)).await start_mcp_server_with_restart(app, servers, name, config, Some(3)).await
@ -63,7 +61,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) ->
// Release the lock before calling cancel // Release the lock before calling cancel
drop(servers_map); drop(servers_map);
service.cancel().await.map_err(|e| e.to_string())?; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
log::info!("Server {name} stopped successfully and marked as deactivated."); log::info!("Server {name} stopped successfully and marked as deactivated.");
Ok(()) Ok(())
} }

View File

@ -1,4 +1,4 @@
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt}; use rmcp::{transport::TokioChildProcess, ServiceExt};
use serde_json::Value; use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc, time::Duration}; use std::{collections::HashMap, env, sync::Arc, time::Duration};
use tauri::{AppHandle, Emitter, Manager, Runtime, State}; use tauri::{AppHandle, Emitter, Manager, Runtime, State};
@ -11,7 +11,10 @@ use tokio::{
use super::constants::{ use super::constants::{
MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS, MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS,
}; };
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; use crate::core::{
app::commands::get_jan_data_folder_path,
state::{AppState, RunningServiceEnum, SharedMcpServers},
};
use jan_utils::can_override_npx; use jan_utils::can_override_npx;
/// Calculate exponential backoff delay with jitter /// Calculate exponential backoff delay with jitter
@ -72,7 +75,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 {
/// * `Err(String)` if there was an error reading config or starting servers /// * `Err(String)` if there was an error reading config or starting servers
pub async fn run_mcp_commands<R: Runtime>( pub async fn run_mcp_commands<R: Runtime>(
app: &AppHandle<R>, app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
) -> Result<(), String> { ) -> Result<(), String> {
let app_path = get_jan_data_folder_path(app.clone()); let app_path = get_jan_data_folder_path(app.clone());
let app_path_str = app_path.to_str().unwrap().to_string(); let app_path_str = app_path.to_str().unwrap().to_string();
@ -168,7 +171,7 @@ pub async fn run_mcp_commands<R: Runtime>(
/// Monitor MCP server health without removing it from the HashMap /// Monitor MCP server health without removing it from the HashMap
pub async fn monitor_mcp_server_handle( pub async fn monitor_mcp_server_handle(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
) -> Option<rmcp::service::QuitReason> { ) -> Option<rmcp::service::QuitReason> {
log::info!("Monitoring MCP server {} health", name); log::info!("Monitoring MCP server {} health", name);
@ -213,7 +216,16 @@ pub async fn monitor_mcp_server_handle(
let mut servers = servers_state.lock().await; let mut servers = servers_state.lock().await;
if let Some(service) = servers.remove(&name) { if let Some(service) = servers.remove(&name) {
// Try to cancel the service gracefully // Try to cancel the service gracefully
let _ = service.cancel().await; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
let _ = service.cancel().await;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
let _ = service.cancel().await;
}
}
} }
return Some(rmcp::service::QuitReason::Closed); return Some(rmcp::service::QuitReason::Closed);
} }
@ -224,7 +236,7 @@ pub async fn monitor_mcp_server_handle(
/// Returns the result of the first start attempt, then continues with restart monitoring /// Returns the result of the first start attempt, then continues with restart monitoring
pub async fn start_mcp_server_with_restart<R: Runtime>( pub async fn start_mcp_server_with_restart<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: Option<u32>, max_restarts: Option<u32>,
@ -297,7 +309,7 @@ pub async fn start_mcp_server_with_restart<R: Runtime>(
/// Helper function to handle the restart loop logic /// Helper function to handle the restart loop logic
pub async fn start_restart_loop<R: Runtime>( pub async fn start_restart_loop<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: u32, max_restarts: u32,
@ -452,7 +464,7 @@ pub async fn start_restart_loop<R: Runtime>(
pub async fn schedule_mcp_start_task<R: Runtime>( pub async fn schedule_mcp_start_task<R: Runtime>(
app: tauri::AppHandle<R>, app: tauri::AppHandle<R>,
servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
) -> Result<(), String> { ) -> Result<(), String> {
@ -540,7 +552,12 @@ pub async fn schedule_mcp_start_task<R: Runtime>(
}; };
// Now move the service into the HashMap // Now move the service into the HashMap
servers.lock().await.insert(name.clone(), service); // Now move the service into the HashMap
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::NoInit(service));
log::info!("Server {name} started successfully.");
log::info!("Server {name} started successfully."); log::info!("Server {name} started successfully.");
// Wait a short time to verify the server is stable before marking as connected // Wait a short time to verify the server is stable before marking as connected
@ -604,7 +621,7 @@ pub fn extract_active_status(config: &Value) -> Option<bool> {
/// Restart only servers that were previously active (like cortex restart behavior) /// Restart only servers that were previously active (like cortex restart behavior)
pub async fn restart_active_mcp_servers<R: Runtime>( pub async fn restart_active_mcp_servers<R: Runtime>(
app: &AppHandle<R>, app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
) -> Result<(), String> { ) -> Result<(), String> {
let app_state = app.state::<AppState>(); let app_state = app.state::<AppState>();
let active_servers = app_state.mcp_active_servers.lock().await; let active_servers = app_state.mcp_active_servers.lock().await;
@ -656,14 +673,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) {
log::info!("MCP servers cleaned up successfully"); log::info!("MCP servers cleaned up successfully");
} }
pub async fn stop_mcp_servers( pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> {
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
) -> Result<(), String> {
let mut servers_map = servers_state.lock().await; let mut servers_map = servers_state.lock().await;
let keys: Vec<String> = servers_map.keys().cloned().collect(); let keys: Vec<String> = servers_map.keys().cloned().collect();
for key in keys { for key in keys {
if let Some(service) = servers_map.remove(&key) { if let Some(service) = servers_map.remove(&key) {
service.cancel().await.map_err(|e| e.to_string())?; match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {key}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {key} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
} }
} }
drop(servers_map); // Release the lock after stopping drop(servers_map); // Release the lock after stopping
@ -689,7 +713,7 @@ pub async fn reset_restart_count(restart_counts: &Arc<Mutex<HashMap<String, u32>
/// Spawn the server monitoring task for handling restarts /// Spawn the server monitoring task for handling restarts
pub async fn spawn_server_monitoring_task<R: Runtime>( pub async fn spawn_server_monitoring_task<R: Runtime>(
app: AppHandle<R>, app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, servers_state: SharedMcpServers,
name: String, name: String,
config: Value, config: Value,
max_restarts: u32, max_restarts: u32,

View File

@ -1,6 +1,6 @@
use super::helpers::run_mcp_commands; use super::helpers::run_mcp_commands;
use crate::core::app::commands::get_jan_data_folder_path; use crate::core::app::commands::get_jan_data_folder_path;
use rmcp::{service::RunningService, RoleClient}; use crate::core::state::SharedMcpServers;
use std::collections::HashMap; use std::collections::HashMap;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
@ -27,7 +27,7 @@ async fn test_run_mcp_commands() {
.expect("Failed to write to config file"); .expect("Failed to write to config file");
// Call the run_mcp_commands function // Call the run_mcp_commands function
let servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> = let servers_state: SharedMcpServers =
Arc::new(Mutex::new(HashMap::new())); Arc::new(Mutex::new(HashMap::new()));
let result = run_mcp_commands(app.handle(), servers_state).await; let result = run_mcp_commands(app.handle(), servers_state).await;

View File

@ -1,20 +1,48 @@
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use crate::core::downloads::models::DownloadManagerState; use crate::core::downloads::models::DownloadManagerState;
use rmcp::{service::RunningService, RoleClient}; use rmcp::{
model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool},
service::RunningService,
RoleClient, ServiceError,
};
use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle /// Server handle type for managing the proxy server lifecycle
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>; pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
use tokio::sync::Mutex;
pub enum RunningServiceEnum {
NoInit(RunningService<RoleClient, ()>),
WithInit(RunningService<RoleClient, InitializeRequestParam>),
}
pub type SharedMcpServers = Arc<Mutex<HashMap<String, RunningServiceEnum>>>;
#[derive(Default)] #[derive(Default)]
pub struct AppState { pub struct AppState {
pub app_token: Option<String>, pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>, pub mcp_servers: SharedMcpServers,
pub download_manager: Arc<Mutex<DownloadManagerState>>, pub download_manager: Arc<Mutex<DownloadManagerState>>,
pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>, pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>,
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>, pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>, pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>, pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
} }
impl RunningServiceEnum {
pub async fn list_all_tools(&self) -> Result<Vec<Tool>, ServiceError> {
match self {
Self::NoInit(s) => s.list_all_tools().await,
Self::WithInit(s) => s.list_all_tools().await,
}
}
pub async fn call_tool(
&self,
params: CallToolRequestParam,
) -> Result<CallToolResult, ServiceError> {
match self {
Self::NoInit(s) => s.call_tool(params).await,
Self::WithInit(s) => s.call_tool(params).await,
}
}
}