import { useState, useMemo, useEffect, useCallback, useRef } from 'react' import Image from 'next/image' import { EngineConfig, InferenceEngine } from '@janhq/core' import { Badge, Button, Input, ScrollArea, Tabs, useClickOutside, } from '@janhq/joi' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { ChevronDownIcon, ChevronUpIcon, DownloadCloudIcon, XIcon, } from 'lucide-react' import { twMerge } from 'tailwind-merge' import ProgressCircle from '@/containers/Loader/ProgressCircle' import ModelLabel from '@/containers/ModelLabel' import SetupRemoteModel from '@/containers/SetupRemoteModel' import { useActiveModel } from '@/hooks/useActiveModel' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useDownloadModel from '@/hooks/useDownloadModel' import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { useGetEngines } from '@/hooks/useEngineManagement' import { useGetModelSources } from '@/hooks/useModelSource' import useRecommendedModel from '@/hooks/useRecommendedModel' import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { formatDownloadPercentage, toGigabytes } from '@/utils/converter' import { manualRecommendationModel } from '@/utils/model' import { getLogoEngine, getTitleByEngine } from '@/utils/modelEngine' import { extractModelName } from '@/utils/modelSource' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { configuredModelsAtom, getDownloadingModelAtom, selectedModelAtom, showEngineListModelAtom, } from '@/helpers/atoms/Model.atom' import { activeThreadAtom, setThreadModelParamsAtom, } from '@/helpers/atoms/Thread.atom' type Props = { chatInputMode?: boolean strictedThread?: boolean disabled?: boolean } export const modelDropdownStateAtom = atom(false) const ModelDropdown = ({ disabled, chatInputMode, strictedThread = true, }: Props) => { const { downloadModel } = useDownloadModel() const [modelDropdownState, setModelDropdownState] = useAtom( modelDropdownStateAtom ) const [searchFilter, setSearchFilter] = useState('local') const [searchText, setSearchText] = useState('') const [open, setOpen] = useState(modelDropdownState) const activeThread = useAtomValue(activeThreadAtom) const activeAssistant = useAtomValue(activeAssistantAtom) const downloadingModels = useAtomValue(getDownloadingModelAtom) const [toggle, setToggle] = useState(null) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const { recommendedModel, downloadedModels } = useRecommendedModel() const { sources } = useGetModelSources() const [dropdownOptions, setDropdownOptions] = useState( null ) const { engines } = useGetEngines() const downloadStates = useAtomValue(modelDownloadStateAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const { updateModelParameter } = useUpdateModelParameters() const searchInputRef = useRef(null) const configuredModels = useAtomValue(configuredModelsAtom) const { stopModel } = useActiveModel() const featuredModels = sources?.filter((x) => manualRecommendationModel.includes(x.id) ) const { updateThreadMetadata } = useCreateNewThread() const engineList = useMemo( () => Object.entries(engines ?? {}).flatMap((e) => ({ name: e[0], type: e[1][0]?.type === 'remote' ? 'remote' : 'local', engine: e[1][0], })), [engines] ) useClickOutside(() => handleChangeStateOpen(false), null, [ dropdownOptions, toggle, ]) const [showEngineListModel, setShowEngineListModel] = useAtom( showEngineListModelAtom ) const handleChangeStateOpen = useCallback( (state: boolean) => { setOpen(state) setModelDropdownState(state) }, [setModelDropdownState] ) const filteredDownloadedModels = useMemo( () => configuredModels .concat( downloadedModels.filter( (e) => !configuredModels.some((x) => x.id === e.id) ) ) .filter((e) => e.name.toLowerCase().includes(searchText.toLowerCase().trim()) ) .filter((e) => { if (searchFilter === 'local') { return ( engineList.find((t) => t.engine?.engine === e.engine)?.type === 'local' ) } return true }) .sort((a, b) => a.name.localeCompare(b.name)) .sort((a, b) => { const aInDownloadedModels = downloadedModels.some( (item) => item.id === a.id ) const bInDownloadedModels = downloadedModels.some( (item) => item.id === b.id ) if (aInDownloadedModels && !bInDownloadedModels) { return -1 } else if (!aInDownloadedModels && bInDownloadedModels) { return 1 } else { return 0 } }), [configuredModels, searchText, searchFilter, downloadedModels, engineList] ) useEffect(() => { if (modelDropdownState && chatInputMode) { setOpen(modelDropdownState) } }, [chatInputMode, modelDropdownState]) useEffect(() => { if (open && searchInputRef.current) { searchInputRef.current.focus() } }, [open]) useEffect(() => { setShowEngineListModel((prev) => [ ...prev, ...engineList .filter((x) => (x.engine?.api_key?.length ?? 0) > 0) .map((e) => e.name), ]) }, [setShowEngineListModel, engineList]) useEffect(() => { if (!activeThread) return const modelId = activeAssistant?.model?.id const model = downloadedModels.find((model) => model.id === modelId) if (model) { if ( engines?.[model.engine]?.[0]?.type === 'local' || (engines?.[model.engine]?.[0]?.api_key?.length ?? 0) > 0 ) setSelectedModel(model) } else { setSelectedModel(undefined) } }, [ recommendedModel, activeThread, downloadedModels, setSelectedModel, activeAssistant?.model?.id, engines, ]) const isLocalEngine = useCallback( (engine?: string) => { if (!engine) return false return engineList.some((t) => t.name === engine && t.type === 'local') }, [engineList] ) const onClickModelItem = useCallback( async (modelId: string) => { if (!activeAssistant) return const model = downloadedModels.find((m) => m.id === modelId) setSelectedModel(model) setOpen(false) stopModel() if (activeThread) { // Change assistand tools based on model support RAG updateThreadMetadata({ ...activeThread, assistants: [ { ...activeAssistant, tools: [ { type: 'retrieval', enabled: model?.engine === InferenceEngine.cortex, settings: { ...(activeAssistant.tools && activeAssistant.tools[0]?.settings), }, }, ], }, ], }) const contextLength = model?.settings.ctx_len ? Math.min(8192, model?.settings.ctx_len ?? 8192) : undefined const overriddenParameters = { ctx_len: contextLength, max_tokens: contextLength ? Math.min(model?.parameters.max_tokens ?? 8192, contextLength) : model?.parameters.max_tokens, } const modelParams = { ...model?.parameters, ...model?.settings, ...overriddenParameters, } // Update model parameter to the thread state setThreadModelParams(activeThread.id, modelParams) // Update model parameter to the thread file if (model) updateModelParameter(activeThread, { params: modelParams, modelId: model.id, engine: model.engine, }) } }, [ activeAssistant, downloadedModels, setSelectedModel, activeThread, updateThreadMetadata, setThreadModelParams, updateModelParameter, stopModel, ] ) const isDownloadALocalModel = useMemo( () => downloadedModels.some((x) => engineList.some((t) => t.name === x.engine && t.type === 'local') ), [downloadedModels, engineList] ) if (strictedThread && !activeThread) { return null } return (
{chatInputMode ? ( handleChangeStateOpen(!open)} > {selectedModel?.name || 'Select a model'} ) : ( } onClick={() => setOpen(!open)} /> )}
setSearchFilter(value)} />
setSearchText(e.target.value)} suffixIcon={ searchText.length > 0 && ( setSearchText('')} /> ) } />
{engineList .filter((e) => e.type === searchFilter) .filter( (e) => e.type === 'remote' || e.name === InferenceEngine.cortex_llamacpp || filteredDownloadedModels.some((e) => e.engine === e.name) ) .map((engine, i) => { const isConfigured = engine.type === 'local' || ((engine.engine as EngineConfig).api_key?.length ?? 0) > 1 const engineLogo = getLogoEngine(engine.name as InferenceEngine) const showModel = showEngineListModel.includes(engine.name) const onClickChevron = () => { if (showModel) { setShowEngineListModel((prev) => prev.filter((item) => item !== engine.name) ) } else { setShowEngineListModel((prev) => [...prev, engine.name]) } } return (
{engineLogo && ( logo )}
{getTitleByEngine(engine.name)}
{engine.type === 'remote' && ( 0 } /> )} {!showModel ? ( ) : ( )}
{engine.type === 'local' && !isDownloadALocalModel && showModel && !searchText.length && (
    {featuredModels?.map((model) => { const isDownloading = downloadingModels.some( (md) => md === (model.models[0]?.id ?? model.id) ) return (
  • {extractModelName(model.id)}

    {toGigabytes(model.models[0]?.size)} {!isDownloading ? ( downloadModel(model.models[0]?.id) } /> ) : ( Object.values(downloadStates) .filter( (x) => x.modelId === (model.models[0]?.id ?? model.id) ) .map((item) => ( )) )}
  • ) })}
)}
    {filteredDownloadedModels .filter( (x) => x.engine === engine.name || (x.engine === InferenceEngine.nitro && engine.name === InferenceEngine.cortex_llamacpp) ) .filter((y) => { if (isLocalEngine(y.engine) && !searchText.length) { return downloadedModels.find((c) => c.id === y.id) } else { return y } }) .map((model) => { if (!showModel) return null const isDownloading = downloadingModels.some( (md) => md === model.id ) const isDownloaded = downloadedModels.some( (c) => c.id === model.id ) return ( <> {isDownloaded && (
  • { if ( !isConfigured && engine.type === 'remote' ) return null if (isDownloaded) { onClickModelItem(model.id) } }} >

    {model.name}

    {!isDownloaded && ( {toGigabytes(model.metadata?.size)} )} {!isDownloading && !isDownloaded ? ( downloadModel( model.sources[0].url, model.id ) } /> ) : ( Object.values(downloadStates) .filter((x) => x.modelId === model.id) .map((item) => ( )) )}
  • )} ) })}
) })}
) } export default ModelDropdown