From 91f05b8f321c8799f5ee6f5609f4373654f2737a Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 19 Aug 2025 23:27:12 +0700 Subject: [PATCH] feat: add tool call cancellation --- src-tauri/src/core/mcp/commands.rs | 82 +++++++++++++++++++++++++--- src-tauri/src/core/state.rs | 3 +- src-tauri/src/lib.rs | 2 + web-app/src/containers/ChatInput.tsx | 12 +++- web-app/src/hooks/useAppState.ts | 8 +++ web-app/src/lib/completion.ts | 19 ++++--- web-app/src/lib/service.ts | 1 + web-app/src/services/mcp.ts | 41 ++++++++++++++ 8 files changed, 148 insertions(+), 20 deletions(-) diff --git a/src-tauri/src/core/mcp/commands.rs b/src-tauri/src/core/mcp/commands.rs index 02caca827..48c7f88a1 100644 --- a/src-tauri/src/core/mcp/commands.rs +++ b/src-tauri/src/core/mcp/commands.rs @@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult}; use serde_json::{Map, Value}; use tauri::{AppHandle, Emitter, Runtime, State}; use tokio::time::timeout; +use tokio::sync::oneshot; use super::{ constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, @@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result /// * `state` - Application state containing MCP server connections /// * `tool_name` - Name of the tool to call /// * `arguments` - Optional map of argument names to values +/// * `cancellation_token` - Optional token to allow cancellation from JS side /// /// # Returns /// * `Result` - Result of the tool call if successful, or error message if failed @@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result /// 1. Locks the MCP servers mutex to access server connections /// 2. Searches through all servers for one containing the named tool /// 3. When found, calls the tool on that server with the provided arguments -/// 4. Returns error if no server has the requested tool +/// 4. Supports cancellation via cancellation_token +/// 5. Returns error if no server has the requested tool #[tauri::command] pub async fn call_tool( state: State<'_, AppState>, tool_name: String, arguments: Option>, + cancellation_token: Option, ) -> Result { + // Set up cancellation if token is provided + let (cancel_tx, cancel_rx) = oneshot::channel::<()>(); + + if let Some(token) = &cancellation_token { + let mut cancellations = state.tool_call_cancellations.lock().await; + cancellations.insert(token.clone(), cancel_tx); + } + let servers = state.mcp_servers.lock().await; // Iterate through servers and find the first one that contains the tool @@ -209,25 +221,77 @@ pub async fn call_tool( println!("Found tool {} in server", tool_name); - // Call the tool with timeout + // Call the tool with timeout and cancellation support let tool_call = service.call_tool(CallToolRequestParam { name: tool_name.clone().into(), arguments, }); - return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await { - Ok(result) => result.map_err(|e| e.to_string()), - Err(_) => Err(format!( - "Tool call '{}' timed out after {} seconds", - tool_name, - MCP_TOOL_CALL_TIMEOUT.as_secs() - )), + // Race between timeout, tool call, and cancellation + let result = if cancellation_token.is_some() { + tokio::select! { + result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => { + match result { + Ok(call_result) => call_result.map_err(|e| e.to_string()), + Err(_) => Err(format!( + "Tool call '{}' timed out after {} seconds", + tool_name, + MCP_TOOL_CALL_TIMEOUT.as_secs() + )), + } + } + _ = cancel_rx => { + Err(format!("Tool call '{}' was cancelled", tool_name)) + } + } + } else { + match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await { + Ok(call_result) => call_result.map_err(|e| e.to_string()), + Err(_) => Err(format!( + "Tool call '{}' timed out after {} seconds", + tool_name, + MCP_TOOL_CALL_TIMEOUT.as_secs() + )), + } }; + + // Clean up cancellation token + if let Some(token) = &cancellation_token { + let mut cancellations = state.tool_call_cancellations.lock().await; + cancellations.remove(token); + } + + return result; } Err(format!("Tool {} not found", tool_name)) } +/// Cancels a running tool call by its cancellation token +/// +/// # Arguments +/// * `state` - Application state containing cancellation tokens +/// * `cancellation_token` - Token identifying the tool call to cancel +/// +/// # Returns +/// * `Result<(), String>` - Success if token found and cancelled, error otherwise +#[tauri::command] +pub async fn cancel_tool_call( + state: State<'_, AppState>, + cancellation_token: String, +) -> Result<(), String> { + let mut cancellations = state.tool_call_cancellations.lock().await; + + if let Some(cancel_tx) = cancellations.remove(&cancellation_token) { + // Send cancellation signal - ignore if receiver is already dropped + let _ = cancel_tx.send(()); + println!("Tool call with token {} cancelled", cancellation_token); + Ok(()) + } else { + Err(format!("Cancellation token {} not found", cancellation_token)) + } +} + #[tauri::command] pub async fn get_mcp_configs(app: AppHandle) -> Result { let mut path = get_jan_data_folder_path(app); diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index 3408052d4..ddbbcf7bd 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -6,7 +6,7 @@ use rmcp::{ service::RunningService, RoleClient, ServiceError, }; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, oneshot}; use tokio::task::JoinHandle; /// Server handle type for managing the proxy server lifecycle @@ -27,6 +27,7 @@ pub struct AppState { pub mcp_active_servers: Arc>>, pub mcp_successfully_connected: Arc>>, pub server_handle: Arc>>, + pub tool_call_cancellations: Arc>>>, } impl RunningServiceEnum { diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 63d60a571..10a9d7556 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -74,6 +74,7 @@ pub fn run() { // MCP commands core::mcp::commands::get_tools, core::mcp::commands::call_tool, + core::mcp::commands::cancel_tool_call, core::mcp::commands::restart_mcp_servers, core::mcp::commands::get_connected_servers, core::mcp::commands::save_mcp_configs, @@ -105,6 +106,7 @@ pub fn run() { mcp_active_servers: Arc::new(Mutex::new(HashMap::new())), mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())), server_handle: Arc::new(Mutex::new(None)), + tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())), }) .setup(|app| { app.handle().plugin( diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 59cdaa3cd..c6360253e 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -46,8 +46,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { const textareaRef = useRef(null) const [isFocused, setIsFocused] = useState(false) const [rows, setRows] = useState(1) - const { streamingContent, abortControllers, loadingModel, tools } = - useAppState() + const { + streamingContent, + abortControllers, + loadingModel, + tools, + cancelToolCall, + } = useAppState() const { prompt, setPrompt } = usePrompt() const { currentThreadId } = useThreads() const { t } = useTranslation() @@ -161,8 +166,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { const stopStreaming = useCallback( (threadId: string) => { abortControllers[threadId]?.abort() + cancelToolCall?.() }, - [abortControllers] + [abortControllers, cancelToolCall] ) const fileInputRef = useRef(null) diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index 5876daefb..7b3841f5c 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -13,6 +13,7 @@ type AppState = { tokenSpeed?: TokenSpeed currentToolCall?: ChatCompletionMessageToolCall showOutOfContextDialog?: boolean + cancelToolCall?: () => void setServerStatus: (value: 'running' | 'stopped' | 'pending') => void updateStreamingContent: (content: ThreadMessage | undefined) => void updateCurrentToolCall: ( @@ -24,6 +25,7 @@ type AppState = { updateTokenSpeed: (message: ThreadMessage, increment?: number) => void resetTokenSpeed: () => void setOutOfContextDialog: (show: boolean) => void + setCancelToolCall: (cancel: (() => void) | undefined) => void } export const useAppState = create()((set) => ({ @@ -34,6 +36,7 @@ export const useAppState = create()((set) => ({ abortControllers: {}, tokenSpeed: undefined, currentToolCall: undefined, + cancelToolCall: undefined, updateStreamingContent: (content: ThreadMessage | undefined) => { const assistants = useAssistant.getState().assistants const currentAssistant = useAssistant.getState().currentAssistant @@ -112,4 +115,9 @@ export const useAppState = create()((set) => ({ showOutOfContextDialog: show, })) }, + setCancelToolCall: (cancel) => { + set(() => ({ + cancelToolCall: cancel, + })) + }, })) diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 22ac724e9..c92a0b096 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -31,8 +31,9 @@ import { ulid } from 'ulidx' import { MCPTool } from '@/types/completion' import { CompletionMessagesBuilder } from './messages' import { ChatCompletionMessageToolCall } from 'openai/resources' -import { callTool } from '@/services/mcp' +import { callToolWithCancellation } from '@/services/mcp' import { ExtensionManager } from './extension' +import { useAppState } from '@/hooks/useAppState' export type ChatCompletionResponse = | chatCompletion @@ -381,13 +382,17 @@ export const postMessageProcessing = async ( ) : true) + const { promise, cancel } = callToolWithCancellation({ + toolName: toolCall.function.name, + arguments: toolCall.function.arguments.length + ? JSON.parse(toolCall.function.arguments) + : {}, + }) + + useAppState.getState().setCancelToolCall(cancel) + let result = approved - ? await callTool({ - toolName: toolCall.function.name, - arguments: toolCall.function.arguments.length - ? JSON.parse(toolCall.function.arguments) - : {}, - }).catch((e) => { + ? await promise.catch((e) => { console.error('Tool call failed:', e) return { content: [ diff --git a/web-app/src/lib/service.ts b/web-app/src/lib/service.ts index 0898cc4dc..809090b9d 100644 --- a/web-app/src/lib/service.ts +++ b/web-app/src/lib/service.ts @@ -5,6 +5,7 @@ export const AppRoutes = [ 'installExtensions', 'getTools', 'callTool', + 'cancelToolCall', 'listThreads', 'createThread', 'modifyThread', diff --git a/web-app/src/services/mcp.ts b/web-app/src/services/mcp.ts index 8159a5048..c266c6a13 100644 --- a/web-app/src/services/mcp.ts +++ b/web-app/src/services/mcp.ts @@ -56,3 +56,44 @@ export const callTool = (args: { }): Promise<{ error: string; content: { text: string }[] }> => { return window.core?.api?.callTool(args) } + +/** + * @description Enhanced function to invoke an MCP tool with cancellation support + * @param args - Tool call arguments + * @param cancellationToken - Optional cancellation token + * @returns Promise with tool result and cancellation function + */ +export const callToolWithCancellation = (args: { + toolName: string + arguments: object + cancellationToken?: string +}): { + promise: Promise<{ error: string; content: { text: string }[] }> + cancel: () => Promise + token: string +} => { + // Generate a unique cancellation token if not provided + const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` + + // Create the tool call promise with cancellation token + const promise = window.core?.api?.callTool({ + ...args, + cancellationToken: token + }) + + // Create cancel function + const cancel = async () => { + await window.core?.api?.cancelToolCall({ cancellationToken: token }) + } + + return { promise, cancel, token } +} + +/** + * @description This function cancels a running tool call + * @param cancellationToken - The token identifying the tool call to cancel + * @returns + */ +export const cancelToolCall = (cancellationToken: string): Promise => { + return window.core?.api?.cancelToolCall({ cancellationToken }) +}