enhancement: ux tool call permission dialog and state active (#5157)
* enhancement: mcp toold dialog approval * enhancement: update mcp tool enable or disable * chore: add toggle mcl global permission
This commit is contained in:
parent
573e667c34
commit
057accfb96
@ -5,7 +5,7 @@ import { cva, type VariantProps } from 'class-variance-authority'
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive cursor-pointer",
|
||||
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[0px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive cursor-pointer focus:outline-none",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
|
||||
@ -44,11 +44,18 @@ function DialogOverlay({
|
||||
)
|
||||
}
|
||||
|
||||
type DialogContentProps = React.ComponentProps<
|
||||
typeof DialogPrimitive.Content
|
||||
> & {
|
||||
showCloseButton?: boolean
|
||||
}
|
||||
|
||||
function DialogContent({
|
||||
showCloseButton = true,
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof DialogPrimitive.Content>) {
|
||||
}: DialogContentProps) {
|
||||
return (
|
||||
<DialogPortal data-slot="dialog-portal">
|
||||
<DialogOverlay />
|
||||
@ -61,10 +68,12 @@ function DialogContent({
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<DialogPrimitive.Close className="data-[state=open]:text-main-view-fg/50 absolute top-4 right-4 rounded-xs opacity-70 transition-opacity hover:opacity-100 focus:ring-0 focus:outline-0 disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4 cursor-pointer">
|
||||
<XIcon />
|
||||
<span className="sr-only">Close</span>
|
||||
</DialogPrimitive.Close>
|
||||
{showCloseButton && (
|
||||
<DialogPrimitive.Close className="data-[state=open]:text-main-view-fg/50 absolute top-4 right-4 rounded-xs opacity-70 transition-opacity hover:opacity-100 focus:ring-0 focus:outline-0 disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4 cursor-pointer">
|
||||
<XIcon />
|
||||
<span className="sr-only">Close</span>
|
||||
</DialogPrimitive.Close>
|
||||
)}
|
||||
</DialogPrimitive.Content>
|
||||
</DialogPortal>
|
||||
)
|
||||
|
||||
@ -2,16 +2,17 @@ export const localStorageKey = {
|
||||
LeftPanel: 'left-panel',
|
||||
threads: 'threads',
|
||||
messages: 'messages',
|
||||
assistant: 'assistant',
|
||||
theme: 'theme',
|
||||
modelProvider: 'model-provider',
|
||||
settingAppearance: 'setting-appearance',
|
||||
settingGeneral: 'setting-general',
|
||||
settingCodeBlock: 'setting-code-block',
|
||||
settingMCPSevers: 'setting-mcp-servers',
|
||||
settingLocalApiServer: 'setting-local-api-server',
|
||||
settingProxyConfig: 'setting-proxy-config',
|
||||
settingHardware: 'setting-hardware',
|
||||
productAnalyticPrompt: 'productAnalyticPrompt',
|
||||
productAnalytic: 'productAnalytic',
|
||||
toolApproval: 'tool-approval',
|
||||
toolAvailability: 'tool-availability',
|
||||
mcpGlobalPermissions: 'mcp-global-permissions',
|
||||
}
|
||||
|
||||
@ -27,22 +27,27 @@ import { MovingBorder } from './MovingBorder'
|
||||
import { useChat } from '@/hooks/useChat'
|
||||
import DropdownModelProvider from '@/containers/DropdownModelProvider'
|
||||
import { ModelLoader } from '@/containers/loaders/ModelLoader'
|
||||
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
|
||||
import { getConnectedServers } from '@/services/mcp'
|
||||
|
||||
type ChatInputProps = {
|
||||
className?: string
|
||||
showSpeedToken?: boolean
|
||||
model?: ThreadModel
|
||||
initialMessage?: boolean
|
||||
}
|
||||
|
||||
const ChatInput = ({
|
||||
model,
|
||||
className,
|
||||
showSpeedToken = true,
|
||||
initialMessage,
|
||||
}: ChatInputProps) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||
const [isFocused, setIsFocused] = useState(false)
|
||||
const [rows, setRows] = useState(1)
|
||||
const { streamingContent, abortControllers, loadingModel } = useAppState()
|
||||
const { streamingContent, abortControllers, loadingModel, tools } =
|
||||
useAppState()
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { currentThreadId } = useThreads()
|
||||
const { t } = useTranslation()
|
||||
@ -62,6 +67,30 @@ const ChatInput = ({
|
||||
dataUrl: string
|
||||
}>
|
||||
>([])
|
||||
const [connectedServers, setConnectedServers] = useState<string[]>([])
|
||||
|
||||
// Check for connected MCP servers
|
||||
useEffect(() => {
|
||||
const checkConnectedServers = async () => {
|
||||
try {
|
||||
const servers = await getConnectedServers()
|
||||
setConnectedServers(servers)
|
||||
} catch (error) {
|
||||
console.error('Failed to get connected servers:', error)
|
||||
setConnectedServers([])
|
||||
}
|
||||
}
|
||||
|
||||
checkConnectedServers()
|
||||
|
||||
// Poll for connected servers every 3 seconds
|
||||
const intervalId = setInterval(checkConnectedServers, 3000)
|
||||
|
||||
return () => clearInterval(intervalId)
|
||||
}, [])
|
||||
|
||||
// Check if there are active MCP servers
|
||||
const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0
|
||||
|
||||
const handleSendMesage = (prompt: string) => {
|
||||
if (!selectedModel) {
|
||||
@ -404,11 +433,31 @@ const ChatInput = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedModel?.capabilities?.includes('tools') && (
|
||||
<div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
|
||||
<IconTool size={18} className="text-main-view-fg/50" />
|
||||
</div>
|
||||
)}
|
||||
{selectedModel?.capabilities?.includes('tools') &&
|
||||
hasActiveMCPServers && (
|
||||
<DropdownToolsAvailable initialMessage={initialMessage}>
|
||||
{(isOpen, toolsCount) => (
|
||||
<div
|
||||
className={cn(
|
||||
'h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1 cursor-pointer relative',
|
||||
isOpen && 'bg-main-view-fg/10'
|
||||
)}
|
||||
>
|
||||
<IconTool
|
||||
size={18}
|
||||
className="text-main-view-fg/50"
|
||||
/>
|
||||
{toolsCount > 0 && (
|
||||
<div className="absolute -top-1 -right-1.5 bg-accent text-accent-fg text-xs rounded-full size-4 flex items-center justify-center font-medium">
|
||||
<span className="leading-0">
|
||||
{toolsCount > 99 ? '99+' : toolsCount}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</DropdownToolsAvailable>
|
||||
)}
|
||||
|
||||
{selectedModel?.capabilities?.includes('web_search') && (
|
||||
<div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
|
||||
|
||||
176
web-app/src/containers/DropdownToolsAvailable.tsx
Normal file
176
web-app/src/containers/DropdownToolsAvailable.tsx
Normal file
@ -0,0 +1,176 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import { getTools } from '@/services/mcp'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
|
||||
import { useThreads } from '@/hooks/useThreads'
|
||||
import { useToolAvailable } from '@/hooks/useToolAvailable'
|
||||
|
||||
import React from 'react'
|
||||
|
||||
interface DropdownToolsAvailableProps {
|
||||
children: (isOpen: boolean, toolsCount: number) => React.ReactNode
|
||||
initialMessage?: boolean
|
||||
}
|
||||
|
||||
export default function DropdownToolsAvailable({
|
||||
children,
|
||||
initialMessage = false,
|
||||
}: DropdownToolsAvailableProps) {
|
||||
const [tools, setTools] = useState<MCPTool[]>([])
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
const { getCurrentThread } = useThreads()
|
||||
const {
|
||||
isToolAvailable,
|
||||
setToolAvailableForThread,
|
||||
setDefaultAvailableTools,
|
||||
initializeThreadTools,
|
||||
getAvailableToolsForThread,
|
||||
getDefaultAvailableTools,
|
||||
} = useToolAvailable()
|
||||
|
||||
const currentThread = getCurrentThread()
|
||||
|
||||
useEffect(() => {
|
||||
const fetchTools = async () => {
|
||||
try {
|
||||
const availableTools = await getTools()
|
||||
setTools(availableTools)
|
||||
|
||||
// If this is for the initial message (index page) and no defaults are set,
|
||||
// initialize with all tools as default
|
||||
if (
|
||||
initialMessage &&
|
||||
getDefaultAvailableTools().length === 0 &&
|
||||
availableTools.length > 0
|
||||
) {
|
||||
setDefaultAvailableTools(availableTools.map((tool) => tool.name))
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch tools:', error)
|
||||
setTools([])
|
||||
}
|
||||
}
|
||||
|
||||
// Only fetch tools once when component mounts
|
||||
fetchTools()
|
||||
}, [initialMessage, setDefaultAvailableTools, getDefaultAvailableTools])
|
||||
|
||||
// Separate effect for thread initialization - only when we have tools and a new thread
|
||||
useEffect(() => {
|
||||
if (tools.length > 0 && currentThread?.id) {
|
||||
initializeThreadTools(currentThread.id, tools)
|
||||
}
|
||||
}, [currentThread?.id, tools, initializeThreadTools])
|
||||
|
||||
const handleToolToggle = (toolName: string, checked: boolean) => {
|
||||
if (initialMessage) {
|
||||
// Update default tools for new threads/index page
|
||||
const currentDefaults = getDefaultAvailableTools()
|
||||
if (checked) {
|
||||
if (!currentDefaults.includes(toolName)) {
|
||||
setDefaultAvailableTools([...currentDefaults, toolName])
|
||||
}
|
||||
} else {
|
||||
setDefaultAvailableTools(
|
||||
currentDefaults.filter((name) => name !== toolName)
|
||||
)
|
||||
}
|
||||
} else if (currentThread?.id) {
|
||||
// Update tools for specific thread
|
||||
setToolAvailableForThread(currentThread.id, toolName, checked)
|
||||
}
|
||||
}
|
||||
|
||||
const isToolChecked = (toolName: string): boolean => {
|
||||
if (initialMessage) {
|
||||
// Use default tools for index page
|
||||
return getDefaultAvailableTools().includes(toolName)
|
||||
} else if (currentThread?.id) {
|
||||
// Use thread-specific tools
|
||||
return isToolAvailable(currentThread.id, toolName)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const getEnabledToolsCount = (): number => {
|
||||
if (initialMessage) {
|
||||
return getDefaultAvailableTools().length
|
||||
} else if (currentThread?.id) {
|
||||
return getAvailableToolsForThread(currentThread.id).length
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const renderTrigger = () => children(isOpen, getEnabledToolsCount())
|
||||
|
||||
if (tools.length === 0) {
|
||||
return (
|
||||
<DropdownMenu onOpenChange={setIsOpen}>
|
||||
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="max-w-64">
|
||||
<DropdownMenuItem disabled>No tools available</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<DropdownMenu onOpenChange={setIsOpen}>
|
||||
<DropdownMenuTrigger asChild>{renderTrigger()}</DropdownMenuTrigger>
|
||||
|
||||
<DropdownMenuContent
|
||||
side="top"
|
||||
align="start"
|
||||
className="max-w-64 max-h-64 "
|
||||
>
|
||||
<DropdownMenuLabel className="flex items-center gap-2 sticky -top-1 z-10 bg-main-view px-4 pl-2 py-2">
|
||||
Available Tools
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuSeparator />
|
||||
<div>
|
||||
{tools.map((tool) => {
|
||||
const isChecked = isToolChecked(tool.name)
|
||||
return (
|
||||
<div
|
||||
key={tool.name}
|
||||
className="px-2 py-2 hover:bg-main-view-fg/5 rounded-sm"
|
||||
>
|
||||
<div className="flex items-start justify-between gap-3">
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div>
|
||||
<h4 className="text-sm font-medium truncate">
|
||||
{tool.name}
|
||||
</h4>
|
||||
{tool.description && (
|
||||
<p className="text-xs text-main-view-fg/70 mt-1 line-clamp-2">
|
||||
{tool.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<Switch
|
||||
checked={isChecked}
|
||||
onCheckedChange={(checked) =>
|
||||
handleToolToggle(tool.name, checked)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
)
|
||||
}
|
||||
@ -166,7 +166,13 @@ function RenderMarkdownComponent({
|
||||
|
||||
// Render the markdown content
|
||||
return (
|
||||
<div className={cn('markdown break-words select-text', className)}>
|
||||
<div
|
||||
className={cn(
|
||||
'markdown break-words select-text',
|
||||
isUser && 'is-user',
|
||||
className
|
||||
)}
|
||||
>
|
||||
<ReactMarkdown
|
||||
remarkPlugins={remarkPlugins}
|
||||
rehypePlugins={rehypePlugins}
|
||||
|
||||
@ -39,6 +39,7 @@ const ToolCallBlock = ({ id, name, result, loading }: Props) => {
|
||||
<div
|
||||
className="mx-auto w-full cursor-pointer break-words"
|
||||
onClick={handleClick}
|
||||
data-tool-call-block={id}
|
||||
>
|
||||
<div className="rounded-lg bg-main-view-fg/4 border border-dashed border-main-view-fg/10">
|
||||
<div className="flex items-center gap-3 p-2">
|
||||
|
||||
77
web-app/src/containers/dialogs/ToolApproval.tsx
Normal file
77
web-app/src/containers/dialogs/ToolApproval.tsx
Normal file
@ -0,0 +1,77 @@
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from '@/components/ui/dialog'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
import { AlertTriangle } from 'lucide-react'
|
||||
|
||||
export default function ToolApproval() {
|
||||
const { isModalOpen, modalProps, setModalOpen } = useToolApproval()
|
||||
|
||||
if (!modalProps) {
|
||||
return null
|
||||
}
|
||||
|
||||
const { toolName, onApprove, onDeny } = modalProps
|
||||
|
||||
const handleAllowOnce = () => {
|
||||
onApprove(true) // true = allow once only
|
||||
}
|
||||
|
||||
const handleAllow = () => {
|
||||
onApprove(false) // false = remember for this thread
|
||||
}
|
||||
|
||||
const handleDeny = () => {
|
||||
onDeny()
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={isModalOpen} onOpenChange={setModalOpen}>
|
||||
<DialogContent showCloseButton={false}>
|
||||
<DialogHeader>
|
||||
<div className="flex items-start gap-3">
|
||||
<div className="shrink-0">
|
||||
<AlertTriangle className="size-4" />
|
||||
</div>
|
||||
<div>
|
||||
<DialogTitle>Tool Call Request</DialogTitle>
|
||||
<DialogDescription className="mt-1 text-main-view-fg/70">
|
||||
The assistant wants to use the tool: <strong>{toolName}</strong>
|
||||
</DialogDescription>
|
||||
</div>
|
||||
</div>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="bg-main-view-fg/8 p-2 border border-main-view-fg/5 rounded-lg">
|
||||
<p className="text-sm text-main-view-fg/70 leading-relaxed">
|
||||
<strong>Security Notice:</strong> Malicious tools or conversation
|
||||
content could potentially trick the assistant into attempting
|
||||
harmful actions. Review each tool call carefully before approving.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<DialogFooter className="flex flex-col gap-2 sm:flex-row sm:justify-end">
|
||||
<Button variant="link" onClick={handleDeny} className="w-full">
|
||||
Deny
|
||||
</Button>
|
||||
<Button
|
||||
variant="link"
|
||||
onClick={handleAllowOnce}
|
||||
className="border border-main-view-fg/20"
|
||||
>
|
||||
Allow Once
|
||||
</Button>
|
||||
<Button variant="default" onClick={handleAllow}>
|
||||
Always Allow
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
)
|
||||
}
|
||||
@ -26,6 +26,9 @@ import { listen } from '@tauri-apps/api/event'
|
||||
import { SystemEvent } from '@/types/events'
|
||||
import { stopModel, startModel } from '@/services/models'
|
||||
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
import { useToolAvailable } from '@/hooks/useToolAvailable'
|
||||
|
||||
export const useChat = () => {
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const {
|
||||
@ -39,6 +42,10 @@ export const useChat = () => {
|
||||
} = useAppState()
|
||||
const { currentAssistant } = useAssistant()
|
||||
|
||||
const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
|
||||
useToolApproval()
|
||||
const { getAvailableToolsForThread } = useToolAvailable()
|
||||
|
||||
const { getProviderByName, selectedModel, selectedProvider } =
|
||||
useModelProvider()
|
||||
|
||||
@ -123,9 +130,16 @@ export const useChat = () => {
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
// Filter tools based on model capabilities and available tools for this thread
|
||||
let availableTools = selectedModel?.capabilities?.includes('tools')
|
||||
? tools
|
||||
? tools.filter((tool) => {
|
||||
const availableToolNames = getAvailableToolsForThread(
|
||||
activeThread.id
|
||||
)
|
||||
return availableToolNames.includes(tool.name)
|
||||
})
|
||||
: []
|
||||
|
||||
// TODO: Later replaced by Agent setup?
|
||||
const followUpWithToolUse = true
|
||||
while (!isCompleted && !abortController.signal.aborted) {
|
||||
@ -193,7 +207,10 @@ export const useChat = () => {
|
||||
toolCalls,
|
||||
builder,
|
||||
finalContent,
|
||||
abortController
|
||||
abortController,
|
||||
approvedTools,
|
||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||
allowAllMCPPermissions
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
|
||||
@ -225,6 +242,10 @@ export const useChat = () => {
|
||||
tools,
|
||||
updateLoadingModel,
|
||||
updateTokenSpeed,
|
||||
approvedTools,
|
||||
showApprovalModal,
|
||||
getAvailableToolsForThread,
|
||||
allowAllMCPPermissions,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
107
web-app/src/hooks/useToolApproval.ts
Normal file
107
web-app/src/hooks/useToolApproval.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist, createJSONStorage } from 'zustand/middleware'
|
||||
import { localStorageKey } from '@/constants/localStorage'
|
||||
|
||||
export type ToolApprovalModalProps = {
|
||||
toolName: string
|
||||
threadId: string
|
||||
onApprove: (allowOnce: boolean) => void
|
||||
onDeny: () => void
|
||||
}
|
||||
|
||||
type ToolApprovalState = {
|
||||
// Track approved tools per thread
|
||||
approvedTools: Record<string, string[]> // threadId -> toolNames[]
|
||||
// Global MCP permission toggle
|
||||
allowAllMCPPermissions: boolean
|
||||
// Modal state
|
||||
isModalOpen: boolean
|
||||
modalProps: ToolApprovalModalProps | null
|
||||
|
||||
// Actions
|
||||
approveToolForThread: (threadId: string, toolName: string) => void
|
||||
isToolApproved: (threadId: string, toolName: string) => boolean
|
||||
showApprovalModal: (toolName: string, threadId: string) => Promise<boolean>
|
||||
closeModal: () => void
|
||||
setModalOpen: (open: boolean) => void
|
||||
setAllowAllMCPPermissions: (allow: boolean) => void
|
||||
}
|
||||
|
||||
export const useToolApproval = create<ToolApprovalState>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
approvedTools: {},
|
||||
allowAllMCPPermissions: false,
|
||||
isModalOpen: false,
|
||||
modalProps: null,
|
||||
|
||||
approveToolForThread: (threadId: string, toolName: string) => {
|
||||
set((state) => ({
|
||||
approvedTools: {
|
||||
...state.approvedTools,
|
||||
[threadId]: [
|
||||
...(state.approvedTools[threadId] || []),
|
||||
toolName,
|
||||
].filter((tool, index, arr) => arr.indexOf(tool) === index), // Remove duplicates
|
||||
},
|
||||
}))
|
||||
},
|
||||
|
||||
isToolApproved: (threadId: string, toolName: string) => {
|
||||
const state = get()
|
||||
return state.approvedTools[threadId]?.includes(toolName) || false
|
||||
},
|
||||
|
||||
showApprovalModal: (toolName: string, threadId: string) => {
|
||||
return new Promise<boolean>((resolve) => {
|
||||
set({
|
||||
isModalOpen: true,
|
||||
modalProps: {
|
||||
toolName,
|
||||
threadId,
|
||||
onApprove: (allowOnce: boolean) => {
|
||||
if (!allowOnce) {
|
||||
// If not "allow once", add to approved tools for this thread
|
||||
get().approveToolForThread(threadId, toolName)
|
||||
}
|
||||
get().closeModal()
|
||||
resolve(true)
|
||||
},
|
||||
onDeny: () => {
|
||||
get().closeModal()
|
||||
resolve(false)
|
||||
},
|
||||
},
|
||||
})
|
||||
})
|
||||
},
|
||||
|
||||
closeModal: () => {
|
||||
set({
|
||||
isModalOpen: false,
|
||||
modalProps: null,
|
||||
})
|
||||
},
|
||||
|
||||
setModalOpen: (open: boolean) => {
|
||||
set({ isModalOpen: open })
|
||||
if (!open) {
|
||||
get().closeModal()
|
||||
}
|
||||
},
|
||||
|
||||
setAllowAllMCPPermissions: (allow: boolean) => {
|
||||
set({ allowAllMCPPermissions: allow })
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStorageKey.toolApproval,
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
// Only persist approved tools and global permission setting, not modal state
|
||||
partialize: (state) => ({
|
||||
approvedTools: state.approvedTools,
|
||||
allowAllMCPPermissions: state.allowAllMCPPermissions,
|
||||
}),
|
||||
}
|
||||
)
|
||||
)
|
||||
117
web-app/src/hooks/useToolAvailable.ts
Normal file
117
web-app/src/hooks/useToolAvailable.ts
Normal file
@ -0,0 +1,117 @@
|
||||
import { create } from 'zustand'
|
||||
import { persist, createJSONStorage } from 'zustand/middleware'
|
||||
import { localStorageKey } from '@/constants/localStorage'
|
||||
import { MCPTool } from '@/types/completion'
|
||||
|
||||
type ToolAvailableState = {
|
||||
// Track available tools per thread
|
||||
availableTools: Record<string, string[]> // threadId -> toolNames[]
|
||||
// Global default available tools (for new threads/index page)
|
||||
defaultAvailableTools: string[]
|
||||
|
||||
// Actions
|
||||
setToolAvailableForThread: (
|
||||
threadId: string,
|
||||
toolName: string,
|
||||
available: boolean
|
||||
) => void
|
||||
isToolAvailable: (threadId: string, toolName: string) => boolean
|
||||
getAvailableToolsForThread: (threadId: string) => string[]
|
||||
setDefaultAvailableTools: (toolNames: string[]) => void
|
||||
getDefaultAvailableTools: () => string[]
|
||||
// Initialize thread tools from default or existing thread settings
|
||||
initializeThreadTools: (threadId: string, allTools: MCPTool[]) => void
|
||||
}
|
||||
|
||||
export const useToolAvailable = create<ToolAvailableState>()(
|
||||
persist(
|
||||
(set, get) => ({
|
||||
availableTools: {},
|
||||
defaultAvailableTools: [],
|
||||
|
||||
setToolAvailableForThread: (
|
||||
threadId: string,
|
||||
toolName: string,
|
||||
available: boolean
|
||||
) => {
|
||||
set((state) => {
|
||||
const currentTools = state.availableTools[threadId] || []
|
||||
let updatedTools: string[]
|
||||
|
||||
if (available) {
|
||||
// Add tool if not already present
|
||||
updatedTools = currentTools.includes(toolName)
|
||||
? currentTools
|
||||
: [...currentTools, toolName]
|
||||
} else {
|
||||
// Remove tool
|
||||
updatedTools = currentTools.filter((tool) => tool !== toolName)
|
||||
}
|
||||
|
||||
return {
|
||||
availableTools: {
|
||||
...state.availableTools,
|
||||
[threadId]: updatedTools,
|
||||
},
|
||||
}
|
||||
})
|
||||
},
|
||||
|
||||
isToolAvailable: (threadId: string, toolName: string) => {
|
||||
const state = get()
|
||||
// If no thread-specific settings, use default
|
||||
if (!state.availableTools[threadId]) {
|
||||
return state.defaultAvailableTools.includes(toolName)
|
||||
}
|
||||
return state.availableTools[threadId]?.includes(toolName) || false
|
||||
},
|
||||
|
||||
getAvailableToolsForThread: (threadId: string) => {
|
||||
const state = get()
|
||||
// If no thread-specific settings, use default
|
||||
if (!state.availableTools[threadId]) {
|
||||
return state.defaultAvailableTools
|
||||
}
|
||||
return state.availableTools[threadId] || []
|
||||
},
|
||||
|
||||
setDefaultAvailableTools: (toolNames: string[]) => {
|
||||
set({ defaultAvailableTools: toolNames })
|
||||
},
|
||||
|
||||
getDefaultAvailableTools: () => {
|
||||
return get().defaultAvailableTools
|
||||
},
|
||||
|
||||
initializeThreadTools: (threadId: string, allTools: MCPTool[]) => {
|
||||
const state = get()
|
||||
// If thread already has settings, don't override
|
||||
if (state.availableTools[threadId]) {
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize with default tools only
|
||||
// Don't auto-enable all tools if defaults are explicitly empty
|
||||
const initialTools = state.defaultAvailableTools.filter((toolName) =>
|
||||
allTools.some((tool) => tool.name === toolName)
|
||||
)
|
||||
|
||||
set((currentState) => ({
|
||||
availableTools: {
|
||||
...currentState.availableTools,
|
||||
[threadId]: initialTools,
|
||||
},
|
||||
}))
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStorageKey.toolAvailability,
|
||||
storage: createJSONStorage(() => localStorage),
|
||||
// Persist all state
|
||||
partialize: (state) => ({
|
||||
availableTools: state.availableTools,
|
||||
defaultAvailableTools: state.defaultAvailableTools,
|
||||
}),
|
||||
}
|
||||
)
|
||||
)
|
||||
@ -252,12 +252,18 @@ export const extractToolCall = (
|
||||
* @param builder
|
||||
* @param message
|
||||
* @param content
|
||||
* @param approvedTools - Record of approved tools per thread
|
||||
* @param showModal - Function to show approval modal, returns true if approved
|
||||
* @param allowAllMCPPermissions - Global setting to allow all MCP permissions without modal
|
||||
*/
|
||||
export const postMessageProcessing = async (
|
||||
calls: ChatCompletionMessageToolCall[],
|
||||
builder: CompletionMessagesBuilder,
|
||||
message: ThreadMessage,
|
||||
abortController: AbortController
|
||||
abortController: AbortController,
|
||||
approvedTools: Record<string, string[]> = {},
|
||||
showModal?: (toolName: string, threadId: string) => Promise<boolean>,
|
||||
allowAllMCPPermissions: boolean = false
|
||||
) => {
|
||||
// Handle completed tool calls
|
||||
if (calls.length) {
|
||||
@ -284,12 +290,30 @@ export const postMessageProcessing = async (
|
||||
],
|
||||
}
|
||||
|
||||
const result = await callTool({
|
||||
toolName: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments.length
|
||||
? JSON.parse(toolCall.function.arguments)
|
||||
: {},
|
||||
})
|
||||
// Check if tool is approved or show modal for approval
|
||||
const approved =
|
||||
allowAllMCPPermissions ||
|
||||
approvedTools[message.thread_id]?.includes(toolCall.function.name) ||
|
||||
(showModal
|
||||
? await showModal(toolCall.function.name, message.thread_id)
|
||||
: true)
|
||||
|
||||
const result = approved
|
||||
? await callTool({
|
||||
toolName: toolCall.function.name,
|
||||
arguments: toolCall.function.arguments.length
|
||||
? JSON.parse(toolCall.function.arguments)
|
||||
: {},
|
||||
})
|
||||
: {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'The user has chosen to disallow the tool call.',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
if ('error' in result && result.error) break
|
||||
|
||||
message.metadata = {
|
||||
|
||||
@ -17,6 +17,7 @@ import { PromptAnalytic } from '@/containers/analytics/PromptAnalytic'
|
||||
import { AnalyticProvider } from '@/providers/AnalyticProvider'
|
||||
import { useLeftPanel } from '@/hooks/useLeftPanel'
|
||||
import { cn } from '@/lib/utils'
|
||||
import ToolApproval from '@/containers/dialogs/ToolApproval'
|
||||
|
||||
export const Route = createRootRoute({
|
||||
component: RootLayout,
|
||||
@ -92,6 +93,7 @@ function RootLayout() {
|
||||
{isLocalAPIServerLogsRoute ? <LogsLayout /> : <AppLayout />}
|
||||
{/* <TanStackRouterDevtools position="bottom-right" /> */}
|
||||
<CortexFailureDialog />
|
||||
<ToolApproval />
|
||||
</Fragment>
|
||||
)
|
||||
}
|
||||
|
||||
@ -64,7 +64,11 @@ function Index() {
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex-1 shrink-0">
|
||||
<ChatInput showSpeedToken={false} model={selectedModel} />
|
||||
<ChatInput
|
||||
showSpeedToken={false}
|
||||
model={selectedModel}
|
||||
initialMessage={true}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -18,6 +18,7 @@ import EditJsonMCPserver from '@/containers/dialogs/EditJsonMCPserver'
|
||||
import { Switch } from '@/components/ui/switch'
|
||||
import { twMerge } from 'tailwind-merge'
|
||||
import { getConnectedServers } from '@/services/mcp'
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
export const Route = createFileRoute(route.settings.mcp_servers as any)({
|
||||
@ -26,6 +27,8 @@ export const Route = createFileRoute(route.settings.mcp_servers as any)({
|
||||
|
||||
function MCPServers() {
|
||||
const { mcpServers, addServer, editServer, deleteServer } = useMCPServers()
|
||||
const { allowAllMCPPermissions, setAllowAllMCPPermissions } =
|
||||
useToolApproval()
|
||||
|
||||
const [open, setOpen] = useState(false)
|
||||
const [editingKey, setEditingKey] = useState<string | null>(null)
|
||||
@ -195,6 +198,35 @@ function MCPServers() {
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
|
||||
{/* Global MCP Permission Toggle */}
|
||||
<Card
|
||||
header={
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="space-y-1">
|
||||
<h1 className="text-main-view-fg font-medium text-base">
|
||||
Allow All MCP Tool Permissions
|
||||
</h1>
|
||||
<p className="text-sm text-main-view-fg/70">
|
||||
When enabled, all MCP tool calls will be automatically
|
||||
approved without showing permission dialogs.
|
||||
<span className="font-semibold text-main-view-fg">
|
||||
{' '}
|
||||
Use with caution
|
||||
</span>{' '}
|
||||
- only enable this if you trust all your MCP servers.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex-shrink-0 ml-4">
|
||||
<Switch
|
||||
checked={allowAllMCPPermissions}
|
||||
onCheckedChange={setAllowAllMCPPermissions}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
|
||||
{Object.keys(mcpServers).length === 0 ? (
|
||||
<div className="py-4 text-center font-medium text-main-view-fg/50">
|
||||
No MCP servers found
|
||||
|
||||
@ -1,6 +1,16 @@
|
||||
.markdown {
|
||||
@apply text-inherit;
|
||||
|
||||
&.is-user {
|
||||
p {
|
||||
line-height: 1.6;
|
||||
margin-bottom: 1em;
|
||||
&:first-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Headings */
|
||||
:is(h1, h2, h3, h4, h5, h6) {
|
||||
font-weight: 600;
|
||||
@ -41,10 +51,6 @@
|
||||
p {
|
||||
line-height: 1.6;
|
||||
margin-bottom: 1em;
|
||||
|
||||
&:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
strong {
|
||||
@ -199,3 +205,6 @@
|
||||
margin: 2em 0;
|
||||
}
|
||||
}
|
||||
[data-tool-call-block] + [data-tool-call-block] {
|
||||
margin-top: 16px;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user