diff --git a/web-app/src/containers/DropdownModelProvider.tsx b/web-app/src/containers/DropdownModelProvider.tsx index 72460a9fe..3e77bad74 100644 --- a/web-app/src/containers/DropdownModelProvider.tsx +++ b/web-app/src/containers/DropdownModelProvider.tsx @@ -1,26 +1,33 @@ +import { useEffect, useState, useRef, useMemo, useCallback } from 'react' import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuGroup, - DropdownMenuItem, - DropdownMenuLabel, - DropdownMenuTrigger, -} from '@/components/ui/dropdown-menu' + Popover, + PopoverContent, + PopoverTrigger, +} from '@/components/ui/popover' import { useModelProvider } from '@/hooks/useModelProvider' import { cn, getProviderTitle } from '@/lib/utils' -import { useEffect, useState } from 'react' +import { highlightFzfMatch } from '@/utils/highlight' import Capabilities from './Capabilities' -import { IconSettings } from '@tabler/icons-react' +import { IconSettings, IconX, IconCheck } from '@tabler/icons-react' import { useNavigate } from '@tanstack/react-router' import { route } from '@/constants/routes' import { useThreads } from '@/hooks/useThreads' import { ModelSetting } from '@/containers/ModelSetting' import ProvidersAvatar from '@/containers/ProvidersAvatar' +import { Fzf } from 'fzf' type DropdownModelProviderProps = { model?: ThreadModel } +interface SearchableModel { + provider: ModelProvider + model: Model + searchStr: string + value: string + highlightedId?: string +} + const DropdownModelProvider = ({ model }: DropdownModelProviderProps) => { const { providers, @@ -34,6 +41,11 @@ const DropdownModelProvider = ({ model }: DropdownModelProviderProps) => { const { updateCurrentThreadModel } = useThreads() const navigate = useNavigate() + // Search state + const [open, setOpen] = useState(false) + const [searchValue, setSearchValue] = useState('') + const searchInputRef = useRef(null) + // Initialize model provider only once useEffect(() => { // Auto select model when existing thread is passed @@ -43,7 +55,7 @@ const DropdownModelProvider = ({ model }: DropdownModelProviderProps) => { // default model, we should add from setting selectModelProvider('llama.cpp', 'llama3.2:3b') } - }, [model, selectModelProvider, updateCurrentThreadModel]) // Only run when threadData changes + }, [model, selectModelProvider, updateCurrentThreadModel]) // Update display model when selection changes useEffect(() => { @@ -54,6 +66,116 @@ const DropdownModelProvider = ({ model }: DropdownModelProviderProps) => { } }, [selectedProvider, selectedModel]) + // Reset search value when dropdown closes + const onOpenChange = useCallback((open: boolean) => { + setOpen(open) + if (!open) { + requestAnimationFrame(() => setSearchValue('')) + } else { + // Focus search input when opening + setTimeout(() => { + searchInputRef.current?.focus() + }, 100) + } + }, []) + + // Clear search and focus input + const onClearSearch = useCallback(() => { + setSearchValue('') + searchInputRef.current?.focus() + }, []) + + // Create searchable items from all models + const searchableItems = useMemo(() => { + const items: SearchableModel[] = [] + + providers.forEach((provider) => { + if (!provider.active) return + + provider.models.forEach((modelItem) => { + // Skip models that require API key but don't have one (except llama.cpp) + if (provider.provider !== 'llama.cpp' && !provider.api_key?.length) { + return + } + + const capabilities = modelItem.capabilities || [] + const capabilitiesString = capabilities.join(' ') + const providerTitle = getProviderTitle(provider.provider) + + // Create search string with model id, provider, and capabilities + const searchStr = + `${modelItem.id} ${providerTitle} ${provider.provider} ${capabilitiesString}`.toLowerCase() + + items.push({ + provider, + model: modelItem, + searchStr, + value: `${provider.provider}:${modelItem.id}`, + }) + }) + }) + + return items + }, [providers]) + + // Create Fzf instance for fuzzy search + const fzfInstance = useMemo(() => { + return new Fzf(searchableItems, { + selector: (item) => item.model.id, + }) + }, [searchableItems]) + + // Filter models based on search value + const filteredItems = useMemo(() => { + if (!searchValue) return searchableItems + + return fzfInstance.find(searchValue).map((result) => { + const item = result.item + const positions = Array.from(result.positions) || [] + const highlightedId = highlightFzfMatch( + item.model.id, + positions, + 'text-accent' + ) + + return { + ...item, + highlightedId, + } + }) + }, [searchableItems, searchValue, fzfInstance]) + + // Group filtered items by provider + const groupedItems = useMemo(() => { + const groups: Record = {} + + filteredItems.forEach((item) => { + const providerKey = item.provider.provider + if (!groups[providerKey]) { + groups[providerKey] = [] + } + groups[providerKey].push(item) + }) + + return groups + }, [filteredItems]) + + const handleSelect = useCallback( + (searchableModel: SearchableModel) => { + selectModelProvider( + searchableModel.provider.provider, + searchableModel.model.id + ) + updateCurrentThreadModel({ + id: searchableModel.model.id, + provider: searchableModel.provider.provider, + }) + setSearchValue('') + setOpen(false) + }, + [selectModelProvider, updateCurrentThreadModel] + ) + const currentModel = selectedModel?.id ? getModelBy(selectedModel?.id) : undefined @@ -63,114 +185,159 @@ const DropdownModelProvider = ({ model }: DropdownModelProviderProps) => { const provider = getProviderByName(selectedProvider) return ( - <> - -
- - - - {currentModel?.settings && provider && ( - - )} -
- - - {providers.map((provider, index) => { - // Only show active providers - if (!provider.active) return null + > + {displayModel} + + + + {currentModel?.settings && provider && ( + + )} + - return ( -
-
- - - - {getProviderTitle(provider.provider)} - - + +
+ {/* Search input */} +
+ setSearchValue(e.target.value)} + placeholder="Search models..." + className="text-sm font-normal outline-0" + /> + {searchValue.length > 0 && ( +
+ +
+ )} +
+ + {/* Model list */} +
+ {Object.keys(groupedItems).length === 0 && searchValue ? ( +
+ No models found for "{searchValue}" +
+ ) : ( +
+ {Object.entries(groupedItems).map(([providerKey, models]) => { + const providerInfo = providers.find( + (p) => p.provider === providerKey + ) + if (!providerInfo) return null + + return (
- navigate({ - to: route.settings.providers, - params: { providerName: provider.provider }, - }) - } + key={providerKey} + className="bg-main-view-fg/4 first:mt-0 rounded-sm my-1.5 mx-1.5 first:mb-0" > - -
-
- - {provider.models.map((model, modelIndex) => { - const capabilities = model.capabilities || [] - - return ( - { - selectModelProvider(provider.provider, model.id) - updateCurrentThreadModel({ - id: model.id, - provider: provider.provider, - }) - }} - > -
- - {model.id} + {/* Provider header */} +
+
+ + + {getProviderTitle(providerInfo.provider)} -
- -
- - ) - })} -
- ) - })} - - - - +
{ + e.stopPropagation() + navigate({ + to: route.settings.providers, + params: { providerName: providerInfo.provider }, + }) + setOpen(false) + }} + > + +
+
+ + {/* Models for this provider */} + {models.map((searchableModel) => { + const isSelected = + selectedModel?.id === searchableModel.model.id && + selectedProvider === searchableModel.provider.provider + const capabilities = + searchableModel.model.capabilities || [] + + return ( +
handleSelect(searchableModel)} + className={cn( + 'mx-1 mb-1 px-2 py-1.5 rounded cursor-pointer flex items-center gap-2 transition-all duration-200', + 'hover:bg-main-view-fg/10', + isSelected && 'bg-main-view-fg/15' + )} + > +
+ + {isSelected && ( + + )} +
+ {capabilities.length > 0 && ( +
+ +
+ )} +
+
+ ) + })} +
+ ) + })} +
+ )} +
+
+ + ) }