From 81fea5665b471139b51647c230c8135d591df70d Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 25 Feb 2025 09:36:55 +0700 Subject: [PATCH] chore: enhance onboarding screen's models (#4723) * chore: enhance onboarding screen's models * chore: lint fix * chore: correct lint fix command * chore: fix tests --- .husky/pre-commit | 2 +- web/containers/ModelDropdown/index.tsx | 36 ++++---- .../Settings/SettingLeftPanel/index.tsx | 2 +- .../OnDeviceStarterScreen/index.test.tsx | 91 ++++++++++++------- .../ChatBody/OnDeviceStarterScreen/index.tsx | 80 +++++++--------- web/utils/model.ts | 9 +- 6 files changed, 122 insertions(+), 98 deletions(-) diff --git a/.husky/pre-commit b/.husky/pre-commit index 53c4e577e..94c03b512 100644 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1 +1 @@ -npx oxlint@latest --fix \ No newline at end of file +yarn lint --fix --quiet \ No newline at end of file diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index a702d12f7..6d58b9893 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -35,6 +35,7 @@ import useDownloadModel from '@/hooks/useDownloadModel' import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { useGetEngines } from '@/hooks/useEngineManagement' +import { useGetModelSources } from '@/hooks/useModelSource' import useRecommendedModel from '@/hooks/useRecommendedModel' import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' @@ -44,6 +45,8 @@ import { formatDownloadPercentage, toGigabytes } from '@/utils/converter' import { manualRecommendationModel } from '@/utils/model' import { getLogoEngine, getTitleByEngine } from '@/utils/modelEngine' +import { extractModelName } from '@/utils/modelSource' + import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { configuredModelsAtom, @@ -84,6 +87,7 @@ const ModelDropdown = ({ const [toggle, setToggle] = useState(null) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const { recommendedModel, downloadedModels } = useRecommendedModel() + const { sources } = useGetModelSources() const [dropdownOptions, setDropdownOptions] = useState( null ) @@ -97,11 +101,8 @@ const ModelDropdown = ({ const configuredModels = useAtomValue(configuredModelsAtom) const { stopModel } = useActiveModel() - const featuredModels = configuredModels.filter( - (x) => - manualRecommendationModel.includes(x.id) && - x.metadata?.tags?.includes('Featured') && - x.metadata?.size < 5000000000 + const featuredModels = sources?.filter((x) => + manualRecommendationModel.includes(x.id) ) const { updateThreadMetadata } = useCreateNewThread() @@ -464,9 +465,9 @@ const ModelDropdown = ({ showModel && !searchText.length && (
    - {featuredModels.map((model) => { + {featuredModels?.map((model) => { const isDownloading = downloadingModels.some( - (md) => md === model.id + (md) => md === (model.models[0]?.id ?? model.id) ) return (
  • - {model.name} + {extractModelName(model.id)}

    - {toGigabytes(model.metadata?.size)} + {toGigabytes(model.models[0]?.size)} {!isDownloading ? ( - downloadModel( - model.sources[0].url, - model.id - ) + downloadModel(model.models[0]?.id) } /> ) : ( Object.values(downloadStates) - .filter((x) => x.modelId === model.id) + .filter( + (x) => + x.modelId === + (model.models[0]?.id ?? model.id) + ) .map((item) => ( { for (const extension of extensions) { const settings = await extension.getSettings() - if (settings && settings.length > 0) { + if (settings && settings.length > 0 && settings.some((e) => e.title)) { extensionsMenu.push({ name: extension.productName, setting: extension.name, diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.test.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.test.tsx index 44ecc3dad..8e316ee98 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.test.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.test.tsx @@ -51,27 +51,30 @@ jest.mock('@/hooks/useDownloadModel', () => ({ const mockAtomValue = jest.spyOn(jotai, 'useAtomValue') const mockSetAtom = jest.spyOn(jotai, 'useSetAtom') -describe('OnDeviceStarterScreen', () => { - const mockExtensionHasSettings = [ - { - name: 'Test Extension', - setting: 'test-setting', - apiKey: 'test-key', - provider: 'test-provider', - }, - ] +jest.mock('@/hooks/useModelSource') +import * as source from '@/hooks/useModelSource' + +describe('OnDeviceStarterScreen', () => { beforeEach(() => { mockAtomValue.mockImplementation(() => []) mockSetAtom.mockImplementation(() => jest.fn()) }) + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: [], + error: null, + mutate: jest.fn(), + }) it('renders the component', () => { + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: [], + error: null, + mutate: jest.fn(), + }) render( - + ) @@ -80,11 +83,14 @@ describe('OnDeviceStarterScreen', () => { }) it('handles search input', () => { + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: [], + error: null, + mutate: jest.fn(), + }) render( - + ) @@ -97,11 +103,14 @@ describe('OnDeviceStarterScreen', () => { it('displays "No Result Found" when no models match the search', () => { mockAtomValue.mockImplementation(() => []) + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: [], + error: null, + mutate: jest.fn(), + }) render( - + ) @@ -114,38 +123,60 @@ describe('OnDeviceStarterScreen', () => { it('renders featured models', () => { const mockConfiguredModels = [ { - id: 'gemma-2-9b-it', - name: 'Gemma 2B', + id: 'cortexso/deepseek-r1', + name: 'DeepSeek R1', metadata: { tags: ['Featured'], author: 'Test Author', size: 3000000000, }, + models: [ + { + id: 'cortexso/deepseek-r1', + name: 'DeepSeek R1', + metadata: { + tags: ['Featured'], + }, + }, + ], }, { - id: 'llama3.1-8b-instruct', + id: 'cortexso/llama3.2', name: 'Llama 3.1', metadata: { tags: [], author: 'Test Author', size: 2000000000 }, + models: [ + { + id: 'cortexso/deepseek-r1', + name: 'DeepSeek R1', + metadata: { + tags: ['Featured'], + }, + }, + ], }, ] - mockAtomValue.mockImplementation((atom) => { - return mockConfiguredModels + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: mockConfiguredModels, + error: null, + mutate: jest.fn(), }) render( - + ) - expect(screen.getByText('Gemma 2B')).toBeInTheDocument() - expect(screen.queryByText('Llama 3.1')).not.toBeInTheDocument() + expect(screen.getAllByText('deepseek-r1')[0]).toBeInTheDocument() }) it('renders cloud models', () => { + jest.spyOn(source, 'useGetModelSources').mockReturnValue({ + sources: [], + error: null, + mutate: jest.fn(), + }) const mockRemoteModels = [ { id: 'remote-model-1', name: 'Remote Model 1', engine: 'openai' }, { id: 'remote-model-2', name: 'Remote Model 2', engine: 'anthropic' }, @@ -160,9 +191,7 @@ describe('OnDeviceStarterScreen', () => { render( - + ) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx index 7bbc9acbb..20a7216bc 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx @@ -26,6 +26,8 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState' import { useGetEngines } from '@/hooks/useEngineManagement' +import { useGetModelSources } from '@/hooks/useModelSource' + import { formatDownloadPercentage, toGigabytes } from '@/utils/converter' import { manualRecommendationModel } from '@/utils/model' import { @@ -34,6 +36,8 @@ import { isLocalEngine, } from '@/utils/modelEngine' +import { extractModelName } from '@/utils/modelSource' + import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { configuredModelsAtom, @@ -55,36 +59,17 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => { const { engines } = useGetEngines() const configuredModels = useAtomValue(configuredModelsAtom) + const { sources } = useGetModelSources() const setMainViewState = useSetAtom(mainViewStateAtom) - const featuredModel = configuredModels.filter((x) => { - const manualRecommendModel = configuredModels.filter((x) => - manualRecommendationModel.includes(x.id) - ) - - if (manualRecommendModel.length === 2) { - return ( - x.id === manualRecommendationModel[0] || - x.id === manualRecommendationModel[1] - ) - } else { - return ( - x.metadata?.tags?.includes('Featured') && x.metadata?.size < 5000000000 - ) - } - }) + const featuredModels = sources?.filter((x) => + manualRecommendationModel.includes(x.id) + ) const remoteModel = configuredModels.filter( (x) => !isLocalEngine(engines, x.engine) ) - const filteredModels = configuredModels.filter((model) => { - return ( - isLocalEngine(engines, model.engine) && - model.name.toLowerCase().includes(searchValue.toLowerCase()) - ) - }) - const remoteModelEngine = remoteModel.map((x) => x.engine) const groupByEngine = remoteModelEngine.filter(function (item, index) { @@ -142,16 +127,16 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => { !isOpen ? 'invisible' : 'visible' )} > - {!filteredModels.length ? ( + {!featuredModels?.length ? (

    No Result Found

    ) : ( - filteredModels.map((model) => { + sources?.map((model) => { const isDownloading = downloadingModels.some( - (md) => md === model.id + (md) => md === (model.models[0]?.id ?? model.id) ) return (
    { >

    - {model.name} + {extractModelName(model.id)}

    - +
    - {toGigabytes(model.metadata?.size)} + {toGigabytes(model.models[0]?.size)} {!isDownloading ? ( { className="cursor-pointer text-[hsla(var(--app-link))]" onClick={() => downloadModel( - model.sources[0].url, - model.id, - model.name + model.models[0]?.id ?? model.id ) } /> ) : ( Object.values(downloadStates) - .filter((x) => x.modelId === model.id) + .filter( + (x) => x.modelId === model.models[0]?.id + ) .map((item) => ( {

    - {featuredModel.slice(0, 2).map((featModel) => { + {featuredModels?.map((featModel) => { const isDownloading = downloadingModels.some( - (md) => md === featModel.id + (md) => md === (featModel.models[0]?.id ?? featModel.id) ) return (
    { className="my-2 flex items-start justify-between gap-2 border-b border-[hsla(var(--app-border))] pb-4 pt-1 last:border-none" >
    -
    {featModel.name}
    +
    + {extractModelName(featModel.id)} +
    {isDownloading ? (
    {Object.values(downloadStates) - .filter((x) => x.modelId === featModel.id) + .filter( + (x) => x.modelId === featModel.models[0]?.id + ) .map((item, i) => (
    {
    ))} - {toGigabytes(featModel.metadata?.size)} + {toGigabytes(featModel.models[0]?.size)}
    ) : ( @@ -271,17 +263,13 @@ const OnDeviceStarterScreen = ({ isShowStarterScreen }: Props) => { theme="ghost" className="!bg-[hsla(var(--secondary-bg))]" onClick={() => - downloadModel( - featModel.sources[0].url, - featModel.id, - featModel.name - ) + downloadModel(featModel.models[0]?.id) } > Download - {toGigabytes(featModel.metadata?.size)} + {toGigabytes(featModel.models[0]?.size)}
    )} diff --git a/web/utils/model.ts b/web/utils/model.ts index 00bf80c12..2774ec500 100644 --- a/web/utils/model.ts +++ b/web/utils/model.ts @@ -8,7 +8,12 @@ export const normalizeModelId = (downloadUrl: string): string => { return downloadUrl.split('/').pop() ?? downloadUrl } +/** + * Default models to recommend to users when they first open the app. + * TODO: These will be replaced when we have a proper recommendation system + * AND cortexso repositories are updated with tags. + */ export const manualRecommendationModel = [ - 'llama3.2-1b-instruct', - 'llama3.2-3b-instruct', + 'cortexso/deepseek-r1', + 'cortexso/llama3.2', ]