diff --git a/joi/src/core/Tabs/Tabs.test.tsx b/joi/src/core/Tabs/Tabs.test.tsx index b6dcf8a7b..46bd48435 100644 --- a/joi/src/core/Tabs/Tabs.test.tsx +++ b/joi/src/core/Tabs/Tabs.test.tsx @@ -96,4 +96,20 @@ describe('@joi/core/Tabs', () => { 'Disabled tab' ) }) + + it('applies the tabStyle if provided', () => { + render( + {}} + tabStyle="segmented" + /> + ) + + const tabsContainer = screen.getByTestId('segmented-style') + expect(tabsContainer).toHaveClass('tabs') + expect(tabsContainer).toHaveClass('tabs--segmented') + }) }) diff --git a/joi/src/core/Tabs/index.tsx b/joi/src/core/Tabs/index.tsx index af004e2ba..2dca19831 100644 --- a/joi/src/core/Tabs/index.tsx +++ b/joi/src/core/Tabs/index.tsx @@ -7,6 +7,8 @@ import { Tooltip } from '../Tooltip' import './styles.scss' import { twMerge } from 'tailwind-merge' +type TabStyles = 'segmented' + type TabsProps = { options: { name: string @@ -14,8 +16,10 @@ type TabsProps = { disabled?: boolean tooltipContent?: string }[] - children: ReactNode + children?: ReactNode + defaultValue?: string + tabStyle?: TabStyles value: string onValueChange?: (value: string) => void } @@ -40,15 +44,18 @@ const TabsContent = ({ value, children, className }: TabsContentProps) => { const Tabs = ({ options, children, + tabStyle, defaultValue, value, onValueChange, + ...props }: TabsProps) => ( {options.map((option, i) => { diff --git a/joi/src/core/Tabs/styles.scss b/joi/src/core/Tabs/styles.scss index a24585b4e..ce3df013b 100644 --- a/joi/src/core/Tabs/styles.scss +++ b/joi/src/core/Tabs/styles.scss @@ -3,6 +3,27 @@ flex-direction: column; width: 100%; + &--segmented { + background-color: hsla(var(--secondary-bg)); + border-radius: 6px; + height: 33px; + + .tabs__list { + border: none; + justify-content: center; + align-items: center; + height: 33px; + } + + .tabs__trigger[data-state='active'] { + background-color: hsla(var(--app-bg)); + border: none; + height: 25px; + margin: 0 4px; + border-radius: 5px; + } + } + &__list { flex-shrink: 0; display: flex; @@ -14,9 +35,11 @@ flex: 1; height: 38px; display: flex; + color: hsla(var(--text-secondary)); align-items: center; justify-content: center; line-height: 1; + font-weight: medium; user-select: none; &:focus { position: relative; @@ -38,4 +61,5 @@ .tabs__trigger[data-state='active'] { border-bottom: 1px solid hsla(var(--primary-bg)); font-weight: 600; + color: hsla(var(--text-primary)); } diff --git a/web/containers/ModelDropdown/index.test.tsx b/web/containers/ModelDropdown/index.test.tsx new file mode 100644 index 000000000..7541f891b --- /dev/null +++ b/web/containers/ModelDropdown/index.test.tsx @@ -0,0 +1,101 @@ +import { render, screen, waitFor } from '@testing-library/react' +import { useAtomValue, useAtom, useSetAtom } from 'jotai' +import ModelDropdown from './index' +import useRecommendedModel from '@/hooks/useRecommendedModel' +import '@testing-library/jest-dom' + +class ResizeObserverMock { + observe() {} + unobserve() {} + disconnect() {} +} + +global.ResizeObserver = ResizeObserverMock + +jest.mock('jotai', () => { + const originalModule = jest.requireActual('jotai') + return { + ...originalModule, + useAtom: jest.fn(), + useAtomValue: jest.fn(), + useSetAtom: jest.fn(), + } +}) + +jest.mock('@/containers/ModelLabel') +jest.mock('@/hooks/useRecommendedModel') + +describe('ModelDropdown', () => { + const remoteModel = { + metadata: { tags: ['Featured'], size: 100 }, + name: 'Test Model', + engine: 'openai', + } + + const localModel = { + metadata: { tags: ['Local'], size: 100 }, + name: 'Local Model', + engine: 'nitro', + } + + const configuredModels = [remoteModel, localModel] + + const mockConfiguredModel = configuredModels + const selectedModel = { id: 'selectedModel', name: 'selectedModel' } + const setSelectedModel = jest.fn() + const showEngineListModel = ['nitro'] + const showEngineListModelAtom = jest.fn() + + beforeEach(() => { + jest.clearAllMocks() + ;(useAtom as jest.Mock).mockReturnValue([selectedModel, setSelectedModel]) + ;(useAtom as jest.Mock).mockReturnValue([ + showEngineListModel, + showEngineListModelAtom, + ]) + ;(useAtomValue as jest.Mock).mockReturnValue(mockConfiguredModel) + ;(useRecommendedModel as jest.Mock).mockReturnValue({ + recommendedModel: { id: 'model1', parameters: [], settings: [] }, + downloadedModels: [], + }) + }) + + it('renders the ModelDropdown component', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('model-selector')).toBeInTheDocument() + }) + }) + + it('renders the ModelDropdown component as disabled', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector')).toHaveClass( + 'pointer-events-none' + ) + }) + }) + + it('renders the ModelDropdown component as badge for chat Input', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByTestId('model-selector-badge')).toBeInTheDocument() + expect(screen.getByTestId('model-selector-badge')).toHaveClass('badge') + }) + }) + + it('renders the Tab correctly', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('model-selector')).toBeInTheDocument() + expect(screen.getByText('On-device')) + expect(screen.getByText('Cloud')) + }) + }) +}) diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index d8743ddce..2a0c4ffaf 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -8,7 +8,7 @@ import { Button, Input, ScrollArea, - Select, + Tabs, useClickOutside, } from '@janhq/joi' @@ -70,8 +70,8 @@ const ModelDropdown = ({ strictedThread = true, }: Props) => { const { downloadModel } = useDownloadModel() - const [searchFilter, setSearchFilter] = useState('all') - const [filterOptionsOpen, setFilterOptionsOpen] = useState(false) + + const [searchFilter, setSearchFilter] = useState('local') const [searchText, setSearchText] = useState('') const [open, setOpen] = useState(false) const activeThread = useAtomValue(activeThreadAtom) @@ -92,10 +92,7 @@ const ModelDropdown = ({ ) const { updateThreadMetadata } = useCreateNewThread() - useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ - dropdownOptions, - toggle, - ]) + useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle]) const [showEngineListModel, setShowEngineListModel] = useAtom( showEngineListModelAtom @@ -115,9 +112,6 @@ const ModelDropdown = ({ e.name.toLowerCase().includes(searchText.toLowerCase().trim()) ) .filter((e) => { - if (searchFilter === 'all') { - return e.engine - } if (searchFilter === 'local') { return localEngines.includes(e.engine) } @@ -152,9 +146,9 @@ const ModelDropdown = ({ useEffect(() => { if (!activeThread) return - let model = downloadedModels.find( - (model) => model.id === activeThread.assistants[0].model.id - ) + const modelId = activeThread?.assistants?.[0]?.model?.id + + let model = downloadedModels.find((model) => model.id === modelId) if (!model) { model = recommendedModel } @@ -309,10 +303,14 @@ const ModelDropdown = ({ } return ( -
+
{chatInputMode ? (
-
+
+ setSearchFilter(value)} + /> +
+
setSearchText(e.target.value)} suffixIcon={ searchText.length > 0 && ( @@ -365,26 +374,8 @@ const ModelDropdown = ({ ) } /> -
-