feat: MCP - State update

This commit is contained in:
Louis 2025-08-15 10:02:06 +07:00
parent e1c8d98bf2
commit 13a1969150
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 89 additions and 29 deletions

View File

@ -44,9 +44,10 @@ jan-utils = { path = "./utils" }
libloading = "0.8.7"
log = "0.4"
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [
"client",
"transport-sse-client",
"transport-streamable-http-client",
"transport-child-process",
"tower",
"reqwest",

View File

@ -1,14 +1,13 @@
use rmcp::model::{CallToolRequestParam, CallToolResult, Tool};
use rmcp::{service::RunningService, RoleClient};
use serde_json::{Map, Value};
use std::{collections::HashMap, sync::Arc};
use tauri::{AppHandle, Emitter, Runtime, State};
use tokio::{sync::Mutex, time::timeout};
use tokio::time::timeout;
use super::{
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers},
};
use crate::core::state::{RunningServiceEnum, SharedMcpServers};
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use std::fs;
@ -19,8 +18,7 @@ pub async fn activate_mcp_server<R: Runtime>(
name: String,
config: Value,
) -> Result<(), String> {
let servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
state.mcp_servers.clone();
let servers: SharedMcpServers = state.mcp_servers.clone();
// Use the modified start_mcp_server_with_restart that returns first attempt result
start_mcp_server_with_restart(app, servers, name, config, Some(3)).await
@ -63,7 +61,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) ->
// Release the lock before calling cancel
drop(servers_map);
service.cancel().await.map_err(|e| e.to_string())?;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
log::info!("Server {name} stopped successfully and marked as deactivated.");
Ok(())
}

View File

@ -1,4 +1,4 @@
use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt};
use rmcp::{transport::TokioChildProcess, ServiceExt};
use serde_json::Value;
use std::{collections::HashMap, env, sync::Arc, time::Duration};
use tauri::{AppHandle, Emitter, Manager, Runtime, State};
@ -11,7 +11,10 @@ use tokio::{
use super::constants::{
MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS,
};
use crate::core::{app::commands::get_jan_data_folder_path, state::AppState};
use crate::core::{
app::commands::get_jan_data_folder_path,
state::{AppState, RunningServiceEnum, SharedMcpServers},
};
use jan_utils::can_override_npx;
/// Calculate exponential backoff delay with jitter
@ -72,7 +75,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 {
/// * `Err(String)` if there was an error reading config or starting servers
pub async fn run_mcp_commands<R: Runtime>(
app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
) -> Result<(), String> {
let app_path = get_jan_data_folder_path(app.clone());
let app_path_str = app_path.to_str().unwrap().to_string();
@ -168,7 +171,7 @@ pub async fn run_mcp_commands<R: Runtime>(
/// Monitor MCP server health without removing it from the HashMap
pub async fn monitor_mcp_server_handle(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
) -> Option<rmcp::service::QuitReason> {
log::info!("Monitoring MCP server {} health", name);
@ -213,7 +216,16 @@ pub async fn monitor_mcp_server_handle(
let mut servers = servers_state.lock().await;
if let Some(service) = servers.remove(&name) {
// Try to cancel the service gracefully
let _ = service.cancel().await;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {name}...");
let _ = service.cancel().await;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {name} with initialization...");
let _ = service.cancel().await;
}
}
}
return Some(rmcp::service::QuitReason::Closed);
}
@ -224,7 +236,7 @@ pub async fn monitor_mcp_server_handle(
/// Returns the result of the first start attempt, then continues with restart monitoring
pub async fn start_mcp_server_with_restart<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: Option<u32>,
@ -297,7 +309,7 @@ pub async fn start_mcp_server_with_restart<R: Runtime>(
/// Helper function to handle the restart loop logic
pub async fn start_restart_loop<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: u32,
@ -452,7 +464,7 @@ pub async fn start_restart_loop<R: Runtime>(
pub async fn schedule_mcp_start_task<R: Runtime>(
app: tauri::AppHandle<R>,
servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers: SharedMcpServers,
name: String,
config: Value,
) -> Result<(), String> {
@ -540,7 +552,12 @@ pub async fn schedule_mcp_start_task<R: Runtime>(
};
// Now move the service into the HashMap
servers.lock().await.insert(name.clone(), service);
// Now move the service into the HashMap
servers
.lock()
.await
.insert(name.clone(), RunningServiceEnum::NoInit(service));
log::info!("Server {name} started successfully.");
log::info!("Server {name} started successfully.");
// Wait a short time to verify the server is stable before marking as connected
@ -604,7 +621,7 @@ pub fn extract_active_status(config: &Value) -> Option<bool> {
/// Restart only servers that were previously active (like cortex restart behavior)
pub async fn restart_active_mcp_servers<R: Runtime>(
app: &AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
) -> Result<(), String> {
let app_state = app.state::<AppState>();
let active_servers = app_state.mcp_active_servers.lock().await;
@ -656,14 +673,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) {
log::info!("MCP servers cleaned up successfully");
}
pub async fn stop_mcp_servers(
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
) -> Result<(), String> {
pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> {
let mut servers_map = servers_state.lock().await;
let keys: Vec<String> = servers_map.keys().cloned().collect();
for key in keys {
if let Some(service) = servers_map.remove(&key) {
service.cancel().await.map_err(|e| e.to_string())?;
match service {
RunningServiceEnum::NoInit(service) => {
log::info!("Stopping server {key}...");
service.cancel().await.map_err(|e| e.to_string())?;
}
RunningServiceEnum::WithInit(service) => {
log::info!("Stopping server {key} with initialization...");
service.cancel().await.map_err(|e| e.to_string())?;
}
}
}
}
drop(servers_map); // Release the lock after stopping
@ -689,7 +713,7 @@ pub async fn reset_restart_count(restart_counts: &Arc<Mutex<HashMap<String, u32>
/// Spawn the server monitoring task for handling restarts
pub async fn spawn_server_monitoring_task<R: Runtime>(
app: AppHandle<R>,
servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
servers_state: SharedMcpServers,
name: String,
config: Value,
max_restarts: u32,

View File

@ -1,6 +1,6 @@
use super::helpers::run_mcp_commands;
use crate::core::app::commands::get_jan_data_folder_path;
use rmcp::{service::RunningService, RoleClient};
use crate::core::state::SharedMcpServers;
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
@ -27,7 +27,7 @@ async fn test_run_mcp_commands() {
.expect("Failed to write to config file");
// Call the run_mcp_commands function
let servers_state: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>> =
let servers_state: SharedMcpServers =
Arc::new(Mutex::new(HashMap::new()));
let result = run_mcp_commands(app.handle(), servers_state).await;

View File

@ -1,20 +1,48 @@
use std::{collections::HashMap, sync::Arc};
use crate::core::downloads::models::DownloadManagerState;
use rmcp::{service::RunningService, RoleClient};
use rmcp::{
model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool},
service::RunningService,
RoleClient, ServiceError,
};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle
pub type ServerHandle = JoinHandle<Result<(), Box<dyn std::error::Error + Send + Sync>>>;
use tokio::sync::Mutex;
pub enum RunningServiceEnum {
NoInit(RunningService<RoleClient, ()>),
WithInit(RunningService<RoleClient, InitializeRequestParam>),
}
pub type SharedMcpServers = Arc<Mutex<HashMap<String, RunningServiceEnum>>>;
#[derive(Default)]
pub struct AppState {
pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
pub mcp_servers: SharedMcpServers,
pub download_manager: Arc<Mutex<DownloadManagerState>>,
pub mcp_restart_counts: Arc<Mutex<HashMap<String, u32>>>,
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
}
impl RunningServiceEnum {
pub async fn list_all_tools(&self) -> Result<Vec<Tool>, ServiceError> {
match self {
Self::NoInit(s) => s.list_all_tools().await,
Self::WithInit(s) => s.list_all_tools().await,
}
}
pub async fn call_tool(
&self,
params: CallToolRequestParam,
) -> Result<CallToolResult, ServiceError> {
match self {
Self::NoInit(s) => s.call_tool(params).await,
Self::WithInit(s) => s.call_tool(params).await,
}
}
}