From 13a1969150a8840f5eeec325fe512f357e9477a4 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 15 Aug 2025 10:02:06 +0700 Subject: [PATCH] feat: MCP - State update --- src-tauri/Cargo.toml | 3 +- .../permissions/schemas/schema.json | 2 +- .../permissions/schemas/schema.json | 2 +- src-tauri/src/core/mcp/commands.rs | 19 ++++--- src-tauri/src/core/mcp/helpers.rs | 54 +++++++++++++------ src-tauri/src/core/mcp/tests.rs | 4 +- src-tauri/src/core/state.rs | 34 ++++++++++-- 7 files changed, 89 insertions(+), 29 deletions(-) diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 744b84830..58a342a26 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -44,9 +44,10 @@ jan-utils = { path = "./utils" } libloading = "0.8.7" log = "0.4" reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [ +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [ "client", "transport-sse-client", + "transport-streamable-http-client", "transport-child-process", "tower", "reqwest", diff --git a/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json b/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json index 6848c3288..c5abe1f43 100644 --- a/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json +++ b/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json @@ -327,4 +327,4 @@ ] } } -} +} \ No newline at end of file diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json b/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json index f832b4560..70ccaf6f7 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json +++ b/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json @@ -447,4 +447,4 @@ ] } } -} +} \ No newline at end of file diff --git a/src-tauri/src/core/mcp/commands.rs b/src-tauri/src/core/mcp/commands.rs index 56b1a6124..aa19cc2b9 100644 --- a/src-tauri/src/core/mcp/commands.rs +++ b/src-tauri/src/core/mcp/commands.rs @@ -1,14 +1,13 @@ use rmcp::model::{CallToolRequestParam, CallToolResult, Tool}; -use rmcp::{service::RunningService, RoleClient}; use serde_json::{Map, Value}; -use std::{collections::HashMap, sync::Arc}; use tauri::{AppHandle, Emitter, Runtime, State}; -use tokio::{sync::Mutex, time::timeout}; +use tokio::time::timeout; use super::{ constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers}, }; +use crate::core::state::{RunningServiceEnum, SharedMcpServers}; use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; use std::fs; @@ -19,8 +18,7 @@ pub async fn activate_mcp_server( name: String, config: Value, ) -> Result<(), String> { - let servers: Arc>>> = - state.mcp_servers.clone(); + let servers: SharedMcpServers = state.mcp_servers.clone(); // Use the modified start_mcp_server_with_restart that returns first attempt result start_mcp_server_with_restart(app, servers, name, config, Some(3)).await @@ -63,7 +61,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) -> // Release the lock before calling cancel drop(servers_map); - service.cancel().await.map_err(|e| e.to_string())?; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {name}..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {name} with initialization..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + } log::info!("Server {name} stopped successfully and marked as deactivated."); Ok(()) } diff --git a/src-tauri/src/core/mcp/helpers.rs b/src-tauri/src/core/mcp/helpers.rs index e6b72488d..9d8cbeae5 100644 --- a/src-tauri/src/core/mcp/helpers.rs +++ b/src-tauri/src/core/mcp/helpers.rs @@ -1,4 +1,4 @@ -use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt}; +use rmcp::{transport::TokioChildProcess, ServiceExt}; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc, time::Duration}; use tauri::{AppHandle, Emitter, Manager, Runtime, State}; @@ -11,7 +11,10 @@ use tokio::{ use super::constants::{ MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS, }; -use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; +use crate::core::{ + app::commands::get_jan_data_folder_path, + state::{AppState, RunningServiceEnum, SharedMcpServers}, +}; use jan_utils::can_override_npx; /// Calculate exponential backoff delay with jitter @@ -72,7 +75,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 { /// * `Err(String)` if there was an error reading config or starting servers pub async fn run_mcp_commands( app: &AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, ) -> Result<(), String> { let app_path = get_jan_data_folder_path(app.clone()); let app_path_str = app_path.to_str().unwrap().to_string(); @@ -168,7 +171,7 @@ pub async fn run_mcp_commands( /// Monitor MCP server health without removing it from the HashMap pub async fn monitor_mcp_server_handle( - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, ) -> Option { log::info!("Monitoring MCP server {} health", name); @@ -213,7 +216,16 @@ pub async fn monitor_mcp_server_handle( let mut servers = servers_state.lock().await; if let Some(service) = servers.remove(&name) { // Try to cancel the service gracefully - let _ = service.cancel().await; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {name}..."); + let _ = service.cancel().await; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {name} with initialization..."); + let _ = service.cancel().await; + } + } } return Some(rmcp::service::QuitReason::Closed); } @@ -224,7 +236,7 @@ pub async fn monitor_mcp_server_handle( /// Returns the result of the first start attempt, then continues with restart monitoring pub async fn start_mcp_server_with_restart( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: Option, @@ -297,7 +309,7 @@ pub async fn start_mcp_server_with_restart( /// Helper function to handle the restart loop logic pub async fn start_restart_loop( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: u32, @@ -452,7 +464,7 @@ pub async fn start_restart_loop( pub async fn schedule_mcp_start_task( app: tauri::AppHandle, - servers: Arc>>>, + servers: SharedMcpServers, name: String, config: Value, ) -> Result<(), String> { @@ -540,7 +552,12 @@ pub async fn schedule_mcp_start_task( }; // Now move the service into the HashMap - servers.lock().await.insert(name.clone(), service); + // Now move the service into the HashMap + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::NoInit(service)); + log::info!("Server {name} started successfully."); log::info!("Server {name} started successfully."); // Wait a short time to verify the server is stable before marking as connected @@ -604,7 +621,7 @@ pub fn extract_active_status(config: &Value) -> Option { /// Restart only servers that were previously active (like cortex restart behavior) pub async fn restart_active_mcp_servers( app: &AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, ) -> Result<(), String> { let app_state = app.state::(); let active_servers = app_state.mcp_active_servers.lock().await; @@ -656,14 +673,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) { log::info!("MCP servers cleaned up successfully"); } -pub async fn stop_mcp_servers( - servers_state: Arc>>>, -) -> Result<(), String> { +pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> { let mut servers_map = servers_state.lock().await; let keys: Vec = servers_map.keys().cloned().collect(); for key in keys { if let Some(service) = servers_map.remove(&key) { - service.cancel().await.map_err(|e| e.to_string())?; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {key}..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {key} with initialization..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + } } } drop(servers_map); // Release the lock after stopping @@ -689,7 +713,7 @@ pub async fn reset_restart_count(restart_counts: &Arc /// Spawn the server monitoring task for handling restarts pub async fn spawn_server_monitoring_task( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: u32, diff --git a/src-tauri/src/core/mcp/tests.rs b/src-tauri/src/core/mcp/tests.rs index 8346449b2..081a188e8 100644 --- a/src-tauri/src/core/mcp/tests.rs +++ b/src-tauri/src/core/mcp/tests.rs @@ -1,6 +1,6 @@ use super::helpers::run_mcp_commands; use crate::core::app::commands::get_jan_data_folder_path; -use rmcp::{service::RunningService, RoleClient}; +use crate::core::state::SharedMcpServers; use std::collections::HashMap; use std::fs::File; use std::io::Write; @@ -27,7 +27,7 @@ async fn test_run_mcp_commands() { .expect("Failed to write to config file"); // Call the run_mcp_commands function - let servers_state: Arc>>> = + let servers_state: SharedMcpServers = Arc::new(Mutex::new(HashMap::new())); let result = run_mcp_commands(app.handle(), servers_state).await; diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index bda2eb40c..3408052d4 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -1,20 +1,48 @@ use std::{collections::HashMap, sync::Arc}; use crate::core::downloads::models::DownloadManagerState; -use rmcp::{service::RunningService, RoleClient}; +use rmcp::{ + model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool}, + service::RunningService, + RoleClient, ServiceError, +}; +use tokio::sync::Mutex; use tokio::task::JoinHandle; /// Server handle type for managing the proxy server lifecycle pub type ServerHandle = JoinHandle>>; -use tokio::sync::Mutex; + +pub enum RunningServiceEnum { + NoInit(RunningService), + WithInit(RunningService), +} +pub type SharedMcpServers = Arc>>; #[derive(Default)] pub struct AppState { pub app_token: Option, - pub mcp_servers: Arc>>>, + pub mcp_servers: SharedMcpServers, pub download_manager: Arc>, pub mcp_restart_counts: Arc>>, pub mcp_active_servers: Arc>>, pub mcp_successfully_connected: Arc>>, pub server_handle: Arc>>, } + +impl RunningServiceEnum { + pub async fn list_all_tools(&self) -> Result, ServiceError> { + match self { + Self::NoInit(s) => s.list_all_tools().await, + Self::WithInit(s) => s.list_all_tools().await, + } + } + pub async fn call_tool( + &self, + params: CallToolRequestParam, + ) -> Result { + match self { + Self::NoInit(s) => s.call_tool(params).await, + Self::WithInit(s) => s.call_tool(params).await, + } + } +}