chore: enhance onboarding screen's models (#4723)

* chore: enhance onboarding screen's models

* chore: lint fix

* chore: correct lint fix command

* chore: fix tests
This commit is contained in:
Louis 2025-02-25 09:36:55 +07:00 committed by GitHub
parent 60257635ad
commit 81fea5665b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 122 additions and 98 deletions

View File

@ -1 +1 @@
npx oxlint@latest --fix yarn lint --fix --quiet

View File

@ -35,6 +35,7 @@ import useDownloadModel from '@/hooks/useDownloadModel'
import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
import { useGetEngines } from '@/hooks/useEngineManagement' import { useGetEngines } from '@/hooks/useEngineManagement'
import { useGetModelSources } from '@/hooks/useModelSource'
import useRecommendedModel from '@/hooks/useRecommendedModel' import useRecommendedModel from '@/hooks/useRecommendedModel'
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
@ -44,6 +45,8 @@ import { formatDownloadPercentage, toGigabytes } from '@/utils/converter'
import { manualRecommendationModel } from '@/utils/model' import { manualRecommendationModel } from '@/utils/model'
import { getLogoEngine, getTitleByEngine } from '@/utils/modelEngine' import { getLogoEngine, getTitleByEngine } from '@/utils/modelEngine'
import { extractModelName } from '@/utils/modelSource'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
@ -84,6 +87,7 @@ const ModelDropdown = ({
const [toggle, setToggle] = useState<HTMLDivElement | null>(null) const [toggle, setToggle] = useState<HTMLDivElement | null>(null)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const { recommendedModel, downloadedModels } = useRecommendedModel() const { recommendedModel, downloadedModels } = useRecommendedModel()
const { sources } = useGetModelSources()
const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>( const [dropdownOptions, setDropdownOptions] = useState<HTMLDivElement | null>(
null null
) )
@ -97,11 +101,8 @@ const ModelDropdown = ({
const configuredModels = useAtomValue(configuredModelsAtom) const configuredModels = useAtomValue(configuredModelsAtom)
const { stopModel } = useActiveModel() const { stopModel } = useActiveModel()
const featuredModels = configuredModels.filter( const featuredModels = sources?.filter((x) =>
(x) => manualRecommendationModel.includes(x.id)
manualRecommendationModel.includes(x.id) &&
x.metadata?.tags?.includes('Featured') &&
x.metadata?.size < 5000000000
) )
const { updateThreadMetadata } = useCreateNewThread() const { updateThreadMetadata } = useCreateNewThread()
@ -464,9 +465,9 @@ const ModelDropdown = ({
showModel && showModel &&
!searchText.length && ( !searchText.length && (
<ul className="pb-2"> <ul className="pb-2">
{featuredModels.map((model) => { {featuredModels?.map((model) => {
const isDownloading = downloadingModels.some( const isDownloading = downloadingModels.some(
(md) => md === model.id (md) => md === (model.models[0]?.id ?? model.id)
) )
return ( return (
<li <li
@ -475,34 +476,35 @@ const ModelDropdown = ({
> >
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<p <p
className="max-w-[200px] overflow-hidden truncate whitespace-nowrap text-[hsla(var(--text-secondary))]" className="max-w-[200px] overflow-hidden truncate whitespace-nowrap capitalize text-[hsla(var(--text-secondary))]"
title={model.name} title={model.id}
> >
{model.name} {extractModelName(model.id)}
</p> </p>
<ModelLabel <ModelLabel
size={model.metadata?.size} size={model.models[0]?.size}
compact compact
/> />
</div> </div>
<div className="flex items-center gap-2 text-[hsla(var(--text-tertiary))]"> <div className="flex items-center gap-2 text-[hsla(var(--text-tertiary))]">
<span className="font-medium"> <span className="font-medium">
{toGigabytes(model.metadata?.size)} {toGigabytes(model.models[0]?.size)}
</span> </span>
{!isDownloading ? ( {!isDownloading ? (
<DownloadCloudIcon <DownloadCloudIcon
size={18} size={18}
className="cursor-pointer text-[hsla(var(--app-link))]" className="cursor-pointer text-[hsla(var(--app-link))]"
onClick={() => onClick={() =>
downloadModel( downloadModel(model.models[0]?.id)
model.sources[0].url,
model.id
)
} }
/> />
) : ( ) : (
Object.values(downloadStates) Object.values(downloadStates)
.filter((x) => x.modelId === model.id) .filter(
(x) =>
x.modelId ===
(model.models[0]?.id ?? model.id)
)
.map((item) => ( .map((item) => (
<ProgressCircle <ProgressCircle
key={item.modelId} key={item.modelId}

View File

@ -43,7 +43,7 @@ const SettingLeftPanel = () => {
for (const extension of extensions) { for (const extension of extensions) {
const settings = await extension.getSettings() const settings = await extension.getSettings()
if (settings && settings.length > 0) { if (settings && settings.length > 0 && settings.some((e) => e.title)) {
extensionsMenu.push({ extensionsMenu.push({
name: extension.productName, name: extension.productName,
setting: extension.name, setting: extension.name,

View File

@ -51,27 +51,30 @@ jest.mock('@/hooks/useDownloadModel', () => ({
const mockAtomValue = jest.spyOn(jotai, 'useAtomValue') const mockAtomValue = jest.spyOn(jotai, 'useAtomValue')
const mockSetAtom = jest.spyOn(jotai, 'useSetAtom') const mockSetAtom = jest.spyOn(jotai, 'useSetAtom')
describe('OnDeviceStarterScreen', () => { jest.mock('@/hooks/useModelSource')
const mockExtensionHasSettings = [
{
name: 'Test Extension',
setting: 'test-setting',
apiKey: 'test-key',
provider: 'test-provider',
},
]
import * as source from '@/hooks/useModelSource'
describe('OnDeviceStarterScreen', () => {
beforeEach(() => { beforeEach(() => {
mockAtomValue.mockImplementation(() => []) mockAtomValue.mockImplementation(() => [])
mockSetAtom.mockImplementation(() => jest.fn()) mockSetAtom.mockImplementation(() => jest.fn())
}) })
jest.spyOn(source, 'useGetModelSources').mockReturnValue({
sources: [],
error: null,
mutate: jest.fn(),
})
it('renders the component', () => { it('renders the component', () => {
jest.spyOn(source, 'useGetModelSources').mockReturnValue({
sources: [],
error: null,
mutate: jest.fn(),
})
render( render(
<Provider> <Provider>
<OnDeviceStarterScreen <OnDeviceStarterScreen isShowStarterScreen={true} />
extensionHasSettings={mockExtensionHasSettings}
/>
</Provider> </Provider>
) )
@ -80,11 +83,14 @@ describe('OnDeviceStarterScreen', () => {
}) })
it('handles search input', () => { it('handles search input', () => {
jest.spyOn(source, 'useGetModelSources').mockReturnValue({
sources: [],
error: null,
mutate: jest.fn(),
})
render( render(
<Provider> <Provider>
<OnDeviceStarterScreen <OnDeviceStarterScreen isShowStarterScreen={true} />
extensionHasSettings={mockExtensionHasSettings}
/>
</Provider> </Provider>
) )
@ -97,11 +103,14 @@ describe('OnDeviceStarterScreen', () => {
it('displays "No Result Found" when no models match the search', () => { it('displays "No Result Found" when no models match the search', () => {
mockAtomValue.mockImplementation(() => []) mockAtomValue.mockImplementation(() => [])
jest.spyOn(source, 'useGetModelSources').mockReturnValue({
sources: [],
error: null,
mutate: jest.fn(),
})
render( render(
<Provider> <Provider>
<OnDeviceStarterScreen <OnDeviceStarterScreen isShowStarterScreen={true} />
extensionHasSettings={mockExtensionHasSettings}
/>
</Provider> </Provider>
) )
@ -114,38 +123,60 @@ describe('OnDeviceStarterScreen', () => {
it('renders featured models', () => { it('renders featured models', () => {
const mockConfiguredModels = [ const mockConfiguredModels = [
{ {
id: 'gemma-2-9b-it', id: 'cortexso/deepseek-r1',
name: 'Gemma 2B', name: 'DeepSeek R1',
metadata: { metadata: {
tags: ['Featured'], tags: ['Featured'],
author: 'Test Author', author: 'Test Author',
size: 3000000000, size: 3000000000,
}, },
models: [
{
id: 'cortexso/deepseek-r1',
name: 'DeepSeek R1',
metadata: {
tags: ['Featured'],
},
},
],
}, },
{ {
id: 'llama3.1-8b-instruct', id: 'cortexso/llama3.2',
name: 'Llama 3.1', name: 'Llama 3.1',
metadata: { tags: [], author: 'Test Author', size: 2000000000 }, metadata: { tags: [], author: 'Test Author', size: 2000000000 },
models: [
{
id: 'cortexso/deepseek-r1',
name: 'DeepSeek R1',
metadata: {
tags: ['Featured'],
},
},
],
}, },
] ]
mockAtomValue.mockImplementation((atom) => { jest.spyOn(source, 'useGetModelSources').mockReturnValue({
return mockConfiguredModels sources: mockConfiguredModels,
error: null,
mutate: jest.fn(),
}) })
render( render(
<Provider> <Provider>
<OnDeviceStarterScreen <OnDeviceStarterScreen isShowStarterScreen={true} />
extensionHasSettings={mockExtensionHasSettings}
/>
</Provider> </Provider>
) )
expect(screen.getByText('Gemma 2B')).toBeInTheDocument() expect(screen.getAllByText('deepseek-r1')[0]).toBeInTheDocument()
expect(screen.queryByText('Llama 3.1')).not.toBeInTheDocument()
}) })
it('renders cloud models', () => { it('renders cloud models', () => {
jest.spyOn(source, 'useGetModelSources').mockReturnValue({
sources: [],
error: null,
mutate: jest.fn(),
})
const mockRemoteModels = [ const mockRemoteModels = [
{ id: 'remote-model-1', name: 'Remote Model 1', engine: 'openai' }, { id: 'remote-model-1', name: 'Remote Model 1', engine: 'openai' },
{ id: 'remote-model-2', name: 'Remote Model 2', engine: 'anthropic' }, { id: 'remote-model-2', name: 'Remote Model 2', engine: 'anthropic' },
@ -160,9 +191,7 @@ describe('OnDeviceStarterScreen', () => {
render( render(
<Provider> <Provider>
<OnDeviceStarterScreen <OnDeviceStarterScreen isShowStarterScreen={true} />
extensionHasSettings={mockExtensionHasSettings}
/>
</Provider> </Provider>
) )

View File

@ -26,6 +26,8 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
import { useGetEngines } from '@/hooks/useEngineManagement' import { useGetEngines } from '@/hooks/useEngineManagement'
import { useGetModelSources } from '@/hooks/useModelSource'
import { formatDownloadPercentage, toGigabytes } from '@/utils/converter' import { formatDownloadPercentage, toGigabytes } from '@/utils/converter'
import { manualRecommendationModel } from '@/utils/model' import { manualRecommendationModel } from '@/utils/model'
import { import {
@ -34,6 +36,8 @@ import {
isLocalEngine, isLocalEngine,
} from '@/utils/modelEngine' } from '@/utils/modelEngine'
import { extractModelName } from '@/utils/modelSource'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
@ -55,36 +59,17 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
const { engines } = useGetEngines() const { engines } = useGetEngines()
const configuredModels = useAtomValue(configuredModelsAtom) const configuredModels = useAtomValue(configuredModelsAtom)
const { sources } = useGetModelSources()
const setMainViewState = useSetAtom(mainViewStateAtom) const setMainViewState = useSetAtom(mainViewStateAtom)
const featuredModel = configuredModels.filter((x) => { const featuredModels = sources?.filter((x) =>
const manualRecommendModel = configuredModels.filter((x) =>
manualRecommendationModel.includes(x.id) manualRecommendationModel.includes(x.id)
) )
if (manualRecommendModel.length === 2) {
return (
x.id === manualRecommendationModel[0] ||
x.id === manualRecommendationModel[1]
)
} else {
return (
x.metadata?.tags?.includes('Featured') && x.metadata?.size < 5000000000
)
}
})
const remoteModel = configuredModels.filter( const remoteModel = configuredModels.filter(
(x) => !isLocalEngine(engines, x.engine) (x) => !isLocalEngine(engines, x.engine)
) )
const filteredModels = configuredModels.filter((model) => {
return (
isLocalEngine(engines, model.engine) &&
model.name.toLowerCase().includes(searchValue.toLowerCase())
)
})
const remoteModelEngine = remoteModel.map((x) => x.engine) const remoteModelEngine = remoteModel.map((x) => x.engine)
const groupByEngine = remoteModelEngine.filter(function (item, index) { const groupByEngine = remoteModelEngine.filter(function (item, index) {
@ -142,16 +127,16 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
!isOpen ? 'invisible' : 'visible' !isOpen ? 'invisible' : 'visible'
)} )}
> >
{!filteredModels.length ? ( {!featuredModels?.length ? (
<div className="p-3 text-center"> <div className="p-3 text-center">
<p className="line-clamp-1 text-[hsla(var(--text-secondary))]"> <p className="line-clamp-1 text-[hsla(var(--text-secondary))]">
No Result Found No Result Found
</p> </p>
</div> </div>
) : ( ) : (
filteredModels.map((model) => { sources?.map((model) => {
const isDownloading = downloadingModels.some( const isDownloading = downloadingModels.some(
(md) => md === model.id (md) => md === (model.models[0]?.id ?? model.id)
) )
return ( return (
<div <div
@ -160,16 +145,19 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
> >
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
<p <p
className={twMerge('line-clamp-1')} className={'line-clamp-1 capitalize'}
title={model.name} title={extractModelName(model.id)}
> >
{model.name} {extractModelName(model.id)}
</p> </p>
<ModelLabel size={model.metadata?.size} compact /> <ModelLabel
size={model.models[0]?.size}
compact
/>
</div> </div>
<div className="flex items-center gap-2 text-[hsla(var(--text-tertiary))]"> <div className="flex items-center gap-2 text-[hsla(var(--text-tertiary))]">
<span className="font-medium"> <span className="font-medium">
{toGigabytes(model.metadata?.size)} {toGigabytes(model.models[0]?.size)}
</span> </span>
{!isDownloading ? ( {!isDownloading ? (
<DownloadCloudIcon <DownloadCloudIcon
@ -177,15 +165,15 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
className="cursor-pointer text-[hsla(var(--app-link))]" className="cursor-pointer text-[hsla(var(--app-link))]"
onClick={() => onClick={() =>
downloadModel( downloadModel(
model.sources[0].url, model.models[0]?.id ?? model.id
model.id,
model.name
) )
} }
/> />
) : ( ) : (
Object.values(downloadStates) Object.values(downloadStates)
.filter((x) => x.modelId === model.id) .filter(
(x) => x.modelId === model.models[0]?.id
)
.map((item) => ( .map((item) => (
<ProgressCircle <ProgressCircle
key={item.modelId} key={item.modelId}
@ -222,9 +210,9 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
</p> </p>
</div> </div>
{featuredModel.slice(0, 2).map((featModel) => { {featuredModels?.map((featModel) => {
const isDownloading = downloadingModels.some( const isDownloading = downloadingModels.some(
(md) => md === featModel.id (md) => md === (featModel.models[0]?.id ?? featModel.id)
) )
return ( return (
<div <div
@ -232,13 +220,17 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
className="my-2 flex items-start justify-between gap-2 border-b border-[hsla(var(--app-border))] pb-4 pt-1 last:border-none" className="my-2 flex items-start justify-between gap-2 border-b border-[hsla(var(--app-border))] pb-4 pt-1 last:border-none"
> >
<div className="w-full text-left"> <div className="w-full text-left">
<h6 className="mt-1.5 font-medium">{featModel.name}</h6> <h6 className="mt-1.5 font-medium capitalize">
{extractModelName(featModel.id)}
</h6>
</div> </div>
{isDownloading ? ( {isDownloading ? (
<div className="flex w-full flex-col items-end gap-2"> <div className="flex w-full flex-col items-end gap-2">
{Object.values(downloadStates) {Object.values(downloadStates)
.filter((x) => x.modelId === featModel.id) .filter(
(x) => x.modelId === featModel.models[0]?.id
)
.map((item, i) => ( .map((item, i) => (
<div <div
className="mt-1.5 flex w-full items-center gap-2" className="mt-1.5 flex w-full items-center gap-2"
@ -262,7 +254,7 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
</div> </div>
))} ))}
<span className="text-[hsla(var(--text-secondary))]"> <span className="text-[hsla(var(--text-secondary))]">
{toGigabytes(featModel.metadata?.size)} {toGigabytes(featModel.models[0]?.size)}
</span> </span>
</div> </div>
) : ( ) : (
@ -271,17 +263,13 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => {
theme="ghost" theme="ghost"
className="!bg-[hsla(var(--secondary-bg))]" className="!bg-[hsla(var(--secondary-bg))]"
onClick={() => onClick={() =>
downloadModel( downloadModel(featModel.models[0]?.id)
featModel.sources[0].url,
featModel.id,
featModel.name
)
} }
> >
Download Download
</Button> </Button>
<span className="text-[hsla(var(--text-secondary))]"> <span className="text-[hsla(var(--text-secondary))]">
{toGigabytes(featModel.metadata?.size)} {toGigabytes(featModel.models[0]?.size)}
</span> </span>
</div> </div>
)} )}

View File

@ -8,7 +8,12 @@ export const normalizeModelId = (downloadUrl: string): string => {
return downloadUrl.split('/').pop() ?? downloadUrl return downloadUrl.split('/').pop() ?? downloadUrl
} }
/**
* Default models to recommend to users when they first open the app.
* TODO: These will be replaced when we have a proper recommendation system
* AND cortexso repositories are updated with tags.
*/
export const manualRecommendationModel = [ export const manualRecommendationModel = [
'llama3.2-1b-instruct', 'cortexso/deepseek-r1',
'llama3.2-3b-instruct', 'cortexso/llama3.2',
] ]