diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 744b84830..58a342a26 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -44,9 +44,10 @@ jan-utils = { path = "./utils" } libloading = "0.8.7" log = "0.4" reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } -rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "3196c95f1dfafbffbdcdd6d365c94969ac975e6a", features = [ +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", rev = "209dbac50f51737ad953c3a2c8e28f3619b6c277", features = [ "client", "transport-sse-client", + "transport-streamable-http-client", "transport-child-process", "tower", "reqwest", diff --git a/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json b/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json index 6848c3288..c5abe1f43 100644 --- a/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json +++ b/src-tauri/plugins/tauri-plugin-hardware/permissions/schemas/schema.json @@ -327,4 +327,4 @@ ] } } -} +} \ No newline at end of file diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json b/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json index f832b4560..70ccaf6f7 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json +++ b/src-tauri/plugins/tauri-plugin-llamacpp/permissions/schemas/schema.json @@ -447,4 +447,4 @@ ] } } -} +} \ No newline at end of file diff --git a/src-tauri/src/core/mcp/commands.rs b/src-tauri/src/core/mcp/commands.rs index 56b1a6124..02caca827 100644 --- a/src-tauri/src/core/mcp/commands.rs +++ b/src-tauri/src/core/mcp/commands.rs @@ -1,15 +1,17 @@ -use rmcp::model::{CallToolRequestParam, CallToolResult, Tool}; -use rmcp::{service::RunningService, RoleClient}; +use rmcp::model::{CallToolRequestParam, CallToolResult}; use serde_json::{Map, Value}; -use std::{collections::HashMap, sync::Arc}; use tauri::{AppHandle, Emitter, Runtime, State}; -use tokio::{sync::Mutex, time::timeout}; +use tokio::time::timeout; use super::{ constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, helpers::{restart_active_mcp_servers, start_mcp_server_with_restart, stop_mcp_servers}, }; use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; +use crate::core::{ + mcp::models::ToolWithServer, + state::{RunningServiceEnum, SharedMcpServers}, +}; use std::fs; #[tauri::command] @@ -19,8 +21,7 @@ pub async fn activate_mcp_server( name: String, config: Value, ) -> Result<(), String> { - let servers: Arc>>> = - state.mcp_servers.clone(); + let servers: SharedMcpServers = state.mcp_servers.clone(); // Use the modified start_mcp_server_with_restart that returns first attempt result start_mcp_server_with_restart(app, servers, name, config, Some(3)).await @@ -63,7 +64,16 @@ pub async fn deactivate_mcp_server(state: State<'_, AppState>, name: String) -> // Release the lock before calling cancel drop(servers_map); - service.cancel().await.map_err(|e| e.to_string())?; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {name}..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {name} with initialization..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + } log::info!("Server {name} stopped successfully and marked as deactivated."); Ok(()) } @@ -116,7 +126,7 @@ pub async fn get_connected_servers( Ok(servers_map.keys().cloned().collect()) } -/// Retrieves all available tools from all MCP servers +/// Retrieves all available tools from all MCP servers with server information /// /// # Arguments /// * `state` - Application state containing MCP server connections @@ -128,14 +138,15 @@ pub async fn get_connected_servers( /// 1. Locks the MCP servers mutex to access server connections /// 2. Iterates through all connected servers /// 3. Gets the list of tools from each server -/// 4. Combines all tools into a single vector -/// 5. Returns the combined list of all available tools +/// 4. Associates each tool with its parent server name +/// 5. Combines all tools into a single vector +/// 6. Returns the combined list of all available tools with server information #[tauri::command] -pub async fn get_tools(state: State<'_, AppState>) -> Result, String> { +pub async fn get_tools(state: State<'_, AppState>) -> Result, String> { let servers = state.mcp_servers.lock().await; - let mut all_tools: Vec = Vec::new(); + let mut all_tools: Vec = Vec::new(); - for (_, service) in servers.iter() { + for (server_name, service) in servers.iter() { // List tools with timeout let tools_future = service.list_all_tools(); let tools = match timeout(MCP_TOOL_CALL_TIMEOUT, tools_future).await { @@ -150,7 +161,12 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result, String> }; for tool in tools { - all_tools.push(tool); + all_tools.push(ToolWithServer { + name: tool.name.to_string(), + description: tool.description.as_ref().map(|d| d.to_string()), + input_schema: serde_json::Value::Object((*tool.input_schema).clone()), + server: server_name.clone(), + }); } } diff --git a/src-tauri/src/core/mcp/helpers.rs b/src-tauri/src/core/mcp/helpers.rs index e6b72488d..75a1bba3a 100644 --- a/src-tauri/src/core/mcp/helpers.rs +++ b/src-tauri/src/core/mcp/helpers.rs @@ -1,7 +1,15 @@ -use rmcp::{service::RunningService, transport::TokioChildProcess, RoleClient, ServiceExt}; +use rmcp::{ + model::{ClientCapabilities, ClientInfo, Implementation}, + transport::{ + streamable_http_client::StreamableHttpClientTransportConfig, SseClientTransport, + StreamableHttpClientTransport, TokioChildProcess, + }, + ServiceExt, +}; use serde_json::Value; use std::{collections::HashMap, env, sync::Arc, time::Duration}; use tauri::{AppHandle, Emitter, Manager, Runtime, State}; +use tauri_plugin_http::reqwest; use tokio::{ process::Command, sync::Mutex, @@ -11,7 +19,11 @@ use tokio::{ use super::constants::{ MCP_BACKOFF_MULTIPLIER, MCP_BASE_RESTART_DELAY_MS, MCP_MAX_RESTART_DELAY_MS, }; -use crate::core::{app::commands::get_jan_data_folder_path, state::AppState}; +use crate::core::{ + app::commands::get_jan_data_folder_path, + mcp::models::McpServerConfig, + state::{AppState, RunningServiceEnum, SharedMcpServers}, +}; use jan_utils::can_override_npx; /// Calculate exponential backoff delay with jitter @@ -72,7 +84,7 @@ pub fn calculate_exponential_backoff_delay(attempt: u32) -> u64 { /// * `Err(String)` if there was an error reading config or starting servers pub async fn run_mcp_commands( app: &AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, ) -> Result<(), String> { let app_path = get_jan_data_folder_path(app.clone()); let app_path_str = app_path.to_str().unwrap().to_string(); @@ -168,7 +180,7 @@ pub async fn run_mcp_commands( /// Monitor MCP server health without removing it from the HashMap pub async fn monitor_mcp_server_handle( - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, ) -> Option { log::info!("Monitoring MCP server {} health", name); @@ -213,7 +225,16 @@ pub async fn monitor_mcp_server_handle( let mut servers = servers_state.lock().await; if let Some(service) = servers.remove(&name) { // Try to cancel the service gracefully - let _ = service.cancel().await; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {name}..."); + let _ = service.cancel().await; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {name} with initialization..."); + let _ = service.cancel().await; + } + } } return Some(rmcp::service::QuitReason::Closed); } @@ -224,7 +245,7 @@ pub async fn monitor_mcp_server_handle( /// Returns the result of the first start attempt, then continues with restart monitoring pub async fn start_mcp_server_with_restart( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: Option, @@ -297,7 +318,7 @@ pub async fn start_mcp_server_with_restart( /// Helper function to handle the restart loop logic pub async fn start_restart_loop( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: u32, @@ -450,9 +471,9 @@ pub async fn start_restart_loop( } } -pub async fn schedule_mcp_start_task( +async fn schedule_mcp_start_task( app: tauri::AppHandle, - servers: Arc>>>, + servers: SharedMcpServers, name: String, config: Value, ) -> Result<(), String> { @@ -463,136 +484,278 @@ pub async fn schedule_mcp_start_task( .expect("Executable must have a parent directory"); let bin_path = exe_parent_path.to_path_buf(); - let (command, args, envs) = extract_command_args(&config) + let config_params = extract_command_args(&config) .ok_or_else(|| format!("Failed to extract command args from config for {name}"))?; - let mut cmd = Command::new(command.clone()); + if config_params.transport_type.as_deref() == Some("http") && config_params.url.is_some() { + let transport = StreamableHttpClientTransport::with_client( + reqwest::Client::builder() + .default_headers({ + // Map envs to request headers + let mut headers: tauri::http::HeaderMap = reqwest::header::HeaderMap::new(); + for (key, value) in config_params.headers.iter() { + if let Some(v_str) = value.as_str() { + // Try to map env keys to HTTP header names (case-insensitive) + // Most HTTP headers are Title-Case, so we try to convert + let header_name = + reqwest::header::HeaderName::from_bytes(key.as_bytes()); + if let Ok(header_name) = header_name { + if let Ok(header_value) = + reqwest::header::HeaderValue::from_str(v_str) + { + headers.insert(header_name, header_value); + } + } + } + } + headers + }) + .connect_timeout(config_params.timeout.unwrap_or(Duration::MAX)) + .build() + .unwrap(), + StreamableHttpClientTransportConfig { + uri: config_params.url.unwrap().into(), + ..Default::default() + }, + ); - if command == "npx" && can_override_npx() { - let mut cache_dir = app_path.clone(); - cache_dir.push(".npx"); - let bun_x_path = format!("{}/bun", bin_path.display()); - cmd = Command::new(bun_x_path); - cmd.arg("x"); - cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string()); - } + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "Jan Streamable Client".to_string(), + version: "0.0.1".to_string(), + }, + }; + let client = client_info.serve(transport).await.inspect_err(|e| { + log::error!("client error: {:?}", e); + }); - if command == "uvx" { - let mut cache_dir = app_path.clone(); - cache_dir.push(".uvx"); - let bun_x_path = format!("{}/uv", bin_path.display()); - cmd = Command::new(bun_x_path); - cmd.arg("tool"); - cmd.arg("run"); - cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string()); - } + match client { + Ok(client) => { + log::info!("Connected to server: {:?}", client.peer_info()); + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::WithInit(client)); - #[cfg(windows)] + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } + } + Err(e) => { + log::error!("Failed to connect to server: {}", e); + return Err(format!("Failed to connect to server: {}", e)); + } + } + } else if config_params.transport_type.as_deref() == Some("sse") && config_params.url.is_some() { - cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows - } - - let app_path_str = app_path.to_str().unwrap().to_string(); - let log_file_path = format!("{}/logs/app.log", app_path_str); - match std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(log_file_path) - { - Ok(file) => { - cmd.stderr(std::process::Stdio::from(file)); - } - Err(err) => { - log::error!("Failed to open log file: {}", err); - } - }; - - cmd.kill_on_drop(true); - log::trace!("Command: {cmd:#?}"); - - args.iter().filter_map(Value::as_str).for_each(|arg| { - cmd.arg(arg); - }); - envs.iter().for_each(|(k, v)| { - if let Some(v_str) = v.as_str() { - cmd.env(k, v_str); - } - }); - - let process = TokioChildProcess::new(cmd).map_err(|e| { - log::error!("Failed to run command {name}: {e}"); - format!("Failed to run command {name}: {e}") - })?; - - let service = () - .serve(process) + let transport = SseClientTransport::start_with_client( + reqwest::Client::builder() + .default_headers({ + // Map envs to request headers + let mut headers = reqwest::header::HeaderMap::new(); + for (key, value) in config_params.headers.iter() { + if let Some(v_str) = value.as_str() { + // Try to map env keys to HTTP header names (case-insensitive) + // Most HTTP headers are Title-Case, so we try to convert + let header_name = + reqwest::header::HeaderName::from_bytes(key.as_bytes()); + if let Ok(header_name) = header_name { + if let Ok(header_value) = + reqwest::header::HeaderValue::from_str(v_str) + { + headers.insert(header_name, header_value); + } + } + } + } + headers + }) + .connect_timeout(config_params.timeout.unwrap_or(Duration::MAX)) + .build() + .unwrap(), + rmcp::transport::sse_client::SseClientConfig { + sse_endpoint: config_params.url.unwrap().into(), + ..Default::default() + }, + ) .await - .map_err(|e| format!("Failed to start MCP server {name}: {e}"))?; + .map_err(|e| { + log::error!("transport error: {:?}", e); + format!("Failed to start SSE transport: {}", e) + })?; - // Get peer info and clone the needed values before moving the service - let (server_name, server_version) = { + let client_info = ClientInfo { + protocol_version: Default::default(), + capabilities: ClientCapabilities::default(), + client_info: Implementation { + name: "Jan SSE Client".to_string(), + version: "0.0.1".to_string(), + }, + }; + let client = client_info.serve(transport).await.map_err(|e| { + log::error!("client error: {:?}", e); + e.to_string() + }); + + match client { + Ok(client) => { + log::info!("Connected to server: {:?}", client.peer_info()); + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::WithInit(client)); + + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } + } + Err(e) => { + log::error!("Failed to connect to server: {}", e); + return Err(format!("Failed to connect to server: {}", e)); + } + } + } else { + let mut cmd = Command::new(config_params.command.clone()); + if config_params.command.clone() == "npx" && can_override_npx() { + let mut cache_dir = app_path.clone(); + cache_dir.push(".npx"); + let bun_x_path = format!("{}/bun", bin_path.display()); + cmd = Command::new(bun_x_path); + cmd.arg("x"); + cmd.env("BUN_INSTALL", cache_dir.to_str().unwrap().to_string()); + } + if config_params.command.clone() == "uvx" { + let mut cache_dir = app_path.clone(); + cache_dir.push(".uvx"); + let bun_x_path = format!("{}/uv", bin_path.display()); + cmd = Command::new(bun_x_path); + cmd.arg("tool"); + cmd.arg("run"); + cmd.env("UV_CACHE_DIR", cache_dir.to_str().unwrap().to_string()); + } + #[cfg(windows)] + { + cmd.creation_flags(0x08000000); // CREATE_NO_WINDOW: prevents shell window on Windows + } + let app_path_str = app_path.to_str().unwrap().to_string(); + let log_file_path = format!("{}/logs/app.log", app_path_str); + match std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(log_file_path) + { + Ok(file) => { + cmd.stderr(std::process::Stdio::from(file)); + } + Err(err) => { + log::error!("Failed to open log file: {}", err); + } + }; + + cmd.kill_on_drop(true); + log::trace!("Command: {cmd:#?}"); + + config_params + .args + .iter() + .filter_map(Value::as_str) + .for_each(|arg| { + cmd.arg(arg); + }); + config_params.envs.iter().for_each(|(k, v)| { + if let Some(v_str) = v.as_str() { + cmd.env(k, v_str); + } + }); + + let process = TokioChildProcess::new(cmd).map_err(|e| { + log::error!("Failed to run command {name}: {e}"); + format!("Failed to run command {name}: {e}") + })?; + + let service = () + .serve(process) + .await + .map_err(|e| format!("Failed to start MCP server {name}: {e}"))?; + + // Get peer info and clone the needed values before moving the service let server_info = service.peer_info(); log::trace!("Connected to server: {server_info:#?}"); - ( - server_info.unwrap().server_info.name.clone(), - server_info.unwrap().server_info.version.clone(), - ) - }; - // Now move the service into the HashMap - servers.lock().await.insert(name.clone(), service); - log::info!("Server {name} started successfully."); + // Now move the service into the HashMap + servers + .lock() + .await + .insert(name.clone(), RunningServiceEnum::NoInit(service)); + log::info!("Server {name} started successfully."); - // Wait a short time to verify the server is stable before marking as connected - // This prevents race conditions where the server quits immediately - let verification_delay = Duration::from_millis(500); - sleep(verification_delay).await; + // Wait a short time to verify the server is stable before marking as connected + // This prevents race conditions where the server quits immediately + let verification_delay = Duration::from_millis(500); + sleep(verification_delay).await; - // Check if server is still running after the verification delay - let server_still_running = { - let servers_map = servers.lock().await; - servers_map.contains_key(&name) - }; + // Check if server is still running after the verification delay + let server_still_running = { + let servers_map = servers.lock().await; + servers_map.contains_key(&name) + }; - if !server_still_running { - return Err(format!( - "MCP server {} quit immediately after starting", - name - )); + if !server_still_running { + return Err(format!( + "MCP server {} quit immediately after starting", + name + )); + } + // Mark server as successfully connected (for restart policy) + { + let app_state = app.state::(); + let mut connected = app_state.mcp_successfully_connected.lock().await; + connected.insert(name.clone(), true); + log::info!("Marked MCP server {} as successfully connected", name); + } } - - // Mark server as successfully connected (for restart policy) - { - let app_state = app.state::(); - let mut connected = app_state.mcp_successfully_connected.lock().await; - connected.insert(name.clone(), true); - log::info!("Marked MCP server {} as successfully connected", name); - } - - // Emit event to the frontend - let event = format!("mcp-connected"); - let payload = serde_json::json!({ - "name": server_name, - "version": server_version, - }); - app.emit(&event, payload) - .map_err(|e| format!("Failed to emit event: {}", e))?; - Ok(()) } -pub fn extract_command_args( - config: &Value, -) -> Option<(String, Vec, serde_json::Map)> { +pub fn extract_command_args(config: &Value) -> Option { let obj = config.as_object()?; let command = obj.get("command")?.as_str()?.to_string(); let args = obj.get("args")?.as_array()?.clone(); + let url = obj.get("url").and_then(|u| u.as_str()).map(String::from); + let transport_type = obj.get("type").and_then(|t| t.as_str()).map(String::from); + let timeout = obj + .get("timeout") + .and_then(|t| t.as_u64()) + .map(Duration::from_secs); + let headers = obj + .get("headers") + .unwrap_or(&Value::Object(serde_json::Map::new())) + .as_object()? + .clone(); let envs = obj .get("env") .unwrap_or(&Value::Object(serde_json::Map::new())) .as_object()? .clone(); - Some((command, args, envs)) + Some(McpServerConfig { + timeout, + transport_type, + url, + command, + args, + envs, + headers + }) } pub fn extract_active_status(config: &Value) -> Option { @@ -604,7 +767,7 @@ pub fn extract_active_status(config: &Value) -> Option { /// Restart only servers that were previously active (like cortex restart behavior) pub async fn restart_active_mcp_servers( app: &AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, ) -> Result<(), String> { let app_state = app.state::(); let active_servers = app_state.mcp_active_servers.lock().await; @@ -656,14 +819,21 @@ pub async fn clean_up_mcp_servers(state: State<'_, AppState>) { log::info!("MCP servers cleaned up successfully"); } -pub async fn stop_mcp_servers( - servers_state: Arc>>>, -) -> Result<(), String> { +pub async fn stop_mcp_servers(servers_state: SharedMcpServers) -> Result<(), String> { let mut servers_map = servers_state.lock().await; let keys: Vec = servers_map.keys().cloned().collect(); for key in keys { if let Some(service) = servers_map.remove(&key) { - service.cancel().await.map_err(|e| e.to_string())?; + match service { + RunningServiceEnum::NoInit(service) => { + log::info!("Stopping server {key}..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + RunningServiceEnum::WithInit(service) => { + log::info!("Stopping server {key} with initialization..."); + service.cancel().await.map_err(|e| e.to_string())?; + } + } } } drop(servers_map); // Release the lock after stopping @@ -689,7 +859,7 @@ pub async fn reset_restart_count(restart_counts: &Arc /// Spawn the server monitoring task for handling restarts pub async fn spawn_server_monitoring_task( app: AppHandle, - servers_state: Arc>>>, + servers_state: SharedMcpServers, name: String, config: Value, max_restarts: u32, diff --git a/src-tauri/src/core/mcp/mod.rs b/src-tauri/src/core/mcp/mod.rs index 5b20160de..b9627f02f 100644 --- a/src-tauri/src/core/mcp/mod.rs +++ b/src-tauri/src/core/mcp/mod.rs @@ -1,6 +1,7 @@ pub mod commands; mod constants; pub mod helpers; +pub mod models; #[cfg(test)] mod tests; diff --git a/src-tauri/src/core/mcp/models.rs b/src-tauri/src/core/mcp/models.rs new file mode 100644 index 000000000..cd3debbc8 --- /dev/null +++ b/src-tauri/src/core/mcp/models.rs @@ -0,0 +1,26 @@ +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Configuration parameters extracted from MCP server config +#[derive(Debug, Clone)] +pub struct McpServerConfig { + pub transport_type: Option, + pub url: Option, + pub command: String, + pub args: Vec, + pub envs: serde_json::Map, + pub timeout: Option, + pub headers: serde_json::Map, +} + +/// Tool with server information +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolWithServer { + pub name: String, + pub description: Option, + #[serde(rename = "inputSchema")] + pub input_schema: serde_json::Value, + pub server: String, +} diff --git a/src-tauri/src/core/mcp/tests.rs b/src-tauri/src/core/mcp/tests.rs index 8346449b2..081a188e8 100644 --- a/src-tauri/src/core/mcp/tests.rs +++ b/src-tauri/src/core/mcp/tests.rs @@ -1,6 +1,6 @@ use super::helpers::run_mcp_commands; use crate::core::app::commands::get_jan_data_folder_path; -use rmcp::{service::RunningService, RoleClient}; +use crate::core::state::SharedMcpServers; use std::collections::HashMap; use std::fs::File; use std::io::Write; @@ -27,7 +27,7 @@ async fn test_run_mcp_commands() { .expect("Failed to write to config file"); // Call the run_mcp_commands function - let servers_state: Arc>>> = + let servers_state: SharedMcpServers = Arc::new(Mutex::new(HashMap::new())); let result = run_mcp_commands(app.handle(), servers_state).await; diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index bda2eb40c..3408052d4 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -1,20 +1,48 @@ use std::{collections::HashMap, sync::Arc}; use crate::core::downloads::models::DownloadManagerState; -use rmcp::{service::RunningService, RoleClient}; +use rmcp::{ + model::{CallToolRequestParam, CallToolResult, InitializeRequestParam, Tool}, + service::RunningService, + RoleClient, ServiceError, +}; +use tokio::sync::Mutex; use tokio::task::JoinHandle; /// Server handle type for managing the proxy server lifecycle pub type ServerHandle = JoinHandle>>; -use tokio::sync::Mutex; + +pub enum RunningServiceEnum { + NoInit(RunningService), + WithInit(RunningService), +} +pub type SharedMcpServers = Arc>>; #[derive(Default)] pub struct AppState { pub app_token: Option, - pub mcp_servers: Arc>>>, + pub mcp_servers: SharedMcpServers, pub download_manager: Arc>, pub mcp_restart_counts: Arc>>, pub mcp_active_servers: Arc>>, pub mcp_successfully_connected: Arc>>, pub server_handle: Arc>>, } + +impl RunningServiceEnum { + pub async fn list_all_tools(&self) -> Result, ServiceError> { + match self { + Self::NoInit(s) => s.list_all_tools().await, + Self::WithInit(s) => s.list_all_tools().await, + } + } + pub async fn call_tool( + &self, + params: CallToolRequestParam, + ) -> Result { + match self { + Self::NoInit(s) => s.call_tool(params).await, + Self::WithInit(s) => s.call_tool(params).await, + } + } +} diff --git a/web-app/package.json b/web-app/package.json index f469a4998..ac1d7366f 100644 --- a/web-app/package.json +++ b/web-app/package.json @@ -17,11 +17,12 @@ "@dnd-kit/sortable": "^10.0.0", "@janhq/core": "link:../core", "@radix-ui/react-accordion": "^1.2.10", - "@radix-ui/react-dialog": "^1.1.11", - "@radix-ui/react-dropdown-menu": "^2.1.11", + "@radix-ui/react-dialog": "^1.1.14", + "@radix-ui/react-dropdown-menu": "^2.1.15", "@radix-ui/react-hover-card": "^1.1.14", "@radix-ui/react-popover": "^1.1.13", "@radix-ui/react-progress": "^1.1.4", + "@radix-ui/react-radio-group": "^1.3.7", "@radix-ui/react-slider": "^1.3.2", "@radix-ui/react-slot": "^1.2.0", "@radix-ui/react-switch": "^1.2.2", @@ -43,13 +44,14 @@ "class-variance-authority": "^0.7.1", "culori": "^4.0.1", "emoji-picker-react": "^4.12.2", + "framer-motion": "^12.23.12", "fuse.js": "^7.1.0", "fzf": "^0.5.2", "i18next": "^25.0.1", "katex": "^0.16.22", "lodash.clonedeep": "^4.5.0", "lodash.debounce": "^4.0.8", - "lucide-react": "^0.522.0", + "lucide-react": "^0.536.0", "motion": "^12.10.5", "next-themes": "^0.4.6", "posthog-js": "^1.246.0", @@ -75,6 +77,7 @@ "ulidx": "^2.4.1", "unified": "^11.0.5", "uuid": "^11.1.0", + "vaul": "^1.1.2", "zustand": "^5.0.3" }, "devDependencies": { @@ -104,7 +107,7 @@ "istanbul-lib-report": "^3.0.1", "istanbul-reports": "^3.1.7", "jsdom": "^26.1.0", - "tailwind-merge": "^3.2.0", + "tailwind-merge": "^3.3.1", "typescript": "~5.8.3", "typescript-eslint": "^8.26.1", "vite": "^6.3.0", diff --git a/web-app/src/components/ui/__tests__/dropdrawer.test.tsx b/web-app/src/components/ui/__tests__/dropdrawer.test.tsx new file mode 100644 index 000000000..6203d9f4e --- /dev/null +++ b/web-app/src/components/ui/__tests__/dropdrawer.test.tsx @@ -0,0 +1,533 @@ +import React from 'react' +import { render, screen } from '@testing-library/react' +import { describe, it, expect, vi, beforeEach } from 'vitest' +import '@testing-library/jest-dom' + +import { + DropDrawer, + DropDrawerContent, + DropDrawerFooter, + DropDrawerGroup, + DropDrawerItem, + DropDrawerLabel, + DropDrawerSeparator, + DropDrawerSub, + DropDrawerSubContent, + DropDrawerSubTrigger, + DropDrawerTrigger, +} from '../dropdrawer' + +// Mock the media query hook +const mockUseSmallScreen = vi.fn() +vi.mock('@/hooks/useMediaQuery', () => ({ + useSmallScreen: () => mockUseSmallScreen(), +})) + +// Mock framer-motion to avoid animation complexity in tests +vi.mock('framer-motion', () => ({ + AnimatePresence: ({ children }: { children: React.ReactNode }) =>
{children}
, + motion: { + div: ({ children, ...props }: any) =>
{children}
, + }, +})) + +describe('DropDrawer Utilities', () => { + it('renders without crashing', () => { + expect(() => { + render( + + Test + + Item + + + ) + }).not.toThrow() + }) +}) + +describe('DropDrawer Component', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Desktop Mode', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(false) + }) + + it('renders dropdown menu on desktop', () => { + render( + + Open Menu + + Item 1 + Item 2 + + + ) + + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('renders dropdown menu structure', () => { + render( + + Open Menu + + Desktop Item + + + ) + + // Only the trigger is visible initially + expect(screen.getByText('Open Menu')).toBeInTheDocument() + expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'menu') + }) + + it('structures dropdown with separators', () => { + render( + + Open Menu + + Item 1 + + Item 2 + + + ) + + // Verify component structure - content is not visible until opened + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('structures dropdown with labels', () => { + render( + + Open Menu + + Menu Section + Item 1 + + + ) + + // Only verify trigger is present - content shows on interaction + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + }) + + describe('Mobile Mode', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(true) + }) + + it('renders drawer on mobile', () => { + render( + + Open Drawer + + Mobile Item + + + ) + + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + + it('renders drawer structure', () => { + render( + + Open Drawer + + Mobile Item + + + ) + + // Verify drawer trigger is present + const trigger = screen.getByText('Open Drawer') + expect(trigger).toBeInTheDocument() + expect(screen.getByRole('button')).toHaveAttribute('aria-haspopup', 'dialog') + }) + + it('does not render separators in mobile mode', () => { + render( + + Open Drawer + + Item 1 + + Item 2 + + + ) + + // Mobile separators return null, so they shouldn't be in the DOM + const separators = screen.queryAllByRole('separator') + expect(separators).toHaveLength(0) + }) + + it('renders drawer with labels structure', () => { + render( + + Open Drawer + + Drawer Section + Item 1 + + + ) + + // Verify drawer structure is present + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + }) + + describe('DropDrawerItem', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(false) + }) + + it('can be structured with click handlers', () => { + const handleClick = vi.fn() + + render( + + Open Menu + + Clickable Item + + + ) + + // Verify structure is valid + expect(screen.getByText('Open Menu')).toBeInTheDocument() + expect(handleClick).not.toHaveBeenCalled() + }) + + it('can be structured with icons', () => { + const TestIcon = () => Icon + + render( + + Open Menu + + }>Item with Icon + + + ) + + // Structure is valid + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('accepts variant props', () => { + render( + + Open Menu + + + Delete Item + + + + ) + + // Component structure is valid with variants + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('accepts disabled prop', () => { + render( + + Open Menu + + + Disabled Item + + + + ) + + // Component structure is valid with disabled prop + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + }) + + describe('DropDrawerGroup', () => { + it('structures groups in desktop mode', () => { + mockUseSmallScreen.mockReturnValue(false) + + render( + + Open Menu + + + Group Item 1 + Group Item 2 + + + + ) + + // Component structure is valid + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('structures groups in mobile mode', () => { + mockUseSmallScreen.mockReturnValue(true) + + render( + + Open Drawer + + + Item 1 + Item 2 + + + + ) + + // Component structure is valid in mobile mode + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + }) + + describe('DropDrawerFooter', () => { + it('structures footer in desktop mode', () => { + mockUseSmallScreen.mockReturnValue(false) + + render( + + Open Menu + + Item + Footer Content + + + ) + + // Component structure is valid + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('structures footer in mobile mode', () => { + mockUseSmallScreen.mockReturnValue(true) + + render( + + Open Drawer + + Item + Mobile Footer + + + ) + + // Component structure is valid in mobile mode + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + }) + + describe('Submenu Components', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(false) + }) + + it('structures submenu in desktop mode', () => { + render( + + Open Menu + + + Submenu Trigger + + Submenu Item + + + + + ) + + // Component structure is valid + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + + it('structures submenu in mobile mode', () => { + mockUseSmallScreen.mockReturnValue(true) + + render( + + Open Drawer + + + + Mobile Submenu + + + Submenu Item + + + + + ) + + // Component structure is valid in mobile mode + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + + it('handles submenu content correctly in mobile mode', () => { + mockUseSmallScreen.mockReturnValue(true) + + render( + + Open Drawer + + + Mobile Submenu + + Hidden Item + + + + + ) + + // Component handles mobile submenu structure correctly + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + }) + }) + + describe('Accessibility', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(false) + }) + + it('maintains proper ARIA attributes on triggers', () => { + render( + + Open Menu + + + Item 1 + + + + ) + + const trigger = screen.getByRole('button') + expect(trigger).toHaveAttribute('aria-haspopup', 'menu') + }) + + it('supports disabled state', () => { + const handleClick = vi.fn() + + mockUseSmallScreen.mockReturnValue(true) + + render( + + Open Drawer + + + Disabled Item + + + + ) + + // Component structure supports disabled prop + expect(screen.getByText('Open Drawer')).toBeInTheDocument() + expect(handleClick).not.toHaveBeenCalled() + }) + }) + + describe('Error Boundaries', () => { + it('requires proper context usage', () => { + // Suppress console.error for this test + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + expect(() => { + render(Orphan Item) + }).toThrow() + + consoleSpy.mockRestore() + }) + }) + + describe('Custom Props and Styling', () => { + beforeEach(() => { + mockUseSmallScreen.mockReturnValue(false) + }) + + it('applies custom className', () => { + render( + + Custom Trigger + + Custom Item + + + ) + + const trigger = screen.getByText('Custom Trigger') + expect(trigger).toHaveClass('custom-trigger') + }) + + it('accepts additional props', () => { + render( + + Open Menu + + Custom Props Item + + + ) + + // Component structure accepts custom props + expect(screen.getByText('Open Menu')).toBeInTheDocument() + }) + }) + + describe('Responsive Behavior', () => { + it('adapts to different screen sizes', () => { + const { rerender } = render( + + Responsive Trigger + + Responsive Item + + + ) + + // Desktop mode + mockUseSmallScreen.mockReturnValue(false) + rerender( + + Responsive Trigger + + Responsive Item + + + ) + + let trigger = screen.getByText('Responsive Trigger') + expect(trigger).toHaveAttribute('aria-haspopup', 'menu') + + // Mobile mode + mockUseSmallScreen.mockReturnValue(true) + rerender( + + Responsive Trigger + + Responsive Item + + + ) + + trigger = screen.getByText('Responsive Trigger') + expect(trigger).toHaveAttribute('aria-haspopup', 'dialog') + }) + }) +}) \ No newline at end of file diff --git a/web-app/src/components/ui/__tests__/radio-group.test.tsx b/web-app/src/components/ui/__tests__/radio-group.test.tsx new file mode 100644 index 000000000..a788931d8 --- /dev/null +++ b/web-app/src/components/ui/__tests__/radio-group.test.tsx @@ -0,0 +1,62 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import { RadioGroup, RadioGroupItem } from '../radio-group' + +describe('RadioGroup', () => { + it('renders radio items correctly', () => { + render( + +
+ + +
+
+ + +
+
+ ) + + expect(screen.getByLabelText('HTTP')).toBeInTheDocument() + expect(screen.getByLabelText('SSE')).toBeInTheDocument() + }) + + it('allows selecting different options', async () => { + const user = userEvent.setup() + const onValueChange = vi.fn() + + render( + +
+ + +
+
+ + +
+
+ ) + + await user.click(screen.getByLabelText('SSE')) + expect(onValueChange).toHaveBeenCalledWith('sse') + }) + + it('has correct default selection', () => { + render( + +
+ + +
+
+ + +
+
+ ) + + expect(screen.getByLabelText('HTTP')).toBeChecked() + expect(screen.getByLabelText('SSE')).not.toBeChecked() + }) +}) \ No newline at end of file diff --git a/web-app/src/components/ui/drawer.tsx b/web-app/src/components/ui/drawer.tsx new file mode 100644 index 000000000..6766e6caf --- /dev/null +++ b/web-app/src/components/ui/drawer.tsx @@ -0,0 +1,133 @@ +import * as React from 'react' +import { Drawer as DrawerPrimitive } from 'vaul' + +import { cn } from '@/lib/utils' + +function Drawer({ + ...props +}: React.ComponentProps) { + return +} + +function DrawerTrigger({ + ...props +}: React.ComponentProps) { + return +} + +function DrawerPortal({ + ...props +}: React.ComponentProps) { + return +} + +function DrawerClose({ + ...props +}: React.ComponentProps) { + return +} + +function DrawerOverlay({ + className, + ...props +}: React.ComponentProps) { + return ( + + ) +} + +function DrawerContent({ + className, + children, + ...props +}: React.ComponentProps) { + return ( + + + +
+ {children} + + + ) +} + +function DrawerHeader({ className, ...props }: React.ComponentProps<'div'>) { + return ( +
+ ) +} + +function DrawerFooter({ className, ...props }: React.ComponentProps<'div'>) { + return ( +
+ ) +} + +function DrawerTitle({ + className, + ...props +}: React.ComponentProps) { + return ( + + ) +} + +function DrawerDescription({ + className, + ...props +}: React.ComponentProps) { + return ( + + ) +} + +export { + Drawer, + DrawerPortal, + DrawerOverlay, + DrawerTrigger, + DrawerClose, + DrawerContent, + DrawerHeader, + DrawerFooter, + DrawerTitle, + DrawerDescription, +} diff --git a/web-app/src/components/ui/dropdown-menu.tsx b/web-app/src/components/ui/dropdown-menu.tsx index 2a75d726d..15f721e2e 100644 --- a/web-app/src/components/ui/dropdown-menu.tsx +++ b/web-app/src/components/ui/dropdown-menu.tsx @@ -61,14 +61,17 @@ function DropdownMenuGroup({ function DropdownMenuItem({ className, inset, + variant = 'default', ...props }: React.ComponentProps & { inset?: boolean + variant?: 'default' | 'destructive' }) { return ( {children} - + ) } diff --git a/web-app/src/components/ui/dropdrawer.tsx b/web-app/src/components/ui/dropdrawer.tsx new file mode 100644 index 000000000..a727dd4b0 --- /dev/null +++ b/web-app/src/components/ui/dropdrawer.tsx @@ -0,0 +1,949 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +'use client' + +import { AnimatePresence, motion } from 'framer-motion' +import { ChevronLeftIcon, ChevronRightIcon } from 'lucide-react' +import * as React from 'react' + +import { + Drawer, + DrawerClose, + DrawerContent, + DrawerFooter, + DrawerHeader, + DrawerTitle, + DrawerTrigger, +} from '@/components/ui/drawer' + +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from '@/components/ui/dropdown-menu' + +import { cn } from '@/lib/utils' +import { useSmallScreen } from '@/hooks/useMediaQuery' + +const ANIMATION_CONFIG = { + variants: { + enter: (direction: 'forward' | 'backward') => ({ + x: direction === 'forward' ? '100%' : '-100%', + opacity: 0, + }), + center: { + x: 0, + opacity: 1, + }, + exit: (direction: 'forward' | 'backward') => ({ + x: direction === 'forward' ? '-100%' : '100%', + opacity: 0, + }), + }, + transition: { + duration: 0.3, + ease: [0.25, 0.1, 0.25, 1.0], + }, +} as const + +const getMobileItemStyles = ( + isInsideGroup: boolean, + inset?: boolean, + variant?: string, + disabled?: boolean +) => { + return cn( + 'flex cursor-pointer items-center justify-between px-4 py-4 w-full gap-4', + !isInsideGroup && 'bg-main-view-fg/50 mx-2 my-1.5 rounded-md', + isInsideGroup && 'bg-transparent py-4', + inset && 'pl-8', + variant === 'destructive' && 'text-destructive', + disabled && 'pointer-events-none opacity-50' + ) +} + +const DropDrawerContext = React.createContext<{ isMobile: boolean }>({ + isMobile: false, +}) + +const useDropDrawerContext = () => { + const context = React.useContext(DropDrawerContext) + if (!context) { + throw new Error( + 'DropDrawer components cannot be rendered outside the DropDrawer Context' + ) + } + return context +} + +const useComponentSelection = () => { + const { isMobile } = useDropDrawerContext() + + const selectComponent = (mobileComponent: T, desktopComponent: D) => { + return isMobile ? mobileComponent : desktopComponent + } + + return { isMobile, selectComponent } +} + +const useGroupDetection = () => { + const isInGroup = React.useCallback( + (element: HTMLElement | null): boolean => { + if (!element) return false + + let parent = element.parentElement + while (parent) { + if (parent.hasAttribute('data-drop-drawer-group')) { + return true + } + parent = parent.parentElement + } + return false + }, + [] + ) + + const useGroupState = () => { + const { isMobile } = useComponentSelection() + const itemRef = React.useRef(null) + const [isInsideGroup, setIsInsideGroup] = React.useState(false) + + React.useEffect(() => { + if (!isMobile) return + + const timer = setTimeout(() => { + if (itemRef.current) { + setIsInsideGroup(isInGroup(itemRef.current)) + } + }, 0) + + return () => clearTimeout(timer) + }, [isMobile]) + + return { itemRef, isInsideGroup } + } + + return { isInGroup, useGroupState } +} + +type ConditionalComponentProps = { + children: React.ReactNode + className?: string +} & (T | D) + +const ConditionalComponent = ({ + mobileComponent, + desktopComponent, + children, + ...props +}: { + mobileComponent: React.ComponentType + desktopComponent: React.ComponentType + children: React.ReactNode +} & ConditionalComponentProps) => { + const { selectComponent } = useComponentSelection() + const Component = selectComponent(mobileComponent, desktopComponent) + + return {children} +} + +function DropDrawer({ + children, + ...props +}: + | React.ComponentProps + | React.ComponentProps) { + const isMobile = useSmallScreen() + + return ( + + + {children} + + + ) +} + +function DropDrawerTrigger({ + className, + children, + ...props +}: + | React.ComponentProps + | React.ComponentProps) { + return ( + + {children} + + ) +} + +function DropDrawerContent({ + className, + children, + ...props +}: + | React.ComponentProps + | React.ComponentProps) { + const { isMobile } = useDropDrawerContext() + const [activeSubmenu, setActiveSubmenu] = React.useState(null) + const [submenuTitle, setSubmenuTitle] = React.useState(null) + const [submenuStack, setSubmenuStack] = React.useState< + { id: string; title: string }[] + >([]) + // Add animation direction state + const [animationDirection, setAnimationDirection] = React.useState< + 'forward' | 'backward' + >('forward') + + // Create a ref to store submenu content by ID + const submenuContentRef = React.useRef>( + new Map() + ) + + // Function to navigate to a submenu + const navigateToSubmenu = React.useCallback((id: string, title: string) => { + // Set animation direction to forward when navigating to a submenu + setAnimationDirection('forward') + setActiveSubmenu(id) + setSubmenuTitle(title) + setSubmenuStack((prev) => [...prev, { id, title }]) + }, []) + + // Function to go back to previous menu + const goBack = React.useCallback(() => { + // Set animation direction to backward when going back + setAnimationDirection('backward') + + if (submenuStack.length <= 1) { + // If we're at the first level, go back to main menu + setActiveSubmenu(null) + setSubmenuTitle(null) + setSubmenuStack([]) + } else { + // Go back to previous submenu + const newStack = [...submenuStack] + newStack.pop() // Remove current + const previous = newStack[newStack.length - 1] + setActiveSubmenu(previous.id) + setSubmenuTitle(previous.title) + setSubmenuStack(newStack) + } + }, [submenuStack]) + + // Function to register submenu content + const registerSubmenuContent = React.useCallback( + (id: string, content: React.ReactNode[]) => { + submenuContentRef.current.set(id, content) + }, + [] + ) + + const extractSubmenuContent = React.useCallback( + (elements: React.ReactNode, targetId: string): React.ReactNode[] => { + const result: React.ReactNode[] = [] + + const findSubmenuContent = (node: React.ReactNode) => { + if (!React.isValidElement(node)) return + + const element = node as React.ReactElement + const props = element.props as { + 'id'?: string + 'data-submenu-id'?: string + 'children'?: React.ReactNode + } + + if (element.type === DropDrawerSub) { + const elementId = props.id || props['data-submenu-id'] + + if (elementId === targetId) { + React.Children.forEach(props.children, (child) => { + if ( + React.isValidElement(child) && + child.type === DropDrawerSubContent + ) { + const subContentProps = child.props as { + children?: React.ReactNode + } + React.Children.forEach( + subContentProps.children, + (contentChild) => { + result.push(contentChild) + } + ) + } + }) + return + } + } + + if (props.children) { + React.Children.forEach(props.children, findSubmenuContent) + } + } + + React.Children.forEach(elements, findSubmenuContent) + return result + }, + [] + ) + + // Get submenu content (always extract fresh to reflect state changes) + const getSubmenuContent = React.useCallback( + (id: string) => { + // Always extract fresh content to ensure state changes are reflected + const submenuContent = extractSubmenuContent(children, id) + return submenuContent + }, + [children, extractSubmenuContent] + ) + + if (isMobile) { + return ( + { + if (id === null) { + setActiveSubmenu(null) + setSubmenuTitle(null) + setSubmenuStack([]) + } + }, + submenuTitle, + setSubmenuTitle, + navigateToSubmenu, + registerSubmenuContent, + }} + > + + {activeSubmenu ? ( + <> + +
+ + + {submenuTitle || 'Submenu'} + +
+
+
+ {/* Use AnimatePresence to handle exit animations */} + + + {activeSubmenu + ? getSubmenuContent(activeSubmenu) + : children} + + +
+ + ) : ( + <> + + Menu + +
+ + + {children} + + +
+ + )} +
+
+ ) + } + + return ( + + + {children} + + + ) +} + +function DropDrawerItem({ + className, + children, + onSelect, + onClick, + icon, + variant = 'default', + inset, + disabled, + ...props +}: React.ComponentProps & { + icon?: React.ReactNode +}) { + const { isMobile } = useComponentSelection() + const { useGroupState } = useGroupDetection() + const { itemRef, isInsideGroup } = useGroupState() + + if (isMobile) { + const handleClick = (e: React.MouseEvent) => { + if (disabled) return + + // If this item only has an icon (like a switch) and no other interactive content, + // don't handle clicks on the main area - let the icon handle everything + if (icon && !onClick && !onSelect) { + return + } + + // Check if the click came from the icon area (where the Switch is) + const target = e.target as HTMLElement + const iconContainer = (e.currentTarget as HTMLElement).querySelector( + '[data-icon-container]' + ) + + if (iconContainer && iconContainer.contains(target)) { + // Don't handle the click if it came from the icon area + return + } + + if (onClick) onClick(e) + if (onSelect) onSelect(e as unknown as Event) + } + + // Only wrap in DrawerClose if it's not a submenu item + const content = ( +
+
{children}
+ {icon && ( +
+ {icon} +
+ )} +
+ ) + + // Check if this is inside a submenu + const isInSubmenu = + (props as Record)['data-parent-submenu-id'] || + (props as Record)['data-parent-submenu'] + + if (isInSubmenu) { + return content + } + + return {content} + } + + return ( + } + variant={variant} + inset={inset} + disabled={disabled} + {...props} + > +
+
{children}
+ {icon &&
{icon}
} +
+
+ ) +} + +function DropDrawerSeparator({ + className, + ...props +}: React.ComponentProps) { + const { isMobile } = useComponentSelection() + + if (isMobile) { + return null + } + + return ( + + ) +} + +function DropDrawerLabel({ + className, + children, + ...props +}: + | React.ComponentProps + | React.ComponentProps) { + const { isMobile } = useComponentSelection() + + if (isMobile) { + return ( + + + {children} + + + ) + } + + return ( + + {children} + + ) +} + +function DropDrawerFooter({ + className, + children, + ...props +}: React.ComponentProps | React.ComponentProps<'div'>) { + const { isMobile } = useDropDrawerContext() + + if (isMobile) { + return ( + + {children} + + ) + } + + // No direct equivalent in DropdownMenu, so we'll just render a div + return ( +
+ {children} +
+ ) +} + +function DropDrawerGroup({ + className, + children, + ...props +}: React.ComponentProps<'div'> & { + children: React.ReactNode +}) { + const { isMobile } = useDropDrawerContext() + + // Add separators between children on mobile + const childrenWithSeparators = React.useMemo(() => { + if (!isMobile) return children + + const childArray = React.Children.toArray(children) + + // Filter out any existing separators + const filteredChildren = childArray.filter( + (child) => + React.isValidElement(child) && child.type !== DropDrawerSeparator + ) + + // Add separators between items + return filteredChildren.flatMap((child, index) => { + if (index === filteredChildren.length - 1) return [child] + return [ + child, +