Merge pull request #6236 from menloresearch/feat/add-tool-call-cancellation

This commit is contained in:
Louis 2025-08-20 09:04:53 +07:00 committed by GitHub
commit 6efdd66bbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 148 additions and 20 deletions

View File

@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use tauri::{AppHandle, Emitter, Runtime, State}; use tauri::{AppHandle, Emitter, Runtime, State};
use tokio::time::timeout; use tokio::time::timeout;
use tokio::sync::oneshot;
use super::{ use super::{
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT}, constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
/// * `state` - Application state containing MCP server connections /// * `state` - Application state containing MCP server connections
/// * `tool_name` - Name of the tool to call /// * `tool_name` - Name of the tool to call
/// * `arguments` - Optional map of argument names to values /// * `arguments` - Optional map of argument names to values
/// * `cancellation_token` - Optional token to allow cancellation from JS side
/// ///
/// # Returns /// # Returns
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed /// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
/// 1. Locks the MCP servers mutex to access server connections /// 1. Locks the MCP servers mutex to access server connections
/// 2. Searches through all servers for one containing the named tool /// 2. Searches through all servers for one containing the named tool
/// 3. When found, calls the tool on that server with the provided arguments /// 3. When found, calls the tool on that server with the provided arguments
/// 4. Returns error if no server has the requested tool /// 4. Supports cancellation via cancellation_token
/// 5. Returns error if no server has the requested tool
#[tauri::command] #[tauri::command]
pub async fn call_tool( pub async fn call_tool(
state: State<'_, AppState>, state: State<'_, AppState>,
tool_name: String, tool_name: String,
arguments: Option<Map<String, Value>>, arguments: Option<Map<String, Value>>,
cancellation_token: Option<String>,
) -> Result<CallToolResult, String> { ) -> Result<CallToolResult, String> {
// Set up cancellation if token is provided
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
if let Some(token) = &cancellation_token {
let mut cancellations = state.tool_call_cancellations.lock().await;
cancellations.insert(token.clone(), cancel_tx);
}
let servers = state.mcp_servers.lock().await; let servers = state.mcp_servers.lock().await;
// Iterate through servers and find the first one that contains the tool // Iterate through servers and find the first one that contains the tool
@ -209,25 +221,77 @@ pub async fn call_tool(
println!("Found tool {} in server", tool_name); println!("Found tool {} in server", tool_name);
// Call the tool with timeout // Call the tool with timeout and cancellation support
let tool_call = service.call_tool(CallToolRequestParam { let tool_call = service.call_tool(CallToolRequestParam {
name: tool_name.clone().into(), name: tool_name.clone().into(),
arguments, arguments,
}); });
return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await { // Race between timeout, tool call, and cancellation
Ok(result) => result.map_err(|e| e.to_string()), let result = if cancellation_token.is_some() {
tokio::select! {
result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => {
match result {
Ok(call_result) => call_result.map_err(|e| e.to_string()),
Err(_) => Err(format!( Err(_) => Err(format!(
"Tool call '{}' timed out after {} seconds", "Tool call '{}' timed out after {} seconds",
tool_name, tool_name,
MCP_TOOL_CALL_TIMEOUT.as_secs() MCP_TOOL_CALL_TIMEOUT.as_secs()
)), )),
}
}
_ = cancel_rx => {
Err(format!("Tool call '{}' was cancelled", tool_name))
}
}
} else {
match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
Ok(call_result) => call_result.map_err(|e| e.to_string()),
Err(_) => Err(format!(
"Tool call '{}' timed out after {} seconds",
tool_name,
MCP_TOOL_CALL_TIMEOUT.as_secs()
)),
}
}; };
// Clean up cancellation token
if let Some(token) = &cancellation_token {
let mut cancellations = state.tool_call_cancellations.lock().await;
cancellations.remove(token);
}
return result;
} }
Err(format!("Tool {} not found", tool_name)) Err(format!("Tool {} not found", tool_name))
} }
/// Cancels a running tool call by its cancellation token
///
/// # Arguments
/// * `state` - Application state containing cancellation tokens
/// * `cancellation_token` - Token identifying the tool call to cancel
///
/// # Returns
/// * `Result<(), String>` - Success if token found and cancelled, error otherwise
#[tauri::command]
pub async fn cancel_tool_call(
state: State<'_, AppState>,
cancellation_token: String,
) -> Result<(), String> {
let mut cancellations = state.tool_call_cancellations.lock().await;
if let Some(cancel_tx) = cancellations.remove(&cancellation_token) {
// Send cancellation signal - ignore if receiver is already dropped
let _ = cancel_tx.send(());
println!("Tool call with token {} cancelled", cancellation_token);
Ok(())
} else {
Err(format!("Cancellation token {} not found", cancellation_token))
}
}
#[tauri::command] #[tauri::command]
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> { pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
let mut path = get_jan_data_folder_path(app); let mut path = get_jan_data_folder_path(app);

View File

@ -6,7 +6,7 @@ use rmcp::{
service::RunningService, service::RunningService,
RoleClient, ServiceError, RoleClient, ServiceError,
}; };
use tokio::sync::Mutex; use tokio::sync::{Mutex, oneshot};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
/// Server handle type for managing the proxy server lifecycle /// Server handle type for managing the proxy server lifecycle
@ -27,6 +27,7 @@ pub struct AppState {
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>, pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>, pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
pub server_handle: Arc<Mutex<Option<ServerHandle>>>, pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
pub tool_call_cancellations: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
} }
impl RunningServiceEnum { impl RunningServiceEnum {

View File

@ -74,6 +74,7 @@ pub fn run() {
// MCP commands // MCP commands
core::mcp::commands::get_tools, core::mcp::commands::get_tools,
core::mcp::commands::call_tool, core::mcp::commands::call_tool,
core::mcp::commands::cancel_tool_call,
core::mcp::commands::restart_mcp_servers, core::mcp::commands::restart_mcp_servers,
core::mcp::commands::get_connected_servers, core::mcp::commands::get_connected_servers,
core::mcp::commands::save_mcp_configs, core::mcp::commands::save_mcp_configs,
@ -105,6 +106,7 @@ pub fn run() {
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())), mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())), mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
server_handle: Arc::new(Mutex::new(None)), server_handle: Arc::new(Mutex::new(None)),
tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())),
}) })
.setup(|app| { .setup(|app| {
app.handle().plugin( app.handle().plugin(

View File

@ -46,8 +46,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null) const textareaRef = useRef<HTMLTextAreaElement>(null)
const [isFocused, setIsFocused] = useState(false) const [isFocused, setIsFocused] = useState(false)
const [rows, setRows] = useState(1) const [rows, setRows] = useState(1)
const { streamingContent, abortControllers, loadingModel, tools } = const {
useAppState() streamingContent,
abortControllers,
loadingModel,
tools,
cancelToolCall,
} = useAppState()
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt()
const { currentThreadId } = useThreads() const { currentThreadId } = useThreads()
const { t } = useTranslation() const { t } = useTranslation()
@ -161,8 +166,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const stopStreaming = useCallback( const stopStreaming = useCallback(
(threadId: string) => { (threadId: string) => {
abortControllers[threadId]?.abort() abortControllers[threadId]?.abort()
cancelToolCall?.()
}, },
[abortControllers] [abortControllers, cancelToolCall]
) )
const fileInputRef = useRef<HTMLInputElement>(null) const fileInputRef = useRef<HTMLInputElement>(null)

View File

@ -13,6 +13,7 @@ type AppState = {
tokenSpeed?: TokenSpeed tokenSpeed?: TokenSpeed
currentToolCall?: ChatCompletionMessageToolCall currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean showOutOfContextDialog?: boolean
cancelToolCall?: () => void
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void updateStreamingContent: (content: ThreadMessage | undefined) => void
updateCurrentToolCall: ( updateCurrentToolCall: (
@ -24,6 +25,7 @@ type AppState = {
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
resetTokenSpeed: () => void resetTokenSpeed: () => void
setOutOfContextDialog: (show: boolean) => void setOutOfContextDialog: (show: boolean) => void
setCancelToolCall: (cancel: (() => void) | undefined) => void
} }
export const useAppState = create<AppState>()((set) => ({ export const useAppState = create<AppState>()((set) => ({
@ -34,6 +36,7 @@ export const useAppState = create<AppState>()((set) => ({
abortControllers: {}, abortControllers: {},
tokenSpeed: undefined, tokenSpeed: undefined,
currentToolCall: undefined, currentToolCall: undefined,
cancelToolCall: undefined,
updateStreamingContent: (content: ThreadMessage | undefined) => { updateStreamingContent: (content: ThreadMessage | undefined) => {
const assistants = useAssistant.getState().assistants const assistants = useAssistant.getState().assistants
const currentAssistant = useAssistant.getState().currentAssistant const currentAssistant = useAssistant.getState().currentAssistant
@ -112,4 +115,9 @@ export const useAppState = create<AppState>()((set) => ({
showOutOfContextDialog: show, showOutOfContextDialog: show,
})) }))
}, },
setCancelToolCall: (cancel) => {
set(() => ({
cancelToolCall: cancel,
}))
},
})) }))

View File

@ -31,8 +31,9 @@ import { ulid } from 'ulidx'
import { MCPTool } from '@/types/completion' import { MCPTool } from '@/types/completion'
import { CompletionMessagesBuilder } from './messages' import { CompletionMessagesBuilder } from './messages'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { callTool } from '@/services/mcp' import { callToolWithCancellation } from '@/services/mcp'
import { ExtensionManager } from './extension' import { ExtensionManager } from './extension'
import { useAppState } from '@/hooks/useAppState'
export type ChatCompletionResponse = export type ChatCompletionResponse =
| chatCompletion | chatCompletion
@ -381,13 +382,17 @@ export const postMessageProcessing = async (
) )
: true) : true)
let result = approved const { promise, cancel } = callToolWithCancellation({
? await callTool({
toolName: toolCall.function.name, toolName: toolCall.function.name,
arguments: toolCall.function.arguments.length arguments: toolCall.function.arguments.length
? JSON.parse(toolCall.function.arguments) ? JSON.parse(toolCall.function.arguments)
: {}, : {},
}).catch((e) => { })
useAppState.getState().setCancelToolCall(cancel)
let result = approved
? await promise.catch((e) => {
console.error('Tool call failed:', e) console.error('Tool call failed:', e)
return { return {
content: [ content: [

View File

@ -5,6 +5,7 @@ export const AppRoutes = [
'installExtensions', 'installExtensions',
'getTools', 'getTools',
'callTool', 'callTool',
'cancelToolCall',
'listThreads', 'listThreads',
'createThread', 'createThread',
'modifyThread', 'modifyThread',

View File

@ -56,3 +56,44 @@ export const callTool = (args: {
}): Promise<{ error: string; content: { text: string }[] }> => { }): Promise<{ error: string; content: { text: string }[] }> => {
return window.core?.api?.callTool(args) return window.core?.api?.callTool(args)
} }
/**
* @description Enhanced function to invoke an MCP tool with cancellation support
* @param args - Tool call arguments
* @param cancellationToken - Optional cancellation token
* @returns Promise with tool result and cancellation function
*/
export const callToolWithCancellation = (args: {
toolName: string
arguments: object
cancellationToken?: string
}): {
promise: Promise<{ error: string; content: { text: string }[] }>
cancel: () => Promise<void>
token: string
} => {
// Generate a unique cancellation token if not provided
const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
// Create the tool call promise with cancellation token
const promise = window.core?.api?.callTool({
...args,
cancellationToken: token
})
// Create cancel function
const cancel = async () => {
await window.core?.api?.cancelToolCall({ cancellationToken: token })
}
return { promise, cancel, token }
}
/**
* @description This function cancels a running tool call
* @param cancellationToken - The token identifying the tool call to cancel
* @returns
*/
export const cancelToolCall = (cancellationToken: string): Promise<void> => {
return window.core?.api?.cancelToolCall({ cancellationToken })
}