From 25043dda7be8b8c9b30d93f9e8f296df1b57dc7e Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 15 Aug 2025 10:12:41 +0700 Subject: [PATCH] feat: MCP streamable http and sse transports --- src-tauri/src/core/mcp/helpers.rs | 352 ++++++++++++++++++++---------- src-tauri/src/core/mcp/mod.rs | 1 + src-tauri/src/core/mcp/models.rs | 11 + 3 files changed, 251 insertions(+), 113 deletions(-) create mode 100644 src-tauri/src/core/mcp/models.rs diff --git a/src-tauri/src/core/mcp/helpers.rs b/src-tauri/src/core/mcp/helpers.rs index 9d8cbeae5..8b2fc0cba 100644 --- a/src-tauri/src/core/mcp/helpers.rs +++ b/src-tauri/src/core/mcp/helpers.rs @@ -1,7 +1,15 @@ -use rmcp::{transport::TokioChildProcess, ServiceExt}; +use rmcp::{ + model::{ClientCapabilities, ClientInfo, Implementation}, + transport::{ + streamable_http_client::StreamableHttpClientTransportConfig, SseClientTransport, + StreamableHttpClientTransport, TokioChildProcess, + }, + ServiceExt, +}; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc, time::Duration}; use tauri::{AppHandle, Emitter, Manager, Runtime, State}; +use tauri_plugin_http::reqwest; use tokio::{ process::Command, sync::Mutex, @@ -12,8 +20,7 @@ use super::constants::{ MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS, }; use crate::core::{ - app::commands::get_jan_data_folder_path, - state::{AppState, RunningServiceEnum, SharedMcpServers}, + app::commands::get_jan_data_folder_path, mcp::models::McpServerConfig, state::{AppState, RunningServiceEnum, SharedMcpServers} }; use jan_utils::can_override_npx; @@ -462,7 +469,7 @@ pub async fn start_restart_loop( } } -pub async fn schedule_mcp_start_task( +async fn schedule_mcp_start_task( app: tauri::AppHandle, servers: SharedMcpServers, name: String, @@ -475,141 +482,260 @@ pub async fn schedule_mcp_start_task( .expect("Executable must have a parent directory"); let bin_path = exe_parent_path.to_path_buf(); - let (command, args, envs) = extract_command_args(&config) + let config_params = extract_command_args(&config) .ok_or_else(|| format!("Failed to extract command args from config for {name}"))?; - let mut cmd = Command::new(command.clone()); + if config_params.transport_type.as_deref() == Some("http") && config_params.url.is_some() { + let transport = StreamableHttpClientTransport::with_client( + reqwest::Client::builder() + .default_headers({ + // Map envs to request headers + let mut headers = reqwest::header::HeaderMap::new(); + for (key, value) in config_params.envs.iter() { + if let Some(v_str) = value.as_str() { + // Try to map env keys to HTTP header names (case-insensitive) + // Most HTTP headers are Title-Case, so we try to convert + let header_name = + reqwest::header::HeaderName::from_bytes(key.as_bytes()); + if let Ok(header_name) = header_name { + if let Ok(header_value) = + reqwest::header::HeaderValue::from_str(v_str) + { + headers.insert(header_name, header_value); + } + } + } + } + headers + }) + .build() + .unwrap(), + StreamableHttpClientTransportConfig { + uri: config_params.url.unwrap().into(), + ..Default::default() + }, + ); - if command == "npx" && can_override_npx() { - let mut cache_dir = app_path.clone(); - cache_dir.push(".npx"); - let bun_x_path = format!("{}/bun", bin_path.display()); - cmd = Command::new(bun_x_path); - cmd.arg("x"); - cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string()); - } + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "Jan Streamable Client".to_string(), + version: "0.0.1".to_string(), + }, + }; + let client = client_info.serve(transport).await.inspect_err(|e| { + log::error!("client error: {:?}", e); + }); - if command == "uvx" { - let mut cache_dir = app_path.clone(); - cache_dir.push(".uvx"); - let bun_x_path = format!("{}/uv", bin_path.display()); - cmd = Command::new(bun_x_path); - cmd.arg("tool"); - cmd.arg("run"); - cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string()); - } + match client { + Ok(client) => { + log::info!("Connected to server: {:?}", client.peer_info()); + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::WithInit(client)); - #[cfg(windows)] - { - cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows - } - - let app_path_str = app_path.to_str().unwrap().to_string(); - let log_file_path = format!("{}/logs/app.log", app_path_str); - match std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(log_file_path) - { - Ok(file) => { - cmd.stderr(std::process::Stdio::from(file)); + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } + } + Err(e) => { + log::error!("Failed to connect to server: {}", e); + return Err(format!("Failed to connect to server: {}", e)); + } } - Err(err) => { - log::error!("Failed to open log file: {}", err); - } - }; - - cmd.kill_on_drop(true); - log::trace!("Command: {cmd:#?}"); - - args.iter().filter_map(Value::as_str).for_each(|arg| { - cmd.arg(arg); - }); - envs.iter().for_each(|(k, v)| { - if let Some(v_str) = v.as_str() { - cmd.env(k, v_str); - } - }); - - let process = TokioChildProcess::new(cmd).map_err(|e| { - log::error!("Failed to run command {name}: {e}"); - format!("Failed to run command {name}: {e}") - })?; - - let service = () - .serve(process) + } else if config_params.transport_type.as_deref() == Some("sse") && config_params.url.is_some() { + let transport = SseClientTransport::start_with_client( + reqwest::Client::builder() + .default_headers({ + // Map envs to request headers + let mut headers = reqwest::header::HeaderMap::new(); + for (key, value) in config_params.envs.iter() { + if let Some(v_str) = value.as_str() { + // Try to map env keys to HTTP header names (case-insensitive) + // Most HTTP headers are Title-Case, so we try to convert + let header_name = + reqwest::header::HeaderName::from_bytes(key.as_bytes()); + if let Ok(header_name) = header_name { + if let Ok(header_value) = + reqwest::header::HeaderValue::from_str(v_str) + { + headers.insert(header_name, header_value); + } + } + } + } + headers + }) + .build() + .unwrap(), + rmcp::transport::sse_client::SseClientConfig { + sse_endpoint: config_params.url.unwrap().into(), + ..Default::default() + }, + ) .await - .map_err(|e| format!("Failed to start MCP server {name}: {e}"))?; + .map_err(|e| { + log::error!("transport error: {:?}", e); + format!("Failed to start SSE transport: {}", e) + })?; - // Get peer info and clone the needed values before moving the service - let (server_name, server_version) = { + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "Jan SSE Client".to_string(), + version: "0.0.1".to_string(), + }, + }; + let client = client_info.serve(transport).await.map_err(|e| { + log::error!("client error: {:?}", e); + e.to_string() + }); + + match client { + Ok(client) => { + log::info!("Connected to server: {:?}", client.peer_info()); + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::WithInit(client)); + + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } + } + Err(e) => { + log::error!("Failed to connect to server: {}", e); + return Err(format!("Failed to connect to server: {}", e)); + } + } + } else { + let mut cmd = Command::new(config_params.command.clone()); + if config_params.command.clone() == "npx" && can_override_npx() { + let mut cache_dir = app_path.clone(); + cache_dir.push(".npx"); + let bun_x_path = format!("{}/bun", bin_path.display()); + cmd = Command::new(bun_x_path); + cmd.arg("x"); + cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string()); + } + if config_params.command.clone() == "uvx" { + let mut cache_dir = app_path.clone(); + cache_dir.push(".uvx"); + let bun_x_path = format!("{}/uv", bin_path.display()); + cmd = Command::new(bun_x_path); + cmd.arg("tool"); + cmd.arg("run"); + cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string()); + } + #[cfg(windows)] + { + cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows + } + let app_path_str = app_path.to_str().unwrap().to_string(); + let log_file_path = format!("{}/logs/app.log", app_path_str); + match std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_file_path) + { + Ok(file) => { + cmd.stderr(std::process::Stdio::from(file)); + } + Err(err) => { + log::error!("Failed to open log file: {}", err); + } + }; + + cmd.kill_on_drop(true); + log::trace!("Command: {cmd:#?}"); + + config_params.args.iter().filter_map(Value::as_str).for_each(|arg| { + cmd.arg(arg); + }); + config_params.envs.iter().for_each(|(k, v)| { + if let Some(v_str) = v.as_str() { + cmd.env(k, v_str); + } + }); + + let process = TokioChildProcess::new(cmd).map_err(|e| { + log::error!("Failed to run command {name}: {e}"); + format!("Failed to run command {name}: {e}") + })?; + + let service = () + .serve(process) + .await + .map_err(|e| format!("Failed to start MCP server {name}: {e}"))?; + + // Get peer info and clone the needed values before moving the service let server_info = service.peer_info(); log::trace!("Connected to server: {server_info:#?}"); - ( - server_info.unwrap().server_info.name.clone(), - server_info.unwrap().server_info.version.clone(), - ) - }; - // Now move the service into the HashMap - // Now move the service into the HashMap - servers - .lock() - .await - .insert(name.clone(), RunningServiceEnum::NoInit(service)); - log::info!("Server {name} started successfully."); - log::info!("Server {name} started successfully."); + // Now move the service into the HashMap + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::NoInit(service)); + log::info!("Server {name} started successfully."); - // Wait a short time to verify the server is stable before marking as connected - // This prevents race conditions where the server quits immediately - let verification_delay = Duration::from_millis(500); - sleep(verification_delay).await; + // Wait a short time to verify the server is stable before marking as connected + // This prevents race conditions where the server quits immediately + let verification_delay = Duration::from_millis(500); + sleep(verification_delay).await; - // Check if server is still running after the verification delay - let server_still_running = { - let servers_map = servers.lock().await; - servers_map.contains_key(&name) - }; + // Check if server is still running after the verification delay + let server_still_running = { + let servers_map = servers.lock().await; + servers_map.contains_key(&name) + }; - if !server_still_running { - return Err(format!( - "MCP server {} quit immediately after starting", - name - )); + if !server_still_running { + return Err(format!( + "MCP server {} quit immediately after starting", + name + )); + } + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } } - - // Mark server as successfully connected (for restart policy) - { - let app_state = app.state::(); - let mut connected = app_state.mcp_successfully_connected.lock().await; - connected.insert(name.clone(), true); - log::info!("Marked MCP server {} as successfully connected", name); - } - - // Emit event to the frontend - let event = format!("mcp-connected"); - let payload = serde_json::json!({ - "name": server_name, - "version": server_version, - }); - app.emit(&event, payload) - .map_err(|e| format!("Failed to emit event: {}", e))?; - Ok(()) } -pub fn extract_command_args( - config: &Value, -) -> Option<(String, Vec, serde_json::Map)> { +pub fn extract_command_args(config: &Value) -> Option { let obj = config.as_object()?; let command = obj.get("command")?.as_str()?.to_string(); let args = obj.get("args")?.as_array()?.clone(); + let url = obj.get("url").and_then(|u| u.as_str()).map(String::from); + let transport_type = obj.get("type").and_then(|t| t.as_str()).map(String::from); let envs = obj .get("env") .unwrap_or(&Value::Object(serde_json::Map::new())) .as_object()? .clone(); - Some((command, args, envs)) + Some(McpServerConfig { + transport_type, + url, + command, + args, + envs, + }) } pub fn extract_active_status(config: &Value) -> Option { diff --git a/src-tauri/src/core/mcp/mod.rs b/src-tauri/src/core/mcp/mod.rs index 5b20160de..b9627f02f 100644 --- a/src-tauri/src/core/mcp/mod.rs +++ b/src-tauri/src/core/mcp/mod.rs @@ -1,6 +1,7 @@ pub mod commands; mod constants; pub mod helpers; +pub mod models; #[cfg(test)] mod tests; diff --git a/src-tauri/src/core/mcp/models.rs b/src-tauri/src/core/mcp/models.rs new file mode 100644 index 000000000..408f290ac --- /dev/null +++ b/src-tauri/src/core/mcp/models.rs @@ -0,0 +1,11 @@ +use serde_json::Value; + +/// Configuration parameters extracted from MCP server config +#[derive(Debug, Clone)] +pub struct McpServerConfig { + pub transport_type: Option, + pub url: Option, + pub command: String, + pub args: Vec, + pub envs: serde_json::Map, +}