import { useState, useMemo, useEffect, useCallback, useRef } from 'react' import Image from 'next/image' import { InferenceEngine, Model } from '@janhq/core' import { Badge, Button, Input, ScrollArea, Select, useClickOutside, } from '@janhq/joi' import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { ChevronDownIcon, ChevronUpIcon, DownloadCloudIcon, XIcon, } from 'lucide-react' import { twMerge } from 'tailwind-merge' import ProgressCircle from '@/containers/Loader/ProgressCircle' import ModelLabel from '@/containers/ModelLabel' import SetupRemoteModel from '@/containers/SetupRemoteModel' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useDownloadModel from '@/hooks/useDownloadModel' import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import useRecommendedModel from '@/hooks/useRecommendedModel' import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' import { getLogoEngine, getTitleByEngine, localEngines, priorityEngine, } from '@/utils/modelEngine' import { extensionManager } from '@/extension' import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { configuredModelsAtom, getDownloadingModelAtom, selectedModelAtom, showEngineListModelAtom, } from '@/helpers/atoms/Model.atom' import { activeThreadAtom, setThreadModelParamsAtom, } from '@/helpers/atoms/Thread.atom' type Props = { chatInputMode?: boolean strictedThread?: boolean disabled?: boolean } const ModelDropdown = ({ disabled, chatInputMode, strictedThread = true, }: Props) => { const { downloadModel } = useDownloadModel() const [searchFilter, setSearchFilter] = useState('all') const [filterOptionsOpen, setFilterOptionsOpen] = useState(false) const [searchText, setSearchText] = useState('') const [open, setOpen] = useState(false) const activeThread = useAtomValue(activeThreadAtom) const downloadingModels = useAtomValue(getDownloadingModelAtom) const [toggle, setToggle] = useState(null) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const { recommendedModel, downloadedModels } = useRecommendedModel() const [dropdownOptions, setDropdownOptions] = useState( null ) const downloadStates = useAtomValue(modelDownloadStateAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const { updateModelParameter } = useUpdateModelParameters() const searchInputRef = useRef(null) const configuredModels = useAtomValue(configuredModelsAtom) const featuredModel = configuredModels.filter((x) => x.metadata.tags.includes('Featured') ) const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) const { updateThreadMetadata } = useCreateNewThread() useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ dropdownOptions, toggle, ]) const [showEngineListModel, setShowEngineListModel] = useAtom( showEngineListModelAtom ) const isModelSupportRagAndTools = useCallback((model: Model) => { return ( model?.engine === InferenceEngine.openai || localEngines.includes(model?.engine as InferenceEngine) ) }, []) const filteredDownloadedModels = useMemo( () => configuredModels .filter((e) => e.name.toLowerCase().includes(searchText.toLowerCase().trim()) ) .filter((e) => { if (searchFilter === 'all') { return e.engine } if (searchFilter === 'local') { return localEngines.includes(e.engine) } if (searchFilter === 'remote') { return !localEngines.includes(e.engine) } }) .sort((a, b) => a.name.localeCompare(b.name)) .sort((a, b) => { const aInDownloadedModels = downloadedModels.some( (item) => item.id === a.id ) const bInDownloadedModels = downloadedModels.some( (item) => item.id === b.id ) if (aInDownloadedModels && !bInDownloadedModels) { return -1 } else if (!aInDownloadedModels && bInDownloadedModels) { return 1 } else { return 0 } }), [configuredModels, searchText, searchFilter, downloadedModels] ) useEffect(() => { if (open && searchInputRef.current) { searchInputRef.current.focus() } }, [open]) useEffect(() => { if (!activeThread) return let model = downloadedModels.find( (model) => model.id === activeThread.assistants[0].model.id ) if (!model) { model = recommendedModel } setSelectedModel(model) }, [recommendedModel, activeThread, downloadedModels, setSelectedModel]) const onClickModelItem = useCallback( async (modelId: string) => { const model = downloadedModels.find((m) => m.id === modelId) setSelectedModel(model) setOpen(false) if (activeThread) { // Change assistand tools based on model support RAG updateThreadMetadata({ ...activeThread, assistants: [ { ...activeThread.assistants[0], tools: [ { type: 'retrieval', enabled: isModelSupportRagAndTools(model as Model), settings: { ...(activeThread.assistants[0].tools && activeThread.assistants[0].tools[0]?.settings), }, }, ], }, ], }) // Default setting ctx_len for the model for a better onboarding experience // TODO: When Cortex support hardware instructions, we should remove this const defaultContextLength = preserveModelSettings ? model?.metadata?.default_ctx_len : 2048 const defaultMaxTokens = preserveModelSettings ? model?.metadata?.default_max_tokens : 2048 const overriddenSettings = model?.settings.ctx_len && model.settings.ctx_len > 2048 ? { ctx_len: defaultContextLength ?? 2048 } : {} const overriddenParameters = model?.parameters.max_tokens && model.parameters.max_tokens ? { max_tokens: defaultMaxTokens ?? 2048 } : {} const modelParams = { ...model?.parameters, ...model?.settings, ...overriddenParameters, ...overriddenSettings, } // Update model parameter to the thread state setThreadModelParams(activeThread.id, modelParams) // Update model parameter to the thread file if (model) updateModelParameter(activeThread, { params: modelParams, modelId: model.id, engine: model.engine, }) } }, [ downloadedModels, activeThread, setSelectedModel, isModelSupportRagAndTools, setThreadModelParams, updateModelParameter, updateThreadMetadata, preserveModelSettings, ] ) const [extensionHasSettings, setExtensionHasSettings] = useState< { name?: string; setting: string; apiKey: string; provider: string }[] >([]) const inActiveEngineProvider = useAtomValue(inActiveEngineProviderAtom) useEffect(() => { const getAllSettings = async () => { const extensionsMenu: { name?: string setting: string apiKey: string provider: string }[] = [] const extensions = extensionManager.getAll() for (const extension of extensions) { if (typeof extension.getSettings === 'function') { const settings = await extension.getSettings() if ( (settings && settings.length > 0) || (await extension.installationState()) !== 'NotRequired' ) { extensionsMenu.push({ name: extension.productName, setting: extension.name, apiKey: 'apiKey' in extension && typeof extension.apiKey === 'string' ? extension.apiKey : '', provider: 'provider' in extension && typeof extension.provider === 'string' ? extension.provider : '', }) } } } setExtensionHasSettings(extensionsMenu) } getAllSettings() }, []) const findByEngine = filteredDownloadedModels .filter((x) => !inActiveEngineProvider.includes(x.engine)) .map((x) => x.engine) const groupByEngine = findByEngine .filter(function (item, index) { if (findByEngine.indexOf(item) === index) return item }) .sort((a, b) => { if (priorityEngine.includes(a) && priorityEngine.includes(b)) { return priorityEngine.indexOf(a) - priorityEngine.indexOf(b) } else if (priorityEngine.includes(a)) { return -1 } else if (priorityEngine.includes(b)) { return 1 } else { return 0 // Leave the rest in their original order } }) const getEngineStatusReady: InferenceEngine[] = extensionHasSettings ?.filter((e) => e.apiKey.length > 0) .map((x) => x.provider as InferenceEngine) useEffect(() => { setShowEngineListModel((prev) => [ ...prev, ...(getEngineStatusReady as InferenceEngine[]), ]) // eslint-disable-next-line react-hooks/exhaustive-deps }, [setShowEngineListModel, extensionHasSettings]) const isDownloadALocalModel = downloadedModels.some((x) => localEngines.includes(x.engine) ) if (strictedThread && !activeThread) { return null } return (
{chatInputMode ? ( setOpen(!open)} > {selectedModel?.name} ) : ( } onClick={() => setOpen(!open)} /> )}
setSearchText(e.target.value)} suffixIcon={ searchText.length > 0 && ( setSearchText('')} /> ) } />