enhance: tabs component in model selection (#3730)

* ui: tabs-model-selection

* chore: updat tabs variant

* test: update test and render correct tab
This commit is contained in:
Faisal Amir 2024-09-24 20:14:43 +07:00 committed by GitHub
parent acd3be3a2a
commit 886b1cbc54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 38 deletions

View File

@ -96,4 +96,20 @@ describe('@joi/core/Tabs', () => {
'Disabled tab'
)
})
it('applies the tabStyle if provided', () => {
render(
<Tabs
data-testid="segmented-style"
options={mockOptions}
value="tab1"
onValueChange={() => {}}
tabStyle="segmented"
/>
)
const tabsContainer = screen.getByTestId('segmented-style')
expect(tabsContainer).toHaveClass('tabs')
expect(tabsContainer).toHaveClass('tabs--segmented')
})
})

View File

@ -7,6 +7,8 @@ import { Tooltip } from '../Tooltip'
import './styles.scss'
import { twMerge } from 'tailwind-merge'
type TabStyles = 'segmented'
type TabsProps = {
options: {
name: string
@ -14,8 +16,10 @@ type TabsProps = {
disabled?: boolean
tooltipContent?: string
}[]
children: ReactNode
children?: ReactNode
defaultValue?: string
tabStyle?: TabStyles
value: string
onValueChange?: (value: string) => void
}
@ -40,15 +44,18 @@ const TabsContent = ({ value, children, className }: TabsContentProps) => {
const Tabs = ({
options,
children,
tabStyle,
defaultValue,
value,
onValueChange,
...props
}: TabsProps) => (
<TabsPrimitive.Root
className="tabs"
className={twMerge('tabs', tabStyle && `tabs--${tabStyle}`)}
value={value}
defaultValue={defaultValue}
onValueChange={onValueChange}
{...props}
>
<TabsPrimitive.List className="tabs__list">
{options.map((option, i) => {

View File

@ -3,6 +3,27 @@
flex-direction: column;
width: 100%;
&--segmented {
background-color: hsla(var(--secondary-bg));
border-radius: 6px;
height: 33px;
.tabs__list {
border: none;
justify-content: center;
align-items: center;
height: 33px;
}
.tabs__trigger[data-state='active'] {
background-color: hsla(var(--app-bg));
border: none;
height: 25px;
margin: 0 4px;
border-radius: 5px;
}
}
&__list {
flex-shrink: 0;
display: flex;
@ -14,9 +35,11 @@
flex: 1;
height: 38px;
display: flex;
color: hsla(var(--text-secondary));
align-items: center;
justify-content: center;
line-height: 1;
font-weight: medium;
user-select: none;
&:focus {
position: relative;
@ -38,4 +61,5 @@
.tabs__trigger[data-state='active'] {
border-bottom: 1px solid hsla(var(--primary-bg));
font-weight: 600;
color: hsla(var(--text-primary));
}

View File

@ -0,0 +1,101 @@
import { render, screen, waitFor } from '@testing-library/react'
import { useAtomValue, useAtom, useSetAtom } from 'jotai'
import ModelDropdown from './index'
import useRecommendedModel from '@/hooks/useRecommendedModel'
import '@testing-library/jest-dom'
class ResizeObserverMock {
observe() {}
unobserve() {}
disconnect() {}
}
global.ResizeObserver = ResizeObserverMock
jest.mock('jotai', () => {
const originalModule = jest.requireActual('jotai')
return {
...originalModule,
useAtom: jest.fn(),
useAtomValue: jest.fn(),
useSetAtom: jest.fn(),
}
})
jest.mock('@/containers/ModelLabel')
jest.mock('@/hooks/useRecommendedModel')
describe('ModelDropdown', () => {
const remoteModel = {
metadata: { tags: ['Featured'], size: 100 },
name: 'Test Model',
engine: 'openai',
}
const localModel = {
metadata: { tags: ['Local'], size: 100 },
name: 'Local Model',
engine: 'nitro',
}
const configuredModels = [remoteModel, localModel]
const mockConfiguredModel = configuredModels
const selectedModel = { id: 'selectedModel', name: 'selectedModel' }
const setSelectedModel = jest.fn()
const showEngineListModel = ['nitro']
const showEngineListModelAtom = jest.fn()
beforeEach(() => {
jest.clearAllMocks()
;(useAtom as jest.Mock).mockReturnValue([selectedModel, setSelectedModel])
;(useAtom as jest.Mock).mockReturnValue([
showEngineListModel,
showEngineListModelAtom,
])
;(useAtomValue as jest.Mock).mockReturnValue(mockConfiguredModel)
;(useRecommendedModel as jest.Mock).mockReturnValue({
recommendedModel: { id: 'model1', parameters: [], settings: [] },
downloadedModels: [],
})
})
it('renders the ModelDropdown component', async () => {
render(<ModelDropdown />)
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
})
})
it('renders the ModelDropdown component as disabled', async () => {
render(<ModelDropdown disabled />)
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
expect(screen.getByTestId('model-selector')).toHaveClass(
'pointer-events-none'
)
})
})
it('renders the ModelDropdown component as badge for chat Input', async () => {
render(<ModelDropdown chatInputMode />)
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
expect(screen.getByTestId('model-selector-badge')).toBeInTheDocument()
expect(screen.getByTestId('model-selector-badge')).toHaveClass('badge')
})
})
it('renders the Tab correctly', async () => {
render(<ModelDropdown />)
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
expect(screen.getByText('On-device'))
expect(screen.getByText('Cloud'))
})
})
})

View File

@ -8,7 +8,7 @@ import {
Button,
Input,
ScrollArea,
Select,
Tabs,
useClickOutside,
} from '@janhq/joi'
@ -70,8 +70,8 @@ const ModelDropdown = ({
strictedThread = true,
}: Props) => {
const { downloadModel } = useDownloadModel()
const [searchFilter, setSearchFilter] = useState('all')
const [filterOptionsOpen, setFilterOptionsOpen] = useState(false)
const [searchFilter, setSearchFilter] = useState('local')
const [searchText, setSearchText] = useState('')
const [open, setOpen] = useState(false)
const activeThread = useAtomValue(activeThreadAtom)
@ -92,10 +92,7 @@ const ModelDropdown = ({
)
const { updateThreadMetadata } = useCreateNewThread()
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
dropdownOptions,
toggle,
])
useClickOutside(() => setOpen(false), null, [dropdownOptions, toggle])
const [showEngineListModel, setShowEngineListModel] = useAtom(
showEngineListModelAtom
@ -115,9 +112,6 @@ const ModelDropdown = ({
e.name.toLowerCase().includes(searchText.toLowerCase().trim())
)
.filter((e) => {
if (searchFilter === 'all') {
return e.engine
}
if (searchFilter === 'local') {
return localEngines.includes(e.engine)
}
@ -152,9 +146,9 @@ const ModelDropdown = ({
useEffect(() => {
if (!activeThread) return
let model = downloadedModels.find(
(model) => model.id === activeThread.assistants[0].model.id
)
const modelId = activeThread?.assistants?.[0]?.model?.id
let model = downloadedModels.find((model) => model.id === modelId)
if (!model) {
model = recommendedModel
}
@ -309,10 +303,14 @@ const ModelDropdown = ({
}
return (
<div className={twMerge('relative', disabled && 'pointer-events-none')}>
<div
className={twMerge('relative', disabled && 'pointer-events-none')}
data-testid="model-selector"
>
<div ref={setToggle}>
{chatInputMode ? (
<Badge
data-testid="model-selector-badge"
theme="secondary"
variant={open ? 'solid' : 'outline'}
className={twMerge(
@ -341,19 +339,30 @@ const ModelDropdown = ({
</div>
<div
className={twMerge(
'w=80 absolute right-0 z-20 mt-2 max-h-80 w-full overflow-hidden rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))] shadow-sm',
'absolute right-0 z-20 mt-2 max-h-80 w-full overflow-hidden rounded-lg border border-[hsla(var(--app-border))] bg-[hsla(var(--app-bg))] shadow-sm',
open ? 'flex' : 'hidden',
chatInputMode && 'bottom-8 left-0 w-72'
)}
ref={setDropdownOptions}
>
<div className="w-full">
<div className="relative">
<div className="p-2 pb-0">
<Tabs
options={[
{ name: 'On-device', value: 'local' },
{ name: 'Cloud', value: 'remote' },
]}
tabStyle="segmented"
value={searchFilter as string}
onValueChange={(value) => setSearchFilter(value)}
/>
</div>
<div className="relative border-b border-[hsla(var(--app-border))] py-2">
<Input
placeholder="Search"
value={searchText}
ref={searchInputRef}
className="rounded-none border-x-0 border-t-0 focus-within:ring-0 hover:border-b-[hsla(var(--app-border))]"
className="rounded-none border-x-0 border-b-0 border-t-0 focus-within:ring-0 "
onChange={(e) => setSearchText(e.target.value)}
suffixIcon={
searchText.length > 0 && (
@ -365,26 +374,8 @@ const ModelDropdown = ({
)
}
/>
<div
className={twMerge(
'absolute right-2 top-1/2 -translate-y-1/2',
searchText.length && 'hidden'
)}
>
<Select
value={searchFilter}
className="h-6 gap-1 px-2"
options={[
{ name: 'All', value: 'all' },
{ name: 'On-device', value: 'local' },
{ name: 'Cloud', value: 'remote' },
]}
onValueChange={(value) => setSearchFilter(value)}
onOpenChange={(open) => setFilterOptionsOpen(open)}
/>
</div>
</div>
<ScrollArea className="h-[calc(100%-36px)] w-full">
<ScrollArea className="h-[calc(100%-90px)] w-full">
{groupByEngine.map((engine, i) => {
const apiKey = !localEngines.includes(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]