Merge pull request #6236 from menloresearch/feat/add-tool-call-cancellation
This commit is contained in:
commit
6efdd66bbd
@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult};
|
||||
use serde_json::{Map, Value};
|
||||
use tauri::{AppHandle, Emitter, Runtime, State};
|
||||
use tokio::time::timeout;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use super::{
|
||||
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
|
||||
@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
|
||||
/// * `state` - Application state containing MCP server connections
|
||||
/// * `tool_name` - Name of the tool to call
|
||||
/// * `arguments` - Optional map of argument names to values
|
||||
/// * `cancellation_token` - Optional token to allow cancellation from JS side
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
|
||||
@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
|
||||
/// 1. Locks the MCP servers mutex to access server connections
|
||||
/// 2. Searches through all servers for one containing the named tool
|
||||
/// 3. When found, calls the tool on that server with the provided arguments
|
||||
/// 4. Returns error if no server has the requested tool
|
||||
/// 4. Supports cancellation via cancellation_token
|
||||
/// 5. Returns error if no server has the requested tool
|
||||
#[tauri::command]
|
||||
pub async fn call_tool(
|
||||
state: State<'_, AppState>,
|
||||
tool_name: String,
|
||||
arguments: Option<Map<String, Value>>,
|
||||
cancellation_token: Option<String>,
|
||||
) -> Result<CallToolResult, String> {
|
||||
// Set up cancellation if token is provided
|
||||
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
|
||||
|
||||
if let Some(token) = &cancellation_token {
|
||||
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||
cancellations.insert(token.clone(), cancel_tx);
|
||||
}
|
||||
|
||||
let servers = state.mcp_servers.lock().await;
|
||||
|
||||
// Iterate through servers and find the first one that contains the tool
|
||||
@ -209,25 +221,77 @@ pub async fn call_tool(
|
||||
|
||||
println!("Found tool {} in server", tool_name);
|
||||
|
||||
// Call the tool with timeout
|
||||
// Call the tool with timeout and cancellation support
|
||||
let tool_call = service.call_tool(CallToolRequestParam {
|
||||
name: tool_name.clone().into(),
|
||||
arguments,
|
||||
});
|
||||
|
||||
return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
|
||||
Ok(result) => result.map_err(|e| e.to_string()),
|
||||
// Race between timeout, tool call, and cancellation
|
||||
let result = if cancellation_token.is_some() {
|
||||
tokio::select! {
|
||||
result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => {
|
||||
match result {
|
||||
Ok(call_result) => call_result.map_err(|e| e.to_string()),
|
||||
Err(_) => Err(format!(
|
||||
"Tool call '{}' timed out after {} seconds",
|
||||
tool_name,
|
||||
MCP_TOOL_CALL_TIMEOUT.as_secs()
|
||||
)),
|
||||
}
|
||||
}
|
||||
_ = cancel_rx => {
|
||||
Err(format!("Tool call '{}' was cancelled", tool_name))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
|
||||
Ok(call_result) => call_result.map_err(|e| e.to_string()),
|
||||
Err(_) => Err(format!(
|
||||
"Tool call '{}' timed out after {} seconds",
|
||||
tool_name,
|
||||
MCP_TOOL_CALL_TIMEOUT.as_secs()
|
||||
)),
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up cancellation token
|
||||
if let Some(token) = &cancellation_token {
|
||||
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||
cancellations.remove(token);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Err(format!("Tool {} not found", tool_name))
|
||||
}
|
||||
|
||||
/// Cancels a running tool call by its cancellation token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `state` - Application state containing cancellation tokens
|
||||
/// * `cancellation_token` - Token identifying the tool call to cancel
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<(), String>` - Success if token found and cancelled, error otherwise
|
||||
#[tauri::command]
|
||||
pub async fn cancel_tool_call(
|
||||
state: State<'_, AppState>,
|
||||
cancellation_token: String,
|
||||
) -> Result<(), String> {
|
||||
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||
|
||||
if let Some(cancel_tx) = cancellations.remove(&cancellation_token) {
|
||||
// Send cancellation signal - ignore if receiver is already dropped
|
||||
let _ = cancel_tx.send(());
|
||||
println!("Tool call with token {} cancelled", cancellation_token);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Cancellation token {} not found", cancellation_token))
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
|
||||
let mut path = get_jan_data_folder_path(app);
|
||||
|
||||
@ -6,7 +6,7 @@ use rmcp::{
|
||||
service::RunningService,
|
||||
RoleClient, ServiceError,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::{Mutex, oneshot};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Server handle type for managing the proxy server lifecycle
|
||||
@ -27,6 +27,7 @@ pub struct AppState {
|
||||
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
|
||||
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
|
||||
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||
pub tool_call_cancellations: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
|
||||
}
|
||||
|
||||
impl RunningServiceEnum {
|
||||
|
||||
@ -74,6 +74,7 @@ pub fn run() {
|
||||
// MCP commands
|
||||
core::mcp::commands::get_tools,
|
||||
core::mcp::commands::call_tool,
|
||||
core::mcp::commands::cancel_tool_call,
|
||||
core::mcp::commands::restart_mcp_servers,
|
||||
core::mcp::commands::get_connected_servers,
|
||||
core::mcp::commands::save_mcp_configs,
|
||||
@ -105,6 +106,7 @@ pub fn run() {
|
||||
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
|
||||
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
|
||||
server_handle: Arc::new(Mutex::new(None)),
|
||||
tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())),
|
||||
})
|
||||
.setup(|app| {
|
||||
app.handle().plugin(
|
||||
|
||||
@ -46,8 +46,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||
const [isFocused, setIsFocused] = useState(false)
|
||||
const [rows, setRows] = useState(1)
|
||||
const { streamingContent, abortControllers, loadingModel, tools } =
|
||||
useAppState()
|
||||
const {
|
||||
streamingContent,
|
||||
abortControllers,
|
||||
loadingModel,
|
||||
tools,
|
||||
cancelToolCall,
|
||||
} = useAppState()
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { currentThreadId } = useThreads()
|
||||
const { t } = useTranslation()
|
||||
@ -161,8 +166,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
||||
const stopStreaming = useCallback(
|
||||
(threadId: string) => {
|
||||
abortControllers[threadId]?.abort()
|
||||
cancelToolCall?.()
|
||||
},
|
||||
[abortControllers]
|
||||
[abortControllers, cancelToolCall]
|
||||
)
|
||||
|
||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
@ -13,6 +13,7 @@ type AppState = {
|
||||
tokenSpeed?: TokenSpeed
|
||||
currentToolCall?: ChatCompletionMessageToolCall
|
||||
showOutOfContextDialog?: boolean
|
||||
cancelToolCall?: () => void
|
||||
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||
updateCurrentToolCall: (
|
||||
@ -24,6 +25,7 @@ type AppState = {
|
||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
||||
resetTokenSpeed: () => void
|
||||
setOutOfContextDialog: (show: boolean) => void
|
||||
setCancelToolCall: (cancel: (() => void) | undefined) => void
|
||||
}
|
||||
|
||||
export const useAppState = create<AppState>()((set) => ({
|
||||
@ -34,6 +36,7 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
abortControllers: {},
|
||||
tokenSpeed: undefined,
|
||||
currentToolCall: undefined,
|
||||
cancelToolCall: undefined,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => {
|
||||
const assistants = useAssistant.getState().assistants
|
||||
const currentAssistant = useAssistant.getState().currentAssistant
|
||||
@ -112,4 +115,9 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
showOutOfContextDialog: show,
|
||||
}))
|
||||
},
|
||||
setCancelToolCall: (cancel) => {
|
||||
set(() => ({
|
||||
cancelToolCall: cancel,
|
||||
}))
|
||||
},
|
||||
}))
|
||||
|
||||
@ -31,8 +31,9 @@ import { ulid } from 'ulidx'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
import { CompletionMessagesBuilder } from './messages'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { callTool } from '@/services/mcp'
|
||||
import { callToolWithCancellation } from '@/services/mcp'
|
||||
import { ExtensionManager } from './extension'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
|
||||
export type ChatCompletionResponse =
|
||||
| chatCompletion
|
||||
@ -381,13 +382,17 @@ export const postMessageProcessing = async (
|
||||
)
|
||||
: true)
|
||||
|
||||
let result = approved
|
||||
? await callTool({
|
||||
const { promise, cancel } = callToolWithCancellation({
|
||||
toolName: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments.length
|
||||
? JSON.parse(toolCall.function.arguments)
|
||||
: {},
|
||||
}).catch((e) => {
|
||||
})
|
||||
|
||||
useAppState.getState().setCancelToolCall(cancel)
|
||||
|
||||
let result = approved
|
||||
? await promise.catch((e) => {
|
||||
console.error('Tool call failed:', e)
|
||||
return {
|
||||
content: [
|
||||
|
||||
@ -5,6 +5,7 @@ export const AppRoutes = [
|
||||
'installExtensions',
|
||||
'getTools',
|
||||
'callTool',
|
||||
'cancelToolCall',
|
||||
'listThreads',
|
||||
'createThread',
|
||||
'modifyThread',
|
||||
|
||||
@ -56,3 +56,44 @@ export const callTool = (args: {
|
||||
}): Promise<{ error: string; content: { text: string }[] }> => {
|
||||
return window.core?.api?.callTool(args)
|
||||
}
|
||||
|
||||
/**
|
||||
* @description Enhanced function to invoke an MCP tool with cancellation support
|
||||
* @param args - Tool call arguments
|
||||
* @param cancellationToken - Optional cancellation token
|
||||
* @returns Promise with tool result and cancellation function
|
||||
*/
|
||||
export const callToolWithCancellation = (args: {
|
||||
toolName: string
|
||||
arguments: object
|
||||
cancellationToken?: string
|
||||
}): {
|
||||
promise: Promise<{ error: string; content: { text: string }[] }>
|
||||
cancel: () => Promise<void>
|
||||
token: string
|
||||
} => {
|
||||
// Generate a unique cancellation token if not provided
|
||||
const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||
|
||||
// Create the tool call promise with cancellation token
|
||||
const promise = window.core?.api?.callTool({
|
||||
...args,
|
||||
cancellationToken: token
|
||||
})
|
||||
|
||||
// Create cancel function
|
||||
const cancel = async () => {
|
||||
await window.core?.api?.cancelToolCall({ cancellationToken: token })
|
||||
}
|
||||
|
||||
return { promise, cancel, token }
|
||||
}
|
||||
|
||||
/**
|
||||
* @description This function cancels a running tool call
|
||||
* @param cancellationToken - The token identifying the tool call to cancel
|
||||
* @returns
|
||||
*/
|
||||
export const cancelToolCall = (cancellationToken: string): Promise<void> => {
|
||||
return window.core?.api?.cancelToolCall({ cancellationToken })
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user