Merge pull request #6236 from menloresearch/feat/add-tool-call-cancellation
This commit is contained in:
commit
6efdd66bbd
@ -2,6 +2,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult};
|
|||||||
use serde_json::{Map, Value};
|
use serde_json::{Map, Value};
|
||||||
use tauri::{AppHandle, Emitter, Runtime, State};
|
use tauri::{AppHandle, Emitter, Runtime, State};
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
use tokio::sync::oneshot;
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
|
constants::{DEFAULT_MCP_CONFIG, MCP_TOOL_CALL_TIMEOUT},
|
||||||
@ -179,6 +180,7 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
|
|||||||
/// * `state` - Application state containing MCP server connections
|
/// * `state` - Application state containing MCP server connections
|
||||||
/// * `tool_name` - Name of the tool to call
|
/// * `tool_name` - Name of the tool to call
|
||||||
/// * `arguments` - Optional map of argument names to values
|
/// * `arguments` - Optional map of argument names to values
|
||||||
|
/// * `cancellation_token` - Optional token to allow cancellation from JS side
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
|
/// * `Result<CallToolResult, String>` - Result of the tool call if successful, or error message if failed
|
||||||
@ -187,13 +189,23 @@ pub async fn get_tools(state: State<'_, AppState>) -> Result<Vec<ToolWithServer>
|
|||||||
/// 1. Locks the MCP servers mutex to access server connections
|
/// 1. Locks the MCP servers mutex to access server connections
|
||||||
/// 2. Searches through all servers for one containing the named tool
|
/// 2. Searches through all servers for one containing the named tool
|
||||||
/// 3. When found, calls the tool on that server with the provided arguments
|
/// 3. When found, calls the tool on that server with the provided arguments
|
||||||
/// 4. Returns error if no server has the requested tool
|
/// 4. Supports cancellation via cancellation_token
|
||||||
|
/// 5. Returns error if no server has the requested tool
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn call_tool(
|
pub async fn call_tool(
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
tool_name: String,
|
tool_name: String,
|
||||||
arguments: Option<Map<String, Value>>,
|
arguments: Option<Map<String, Value>>,
|
||||||
|
cancellation_token: Option<String>,
|
||||||
) -> Result<CallToolResult, String> {
|
) -> Result<CallToolResult, String> {
|
||||||
|
// Set up cancellation if token is provided
|
||||||
|
let (cancel_tx, cancel_rx) = oneshot::channel::<()>();
|
||||||
|
|
||||||
|
if let Some(token) = &cancellation_token {
|
||||||
|
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||||
|
cancellations.insert(token.clone(), cancel_tx);
|
||||||
|
}
|
||||||
|
|
||||||
let servers = state.mcp_servers.lock().await;
|
let servers = state.mcp_servers.lock().await;
|
||||||
|
|
||||||
// Iterate through servers and find the first one that contains the tool
|
// Iterate through servers and find the first one that contains the tool
|
||||||
@ -209,25 +221,77 @@ pub async fn call_tool(
|
|||||||
|
|
||||||
println!("Found tool {} in server", tool_name);
|
println!("Found tool {} in server", tool_name);
|
||||||
|
|
||||||
// Call the tool with timeout
|
// Call the tool with timeout and cancellation support
|
||||||
let tool_call = service.call_tool(CallToolRequestParam {
|
let tool_call = service.call_tool(CallToolRequestParam {
|
||||||
name: tool_name.clone().into(),
|
name: tool_name.clone().into(),
|
||||||
arguments,
|
arguments,
|
||||||
});
|
});
|
||||||
|
|
||||||
return match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
|
// Race between timeout, tool call, and cancellation
|
||||||
Ok(result) => result.map_err(|e| e.to_string()),
|
let result = if cancellation_token.is_some() {
|
||||||
|
tokio::select! {
|
||||||
|
result = timeout(MCP_TOOL_CALL_TIMEOUT, tool_call) => {
|
||||||
|
match result {
|
||||||
|
Ok(call_result) => call_result.map_err(|e| e.to_string()),
|
||||||
Err(_) => Err(format!(
|
Err(_) => Err(format!(
|
||||||
"Tool call '{}' timed out after {} seconds",
|
"Tool call '{}' timed out after {} seconds",
|
||||||
tool_name,
|
tool_name,
|
||||||
MCP_TOOL_CALL_TIMEOUT.as_secs()
|
MCP_TOOL_CALL_TIMEOUT.as_secs()
|
||||||
)),
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = cancel_rx => {
|
||||||
|
Err(format!("Tool call '{}' was cancelled", tool_name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
match timeout(MCP_TOOL_CALL_TIMEOUT, tool_call).await {
|
||||||
|
Ok(call_result) => call_result.map_err(|e| e.to_string()),
|
||||||
|
Err(_) => Err(format!(
|
||||||
|
"Tool call '{}' timed out after {} seconds",
|
||||||
|
tool_name,
|
||||||
|
MCP_TOOL_CALL_TIMEOUT.as_secs()
|
||||||
|
)),
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Clean up cancellation token
|
||||||
|
if let Some(token) = &cancellation_token {
|
||||||
|
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||||
|
cancellations.remove(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(format!("Tool {} not found", tool_name))
|
Err(format!("Tool {} not found", tool_name))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Cancels a running tool call by its cancellation token
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `state` - Application state containing cancellation tokens
|
||||||
|
/// * `cancellation_token` - Token identifying the tool call to cancel
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * `Result<(), String>` - Success if token found and cancelled, error otherwise
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn cancel_tool_call(
|
||||||
|
state: State<'_, AppState>,
|
||||||
|
cancellation_token: String,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let mut cancellations = state.tool_call_cancellations.lock().await;
|
||||||
|
|
||||||
|
if let Some(cancel_tx) = cancellations.remove(&cancellation_token) {
|
||||||
|
// Send cancellation signal - ignore if receiver is already dropped
|
||||||
|
let _ = cancel_tx.send(());
|
||||||
|
println!("Tool call with token {} cancelled", cancellation_token);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Cancellation token {} not found", cancellation_token))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
|
pub async fn get_mcp_configs(app: AppHandle) -> Result<String, String> {
|
||||||
let mut path = get_jan_data_folder_path(app);
|
let mut path = get_jan_data_folder_path(app);
|
||||||
|
|||||||
@ -6,7 +6,7 @@ use rmcp::{
|
|||||||
service::RunningService,
|
service::RunningService,
|
||||||
RoleClient, ServiceError,
|
RoleClient, ServiceError,
|
||||||
};
|
};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::{Mutex, oneshot};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
/// Server handle type for managing the proxy server lifecycle
|
/// Server handle type for managing the proxy server lifecycle
|
||||||
@ -27,6 +27,7 @@ pub struct AppState {
|
|||||||
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
|
pub mcp_active_servers: Arc<Mutex<HashMap<String, serde_json::Value>>>,
|
||||||
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
|
pub mcp_successfully_connected: Arc<Mutex<HashMap<String, bool>>>,
|
||||||
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
pub server_handle: Arc<Mutex<Option<ServerHandle>>>,
|
||||||
|
pub tool_call_cancellations: Arc<Mutex<HashMap<String, oneshot::Sender<()>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunningServiceEnum {
|
impl RunningServiceEnum {
|
||||||
|
|||||||
@ -74,6 +74,7 @@ pub fn run() {
|
|||||||
// MCP commands
|
// MCP commands
|
||||||
core::mcp::commands::get_tools,
|
core::mcp::commands::get_tools,
|
||||||
core::mcp::commands::call_tool,
|
core::mcp::commands::call_tool,
|
||||||
|
core::mcp::commands::cancel_tool_call,
|
||||||
core::mcp::commands::restart_mcp_servers,
|
core::mcp::commands::restart_mcp_servers,
|
||||||
core::mcp::commands::get_connected_servers,
|
core::mcp::commands::get_connected_servers,
|
||||||
core::mcp::commands::save_mcp_configs,
|
core::mcp::commands::save_mcp_configs,
|
||||||
@ -105,6 +106,7 @@ pub fn run() {
|
|||||||
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
|
mcp_active_servers: Arc::new(Mutex::new(HashMap::new())),
|
||||||
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
|
mcp_successfully_connected: Arc::new(Mutex::new(HashMap::new())),
|
||||||
server_handle: Arc::new(Mutex::new(None)),
|
server_handle: Arc::new(Mutex::new(None)),
|
||||||
|
tool_call_cancellations: Arc::new(Mutex::new(HashMap::new())),
|
||||||
})
|
})
|
||||||
.setup(|app| {
|
.setup(|app| {
|
||||||
app.handle().plugin(
|
app.handle().plugin(
|
||||||
|
|||||||
@ -46,8 +46,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||||
const [isFocused, setIsFocused] = useState(false)
|
const [isFocused, setIsFocused] = useState(false)
|
||||||
const [rows, setRows] = useState(1)
|
const [rows, setRows] = useState(1)
|
||||||
const { streamingContent, abortControllers, loadingModel, tools } =
|
const {
|
||||||
useAppState()
|
streamingContent,
|
||||||
|
abortControllers,
|
||||||
|
loadingModel,
|
||||||
|
tools,
|
||||||
|
cancelToolCall,
|
||||||
|
} = useAppState()
|
||||||
const { prompt, setPrompt } = usePrompt()
|
const { prompt, setPrompt } = usePrompt()
|
||||||
const { currentThreadId } = useThreads()
|
const { currentThreadId } = useThreads()
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
@ -161,8 +166,9 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
const stopStreaming = useCallback(
|
const stopStreaming = useCallback(
|
||||||
(threadId: string) => {
|
(threadId: string) => {
|
||||||
abortControllers[threadId]?.abort()
|
abortControllers[threadId]?.abort()
|
||||||
|
cancelToolCall?.()
|
||||||
},
|
},
|
||||||
[abortControllers]
|
[abortControllers, cancelToolCall]
|
||||||
)
|
)
|
||||||
|
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null)
|
const fileInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ type AppState = {
|
|||||||
tokenSpeed?: TokenSpeed
|
tokenSpeed?: TokenSpeed
|
||||||
currentToolCall?: ChatCompletionMessageToolCall
|
currentToolCall?: ChatCompletionMessageToolCall
|
||||||
showOutOfContextDialog?: boolean
|
showOutOfContextDialog?: boolean
|
||||||
|
cancelToolCall?: () => void
|
||||||
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||||
updateCurrentToolCall: (
|
updateCurrentToolCall: (
|
||||||
@ -24,6 +25,7 @@ type AppState = {
|
|||||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void
|
||||||
resetTokenSpeed: () => void
|
resetTokenSpeed: () => void
|
||||||
setOutOfContextDialog: (show: boolean) => void
|
setOutOfContextDialog: (show: boolean) => void
|
||||||
|
setCancelToolCall: (cancel: (() => void) | undefined) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useAppState = create<AppState>()((set) => ({
|
export const useAppState = create<AppState>()((set) => ({
|
||||||
@ -34,6 +36,7 @@ export const useAppState = create<AppState>()((set) => ({
|
|||||||
abortControllers: {},
|
abortControllers: {},
|
||||||
tokenSpeed: undefined,
|
tokenSpeed: undefined,
|
||||||
currentToolCall: undefined,
|
currentToolCall: undefined,
|
||||||
|
cancelToolCall: undefined,
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => {
|
updateStreamingContent: (content: ThreadMessage | undefined) => {
|
||||||
const assistants = useAssistant.getState().assistants
|
const assistants = useAssistant.getState().assistants
|
||||||
const currentAssistant = useAssistant.getState().currentAssistant
|
const currentAssistant = useAssistant.getState().currentAssistant
|
||||||
@ -112,4 +115,9 @@ export const useAppState = create<AppState>()((set) => ({
|
|||||||
showOutOfContextDialog: show,
|
showOutOfContextDialog: show,
|
||||||
}))
|
}))
|
||||||
},
|
},
|
||||||
|
setCancelToolCall: (cancel) => {
|
||||||
|
set(() => ({
|
||||||
|
cancelToolCall: cancel,
|
||||||
|
}))
|
||||||
|
},
|
||||||
}))
|
}))
|
||||||
|
|||||||
@ -31,8 +31,9 @@ import { ulid } from 'ulidx'
|
|||||||
import { MCPTool } from '@/types/completion'
|
import { MCPTool } from '@/types/completion'
|
||||||
import { CompletionMessagesBuilder } from './messages'
|
import { CompletionMessagesBuilder } from './messages'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
import { callTool } from '@/services/mcp'
|
import { callToolWithCancellation } from '@/services/mcp'
|
||||||
import { ExtensionManager } from './extension'
|
import { ExtensionManager } from './extension'
|
||||||
|
import { useAppState } from '@/hooks/useAppState'
|
||||||
|
|
||||||
export type ChatCompletionResponse =
|
export type ChatCompletionResponse =
|
||||||
| chatCompletion
|
| chatCompletion
|
||||||
@ -381,13 +382,17 @@ export const postMessageProcessing = async (
|
|||||||
)
|
)
|
||||||
: true)
|
: true)
|
||||||
|
|
||||||
let result = approved
|
const { promise, cancel } = callToolWithCancellation({
|
||||||
? await callTool({
|
|
||||||
toolName: toolCall.function.name,
|
toolName: toolCall.function.name,
|
||||||
arguments: toolCall.function.arguments.length
|
arguments: toolCall.function.arguments.length
|
||||||
? JSON.parse(toolCall.function.arguments)
|
? JSON.parse(toolCall.function.arguments)
|
||||||
: {},
|
: {},
|
||||||
}).catch((e) => {
|
})
|
||||||
|
|
||||||
|
useAppState.getState().setCancelToolCall(cancel)
|
||||||
|
|
||||||
|
let result = approved
|
||||||
|
? await promise.catch((e) => {
|
||||||
console.error('Tool call failed:', e)
|
console.error('Tool call failed:', e)
|
||||||
return {
|
return {
|
||||||
content: [
|
content: [
|
||||||
|
|||||||
@ -5,6 +5,7 @@ export const AppRoutes = [
|
|||||||
'installExtensions',
|
'installExtensions',
|
||||||
'getTools',
|
'getTools',
|
||||||
'callTool',
|
'callTool',
|
||||||
|
'cancelToolCall',
|
||||||
'listThreads',
|
'listThreads',
|
||||||
'createThread',
|
'createThread',
|
||||||
'modifyThread',
|
'modifyThread',
|
||||||
|
|||||||
@ -56,3 +56,44 @@ export const callTool = (args: {
|
|||||||
}): Promise<{ error: string; content: { text: string }[] }> => {
|
}): Promise<{ error: string; content: { text: string }[] }> => {
|
||||||
return window.core?.api?.callTool(args)
|
return window.core?.api?.callTool(args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description Enhanced function to invoke an MCP tool with cancellation support
|
||||||
|
* @param args - Tool call arguments
|
||||||
|
* @param cancellationToken - Optional cancellation token
|
||||||
|
* @returns Promise with tool result and cancellation function
|
||||||
|
*/
|
||||||
|
export const callToolWithCancellation = (args: {
|
||||||
|
toolName: string
|
||||||
|
arguments: object
|
||||||
|
cancellationToken?: string
|
||||||
|
}): {
|
||||||
|
promise: Promise<{ error: string; content: { text: string }[] }>
|
||||||
|
cancel: () => Promise<void>
|
||||||
|
token: string
|
||||||
|
} => {
|
||||||
|
// Generate a unique cancellation token if not provided
|
||||||
|
const token = args.cancellationToken ?? `tool_call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
|
||||||
|
// Create the tool call promise with cancellation token
|
||||||
|
const promise = window.core?.api?.callTool({
|
||||||
|
...args,
|
||||||
|
cancellationToken: token
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create cancel function
|
||||||
|
const cancel = async () => {
|
||||||
|
await window.core?.api?.cancelToolCall({ cancellationToken: token })
|
||||||
|
}
|
||||||
|
|
||||||
|
return { promise, cancel, token }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @description This function cancels a running tool call
|
||||||
|
* @param cancellationToken - The token identifying the tool call to cancel
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const cancelToolCall = (cancellationToken: string): Promise<void> => {
|
||||||
|
return window.core?.api?.cancelToolCall({ cancellationToken })
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user