jan/web-app/src/hooks/useToolAvailable.ts

130 lines
3.9 KiB
TypeScript

import { create } from 'zustand'
import { persist, createJSONStorage } from 'zustand/middleware'
import { localStorageKey } from '@/constants/localStorage'
import { MCPTool } from '@/types/completion'
type ToolDisabledState = {
// Track disabled tools per thread
disabledTools: Record<string, string[]> // threadId -> toolNames[]
// Global default disabled tools (for new threads/index page)
defaultDisabledTools: string[]
// Flag to track if defaults have been initialized from extension
defaultsInitialized: boolean
// Actions
setToolDisabledForThread: (
threadId: string,
toolName: string,
available: boolean
) => void
isToolDisabled: (threadId: string, toolName: string) => boolean
getDisabledToolsForThread: (threadId: string) => string[]
setDefaultDisabledTools: (toolNames: string[]) => void
getDefaultDisabledTools: () => string[]
isDefaultsInitialized: () => boolean
markDefaultsAsInitialized: () => void
// Initialize thread tools from default or existing thread settings
initializeThreadTools: (threadId: string, allTools: MCPTool[]) => void
}
export const useToolAvailable = create<ToolDisabledState>()(
persist(
(set, get) => ({
disabledTools: {},
defaultDisabledTools: [],
defaultsInitialized: false,
setToolDisabledForThread: (
threadId: string,
toolName: string,
available: boolean
) => {
set((state) => {
const currentTools = state.disabledTools[threadId] || []
let updatedTools: string[]
if (available) {
// Remove disabled tool
updatedTools = [...currentTools.filter((tool) => tool !== toolName)]
} else {
// Disable tool
updatedTools = [...currentTools, toolName]
}
return {
disabledTools: {
...state.disabledTools,
[threadId]: updatedTools,
},
}
})
},
isToolDisabled: (threadId: string, toolName: string) => {
const state = get()
// If no thread-specific settings, use default
if (!state.disabledTools[threadId]) {
return state.defaultDisabledTools.includes(toolName)
}
return state.disabledTools[threadId]?.includes(toolName) || false
},
getDisabledToolsForThread: (threadId: string) => {
const state = get()
// If no thread-specific settings, use default
if (!state.disabledTools[threadId]) {
return state.defaultDisabledTools
}
return state.disabledTools[threadId] || []
},
setDefaultDisabledTools: (toolNames: string[]) => {
set({ defaultDisabledTools: toolNames })
},
getDefaultDisabledTools: () => {
return get().defaultDisabledTools
},
isDefaultsInitialized: () => {
return get().defaultsInitialized
},
markDefaultsAsInitialized: () => {
set({ defaultsInitialized: true })
},
initializeThreadTools: (threadId: string, allTools: MCPTool[]) => {
const state = get()
// If thread already has settings, don't override
if (state.disabledTools[threadId]) {
return
}
// Initialize with default tools only
// Don't auto-enable all tools if defaults are explicitly empty
const initialTools = state.defaultDisabledTools.filter((toolName) =>
allTools.some((tool) => tool.name === toolName)
)
set((currentState) => ({
disabledTools: {
...currentState.disabledTools,
[threadId]: initialTools,
},
}))
},
}),
{
name: localStorageKey.toolAvailability,
storage: createJSONStorage(() => localStorage),
// Persist all state
partialize: (state) => ({
disabledTools: state.disabledTools,
defaultDisabledTools: state.defaultDisabledTools,
defaultsInitialized: state.defaultsInitialized,
}),
}
)
)