import { create } from 'zustand' import { persist, createJSONStorage } from 'zustand/middleware' import { localStorageKey } from '@/constants/localStorage' import { MCPTool } from '@/types/completion' type ToolDisabledState = { // Track disabled tools per thread disabledTools: Record // threadId -> toolNames[] // Global default disabled tools (for new threads/index page) defaultDisabledTools: string[] // Flag to track if defaults have been initialized from extension defaultsInitialized: boolean // Actions setToolDisabledForThread: ( threadId: string, toolName: string, available: boolean ) => void isToolDisabled: (threadId: string, toolName: string) => boolean getDisabledToolsForThread: (threadId: string) => string[] setDefaultDisabledTools: (toolNames: string[]) => void getDefaultDisabledTools: () => string[] isDefaultsInitialized: () => boolean markDefaultsAsInitialized: () => void // Initialize thread tools from default or existing thread settings initializeThreadTools: (threadId: string, allTools: MCPTool[]) => void } export const useToolAvailable = create()( persist( (set, get) => ({ disabledTools: {}, defaultDisabledTools: [], defaultsInitialized: false, setToolDisabledForThread: ( threadId: string, toolName: string, available: boolean ) => { set((state) => { const currentTools = state.disabledTools[threadId] || [] let updatedTools: string[] if (available) { // Remove disabled tool updatedTools = [...currentTools.filter((tool) => tool !== toolName)] } else { // Disable tool updatedTools = [...currentTools, toolName] } return { disabledTools: { ...state.disabledTools, [threadId]: updatedTools, }, } }) }, isToolDisabled: (threadId: string, toolName: string) => { const state = get() // If no thread-specific settings, use default if (!state.disabledTools[threadId]) { return state.defaultDisabledTools.includes(toolName) } return state.disabledTools[threadId]?.includes(toolName) || false }, getDisabledToolsForThread: (threadId: string) => { const state = get() // If no thread-specific settings, use default if (!state.disabledTools[threadId]) { return state.defaultDisabledTools } return state.disabledTools[threadId] || [] }, setDefaultDisabledTools: (toolNames: string[]) => { set({ defaultDisabledTools: toolNames }) }, getDefaultDisabledTools: () => { return get().defaultDisabledTools }, isDefaultsInitialized: () => { return get().defaultsInitialized }, markDefaultsAsInitialized: () => { set({ defaultsInitialized: true }) }, initializeThreadTools: (threadId: string, allTools: MCPTool[]) => { const state = get() // If thread already has settings, don't override if (state.disabledTools[threadId]) { return } // Initialize with default tools only // Don't auto-enable all tools if defaults are explicitly empty const initialTools = state.defaultDisabledTools.filter((toolName) => allTools.some((tool) => tool.name === toolName) ) set((currentState) => ({ disabledTools: { ...currentState.disabledTools, [threadId]: initialTools, }, })) }, }), { name: localStorageKey.toolAvailability, storage: createJSONStorage(() => localStorage), // Persist all state partialize: (state) => ({ disabledTools: state.disabledTools, defaultDisabledTools: state.defaultDisabledTools, defaultsInitialized: state.defaultsInitialized, }), } ) )