fix: correct model dropdown for local models (#3736)
* fix: correct model dropdown for local models * fix: clean unused import * test: add Model.atom and model.Engine tests
This commit is contained in:
parent
8334076047
commit
ba1ddacde3
@ -6,7 +6,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
|
||||
|
||||
import { toGibibytes } from '@/utils/converter'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||
|
||||
@ -35,7 +35,7 @@ const TableActiveModel = () => {
|
||||
})}
|
||||
</tr>
|
||||
</thead>
|
||||
{activeModel && localEngines.includes(activeModel.engine) ? (
|
||||
{activeModel && isLocalEngine(activeModel.engine) ? (
|
||||
<tbody>
|
||||
<tr>
|
||||
<td
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { useAtomValue, useAtom, useSetAtom } from 'jotai'
|
||||
import { render, screen, waitFor, fireEvent } from '@testing-library/react'
|
||||
import { useAtomValue, useAtom } from 'jotai'
|
||||
import ModelDropdown from './index'
|
||||
import useRecommendedModel from '@/hooks/useRecommendedModel'
|
||||
import '@testing-library/jest-dom'
|
||||
@ -38,7 +38,7 @@ describe('ModelDropdown', () => {
|
||||
engine: 'nitro',
|
||||
}
|
||||
|
||||
const configuredModels = [remoteModel, localModel]
|
||||
const configuredModels = [remoteModel, localModel, localModel]
|
||||
|
||||
const mockConfiguredModel = configuredModels
|
||||
const selectedModel = { id: 'selectedModel', name: 'selectedModel' }
|
||||
@ -94,8 +94,20 @@ describe('ModelDropdown', () => {
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
expect(screen.getByText('On-device'))
|
||||
expect(screen.getByText('Cloud'))
|
||||
expect(screen.getByText('On-device')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cloud')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('filters models correctly', async () => {
|
||||
render(<ModelDropdown />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
fireEvent.click(screen.getByText('Cloud'))
|
||||
fireEvent.change(screen.getByText('Cloud'), {
|
||||
target: { value: 'remote' },
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -40,7 +40,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
|
||||
import {
|
||||
getLogoEngine,
|
||||
getTitleByEngine,
|
||||
localEngines,
|
||||
isLocalEngine,
|
||||
priorityEngine,
|
||||
} from '@/utils/modelEngine'
|
||||
|
||||
@ -101,7 +101,7 @@ const ModelDropdown = ({
|
||||
const isModelSupportRagAndTools = useCallback((model: Model) => {
|
||||
return (
|
||||
model?.engine === InferenceEngine.openai ||
|
||||
localEngines.includes(model?.engine as InferenceEngine)
|
||||
isLocalEngine(model?.engine as InferenceEngine)
|
||||
)
|
||||
}, [])
|
||||
|
||||
@ -113,10 +113,10 @@ const ModelDropdown = ({
|
||||
)
|
||||
.filter((e) => {
|
||||
if (searchFilter === 'local') {
|
||||
return localEngines.includes(e.engine)
|
||||
return isLocalEngine(e.engine)
|
||||
}
|
||||
if (searchFilter === 'remote') {
|
||||
return !localEngines.includes(e.engine)
|
||||
return !isLocalEngine(e.engine)
|
||||
}
|
||||
})
|
||||
.sort((a, b) => a.name.localeCompare(b.name))
|
||||
@ -236,7 +236,6 @@ const ModelDropdown = ({
|
||||
for (const extension of extensions) {
|
||||
if (typeof extension.getSettings === 'function') {
|
||||
const settings = await extension.getSettings()
|
||||
|
||||
if (
|
||||
(settings && settings.length > 0) ||
|
||||
(await extension.installationState()) !== 'NotRequired'
|
||||
@ -295,7 +294,7 @@ const ModelDropdown = ({
|
||||
}, [setShowEngineListModel, extensionHasSettings])
|
||||
|
||||
const isDownloadALocalModel = downloadedModels.some((x) =>
|
||||
localEngines.includes(x.engine)
|
||||
isLocalEngine(x.engine)
|
||||
)
|
||||
|
||||
if (strictedThread && !activeThread) {
|
||||
@ -377,7 +376,7 @@ const ModelDropdown = ({
|
||||
</div>
|
||||
<ScrollArea className="h-[calc(100%-90px)] w-full">
|
||||
{groupByEngine.map((engine, i) => {
|
||||
const apiKey = !localEngines.includes(engine)
|
||||
const apiKey = !isLocalEngine(engine)
|
||||
? extensionHasSettings.filter((x) => x.provider === engine)[0]
|
||||
?.apiKey.length > 1
|
||||
: true
|
||||
@ -417,7 +416,7 @@ const ModelDropdown = ({
|
||||
</h6>
|
||||
</div>
|
||||
<div className="-mr-2 flex gap-1">
|
||||
{!localEngines.includes(engine) && (
|
||||
{!isLocalEngine(engine) && (
|
||||
<SetupRemoteModel engine={engine} />
|
||||
)}
|
||||
{!showModel ? (
|
||||
@ -438,7 +437,7 @@ const ModelDropdown = ({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{engine === InferenceEngine.nitro &&
|
||||
{isLocalEngine(engine) &&
|
||||
!isDownloadALocalModel &&
|
||||
showModel &&
|
||||
!searchText.length && (
|
||||
@ -503,10 +502,7 @@ const ModelDropdown = ({
|
||||
{filteredDownloadedModels
|
||||
.filter((x) => x.engine === engine)
|
||||
.filter((y) => {
|
||||
if (
|
||||
localEngines.includes(y.engine) &&
|
||||
!searchText.length
|
||||
) {
|
||||
if (isLocalEngine(y.engine) && !searchText.length) {
|
||||
return downloadedModels.find((c) => c.id === y.id)
|
||||
} else {
|
||||
return y
|
||||
@ -530,10 +526,7 @@ const ModelDropdown = ({
|
||||
: 'text-[hsla(var(--text-primary))]'
|
||||
)}
|
||||
onClick={() => {
|
||||
if (
|
||||
!apiKey &&
|
||||
!localEngines.includes(model.engine)
|
||||
)
|
||||
if (!apiKey && !isLocalEngine(model.engine))
|
||||
return null
|
||||
if (isdDownloaded) {
|
||||
onClickModelItem(model.id)
|
||||
|
||||
@ -21,7 +21,7 @@ import { ulid } from 'ulidx'
|
||||
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
import { extractInferenceParams } from '@/utils/modelParam'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
@ -242,9 +242,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
}
|
||||
|
||||
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
|
||||
if (
|
||||
!localEngines.includes(activeModelRef.current?.engine as InferenceEngine)
|
||||
) {
|
||||
if (!isLocalEngine(activeModelRef.current?.engine as InferenceEngine)) {
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,
|
||||
|
||||
@ -8,7 +8,7 @@ import { SettingsIcon, PlusIcon } from 'lucide-react'
|
||||
|
||||
import { MainViewState } from '@/constants/screens'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||
@ -74,7 +74,7 @@ const SetupRemoteModel = ({ engine }: Props) => {
|
||||
)
|
||||
}
|
||||
|
||||
const apiKey = !localEngines.includes(engine)
|
||||
const apiKey = !isLocalEngine(engine)
|
||||
? extensionHasSettings.filter((x) => x.provider === engine)[0]?.apiKey
|
||||
.length > 1
|
||||
: true
|
||||
|
||||
298
web/helpers/atoms/Model.atom.test.ts
Normal file
298
web/helpers/atoms/Model.atom.test.ts
Normal file
@ -0,0 +1,298 @@
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import * as ModelAtoms from './Model.atom'
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
describe('Model.atom.ts', () => {
|
||||
let mockJotaiGet: jest.Mock
|
||||
let mockJotaiSet: jest.Mock
|
||||
|
||||
beforeEach(() => {
|
||||
mockJotaiGet = jest.fn()
|
||||
mockJotaiSet = jest.fn()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('stateModel', () => {
|
||||
it('should initialize with correct default values', () => {
|
||||
expect(ModelAtoms.stateModel.init).toEqual({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: '',
|
||||
})
|
||||
})
|
||||
})
|
||||
describe('activeAssistantModelAtom', () => {
|
||||
it('should initialize as undefined', () => {
|
||||
expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('selectedModelAtom', () => {
|
||||
it('should initialize as undefined', () => {
|
||||
expect(ModelAtoms.selectedModelAtom.init).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('showEngineListModelAtom', () => {
|
||||
it('should initialize as an empty array', () => {
|
||||
expect(ModelAtoms.showEngineListModelAtom.init).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('addDownloadingModelAtom', () => {
|
||||
it('should add downloading model', async () => {
|
||||
const { result: setAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.addDownloadingModelAtom)
|
||||
)
|
||||
const { result: getAtom } = renderHook(() =>
|
||||
useAtomValue(ModelAtoms.getDownloadingModelAtom)
|
||||
)
|
||||
act(() => {
|
||||
setAtom.current({ id: '1' } as any)
|
||||
})
|
||||
expect(getAtom.current).toEqual([{ id: '1' }])
|
||||
})
|
||||
})
|
||||
|
||||
describe('removeDownloadingModelAtom', () => {
|
||||
it('should remove downloading model', async () => {
|
||||
const { result: setAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.addDownloadingModelAtom)
|
||||
)
|
||||
const { result: removeAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.removeDownloadingModelAtom)
|
||||
)
|
||||
const { result: getAtom } = renderHook(() =>
|
||||
useAtomValue(ModelAtoms.getDownloadingModelAtom)
|
||||
)
|
||||
act(() => {
|
||||
setAtom.current({ id: '1' } as any)
|
||||
removeAtom.current('1')
|
||||
})
|
||||
expect(getAtom.current).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('removeDownloadedModelAtom', () => {
|
||||
it('should remove downloaded model', async () => {
|
||||
const { result: setAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.downloadedModelsAtom)
|
||||
)
|
||||
const { result: removeAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.removeDownloadedModelAtom)
|
||||
)
|
||||
const { result: getAtom } = renderHook(() =>
|
||||
useAtomValue(ModelAtoms.downloadedModelsAtom)
|
||||
)
|
||||
act(() => {
|
||||
setAtom.current([{ id: '1' }] as any)
|
||||
})
|
||||
expect(getAtom.current).toEqual([
|
||||
{
|
||||
id: '1',
|
||||
},
|
||||
])
|
||||
act(() => {
|
||||
removeAtom.current('1')
|
||||
})
|
||||
expect(getAtom.current).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('importingModelAtom', () => {
|
||||
afterEach(() => {
|
||||
jest.resetAllMocks()
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
it('should not update for non-existing import', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: updateAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
|
||||
)
|
||||
act(() => {
|
||||
importAtom.current[1]([])
|
||||
updateAtom.current('2', 50)
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([])
|
||||
})
|
||||
it('should update progress for existing import', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: updateAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([
|
||||
{ importId: '1', status: 'MODEL_SELECTED' },
|
||||
] as any)
|
||||
updateAtom.current('1', 50)
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'IMPORTING',
|
||||
percentage: 50,
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should not update with invalid data', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: updateAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([
|
||||
{ importId: '1', status: 'MODEL_SELECTED' },
|
||||
] as any)
|
||||
updateAtom.current('2', 50)
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'MODEL_SELECTED',
|
||||
},
|
||||
])
|
||||
})
|
||||
it('should update import error', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: errorAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.setImportingModelErrorAtom)
|
||||
)
|
||||
act(() => {
|
||||
importAtom.current[1]([
|
||||
{ importId: '1', status: 'IMPORTING', percentage: 50 },
|
||||
] as any)
|
||||
errorAtom.current('1', 'unknown')
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'FAILED',
|
||||
percentage: 50,
|
||||
},
|
||||
])
|
||||
})
|
||||
it('should not update import error on invalid import ID', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: errorAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.setImportingModelErrorAtom)
|
||||
)
|
||||
act(() => {
|
||||
importAtom.current[1]([
|
||||
{ importId: '1', status: 'IMPORTING', percentage: 50 },
|
||||
] as any)
|
||||
errorAtom.current('2', 'unknown')
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'IMPORTING',
|
||||
percentage: 50,
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should update import success', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: successAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
|
||||
successAtom.current('1', 'id')
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'IMPORTED',
|
||||
percentage: 1,
|
||||
modelId: 'id',
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should update with invalid import ID', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: successAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
|
||||
successAtom.current('2', 'id')
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
status: 'IMPORTING',
|
||||
},
|
||||
])
|
||||
})
|
||||
it('should not update with valid data', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: updateAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.updateImportingModelAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([
|
||||
{ importId: '1', status: 'IMPORTING', percentage: 1 },
|
||||
] as any)
|
||||
updateAtom.current('1', 'name', 'description', ['tag'])
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([
|
||||
{
|
||||
importId: '1',
|
||||
percentage: 1,
|
||||
status: 'IMPORTING',
|
||||
name: 'name',
|
||||
tags: ['tag'],
|
||||
description: 'description',
|
||||
},
|
||||
])
|
||||
})
|
||||
|
||||
it('should not update when there is no importing model', async () => {
|
||||
const { result: importAtom } = renderHook(() =>
|
||||
useAtom(ModelAtoms.importingModelsAtom)
|
||||
)
|
||||
const { result: updateAtom } = renderHook(() =>
|
||||
useSetAtom(ModelAtoms.updateImportingModelAtom)
|
||||
)
|
||||
|
||||
act(() => {
|
||||
importAtom.current[1]([])
|
||||
updateAtom.current('1', 'name', 'description', ['tag'])
|
||||
})
|
||||
expect(importAtom.current[0]).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('defaultModelAtom', () => {
|
||||
it('should initialize as undefined', () => {
|
||||
expect(ModelAtoms.defaultModelAtom.init).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,8 +1,6 @@
|
||||
import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
|
||||
import { ImportingModel, Model, ModelFile } from '@janhq/core'
|
||||
import { atom } from 'jotai'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
|
||||
export const stateModel = atom({ state: 'start', loading: false, model: '' })
|
||||
export const activeAssistantModelAtom = atom<Model | undefined>(undefined)
|
||||
|
||||
@ -135,4 +133,4 @@ export const updateImportingModelAtom = atom(
|
||||
|
||||
export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
|
||||
|
||||
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)
|
||||
export const showEngineListModelAtom = atom<string[]>([])
|
||||
|
||||
@ -2,7 +2,7 @@ import { useState, useEffect } from 'react'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
@ -13,7 +13,7 @@ export function useStarterScreen() {
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
|
||||
const isDownloadALocalModel = downloadedModels.some((x) =>
|
||||
localEngines.includes(x.engine)
|
||||
isLocalEngine(x.engine)
|
||||
)
|
||||
|
||||
const [extensionHasSettings, setExtensionHasSettings] = useState<
|
||||
|
||||
@ -16,7 +16,7 @@ import useDeleteModel from '@/hooks/useDeleteModel'
|
||||
|
||||
import { toGibibytes } from '@/utils/converter'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||
|
||||
@ -74,7 +74,7 @@ const MyModelList = ({ model }: Props) => {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{localEngines.includes(model.engine) && (
|
||||
{isLocalEngine(model.engine) && (
|
||||
<div className="flex gap-x-4">
|
||||
<div className="md:min-w-[90px] md:max-w-[90px]">
|
||||
<Badge theme="secondary" className="sm:mr-8">
|
||||
|
||||
@ -29,7 +29,7 @@ import { setImportModelStageAtom } from '@/hooks/useImportModel'
|
||||
import {
|
||||
getLogoEngine,
|
||||
getTitleByEngine,
|
||||
localEngines,
|
||||
isLocalEngine,
|
||||
priorityEngine,
|
||||
} from '@/utils/modelEngine'
|
||||
|
||||
@ -222,7 +222,7 @@ const MyModels = () => {
|
||||
</h6>
|
||||
</div>
|
||||
<div className="flex gap-1">
|
||||
{!localEngines.includes(engine) && (
|
||||
{!isLocalEngine(engine) && (
|
||||
<SetupRemoteModel engine={engine} />
|
||||
)}
|
||||
{!showModel ? (
|
||||
|
||||
@ -28,7 +28,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
|
||||
import {
|
||||
getLogoEngine,
|
||||
getTitleByEngine,
|
||||
localEngines,
|
||||
isLocalEngine,
|
||||
} from '@/utils/modelEngine'
|
||||
|
||||
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
|
||||
@ -74,13 +74,11 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
|
||||
}
|
||||
})
|
||||
|
||||
const remoteModel = configuredModels.filter(
|
||||
(x) => !localEngines.includes(x.engine)
|
||||
)
|
||||
const remoteModel = configuredModels.filter((x) => !isLocalEngine(x.engine))
|
||||
|
||||
const filteredModels = configuredModels.filter((model) => {
|
||||
return (
|
||||
localEngines.includes(model.engine) &&
|
||||
isLocalEngine(model.engine) &&
|
||||
model.name.toLowerCase().includes(searchValue.toLowerCase())
|
||||
)
|
||||
})
|
||||
|
||||
@ -24,7 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
|
||||
|
||||
import useSendChatMessage from '@/hooks/useSendChatMessage'
|
||||
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
|
||||
import FileUploadPreview from '../FileUploadPreview'
|
||||
import ImageUploadPreview from '../ImageUploadPreview'
|
||||
@ -130,7 +130,7 @@ const ChatInput = () => {
|
||||
|
||||
const isModelSupportRagAndTools =
|
||||
selectedModel?.engine === InferenceEngine.openai ||
|
||||
localEngines.includes(selectedModel?.engine as InferenceEngine)
|
||||
isLocalEngine(selectedModel?.engine as InferenceEngine)
|
||||
|
||||
/**
|
||||
* Handles the change event of the extension file input element by setting the file name state.
|
||||
|
||||
@ -28,7 +28,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
@ -63,7 +63,7 @@ const ThreadRightPanel = () => {
|
||||
|
||||
const isModelSupportRagAndTools =
|
||||
selectedModel?.engine === InferenceEngine.openai ||
|
||||
localEngines.includes(selectedModel?.engine as InferenceEngine)
|
||||
isLocalEngine(selectedModel?.engine as InferenceEngine)
|
||||
|
||||
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
|
||||
const { stopModel } = useActiveModel()
|
||||
|
||||
185
web/utils/modelEngine.test.ts
Normal file
185
web/utils/modelEngine.test.ts
Normal file
@ -0,0 +1,185 @@
|
||||
import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
|
||||
import {
|
||||
getTitleByEngine,
|
||||
isLocalEngine,
|
||||
priorityEngine,
|
||||
getLogoEngine,
|
||||
} from './modelEngine'
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core'),
|
||||
EngineManager: {
|
||||
instance: jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
}),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('isLocalEngine', () => {
|
||||
const mockEngineManagerInstance = EngineManager.instance()
|
||||
const mockGet = mockEngineManagerInstance.get as jest.Mock
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should return false if engine is not found', () => {
|
||||
mockGet.mockReturnValue(null)
|
||||
const result = isLocalEngine('nonexistentEngine')
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true if engine is an instance of LocalOAIEngine', () => {
|
||||
const mockEngineObj = {
|
||||
__proto__: {
|
||||
constructor: {
|
||||
__proto__: {
|
||||
name: LocalOAIEngine.name,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mockGet.mockReturnValue(mockEngineObj)
|
||||
const result = isLocalEngine('localEngine')
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false if engine is not an instance of LocalOAIEngine', () => {
|
||||
const mockEngineObj = {
|
||||
__proto__: {
|
||||
constructor: {
|
||||
__proto__: {
|
||||
name: 'SomeOtherEngine',
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mockGet.mockReturnValue(mockEngineObj)
|
||||
const result = isLocalEngine('someOtherEngine')
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core'),
|
||||
EngineManager: {
|
||||
instance: jest.fn().mockReturnValue({
|
||||
get: jest.fn(),
|
||||
}),
|
||||
},
|
||||
}))
|
||||
|
||||
describe('getTitleByEngine', () => {
|
||||
it('should return correct title for InferenceEngine.nitro', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.nitro)
|
||||
expect(result).toBe('Llama.cpp (Nitro)')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.nitro_tensorrt_llm', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.nitro_tensorrt_llm)
|
||||
expect(result).toBe('TensorRT-LLM (Nitro)')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.cortex_llamacpp', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.cortex_llamacpp)
|
||||
expect(result).toBe('Llama.cpp (Cortex)')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.cortex_onnx', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.cortex_onnx)
|
||||
expect(result).toBe('Onnx (Cortex)')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.cortex_tensorrtllm', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.cortex_tensorrtllm)
|
||||
expect(result).toBe('TensorRT-LLM (Cortex)')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.openai', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.openai)
|
||||
expect(result).toBe('OpenAI')
|
||||
})
|
||||
|
||||
it('should return correct title for InferenceEngine.openrouter', () => {
|
||||
const result = getTitleByEngine(InferenceEngine.openrouter)
|
||||
expect(result).toBe('OpenRouter')
|
||||
})
|
||||
|
||||
it('should return capitalized engine name for unknown engine', () => {
|
||||
const result = getTitleByEngine('unknownEngine' as InferenceEngine)
|
||||
expect(result).toBe('UnknownEngine')
|
||||
})
|
||||
})
|
||||
|
||||
describe('priorityEngine', () => {
|
||||
it('should contain the correct engines in the correct order', () => {
|
||||
expect(priorityEngine).toEqual([
|
||||
InferenceEngine.cortex_llamacpp,
|
||||
InferenceEngine.cortex_onnx,
|
||||
InferenceEngine.cortex_tensorrtllm,
|
||||
InferenceEngine.nitro,
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('getLogoEngine', () => {
|
||||
it('should return correct logo path for InferenceEngine.anthropic', () => {
|
||||
const result = getLogoEngine(InferenceEngine.anthropic)
|
||||
expect(result).toBe('images/ModelProvider/anthropic.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.nitro_tensorrt_llm', () => {
|
||||
const result = getLogoEngine(InferenceEngine.nitro_tensorrt_llm)
|
||||
expect(result).toBe('images/ModelProvider/nitro.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.cortex_llamacpp', () => {
|
||||
const result = getLogoEngine(InferenceEngine.cortex_llamacpp)
|
||||
expect(result).toBe('images/ModelProvider/cortex.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.mistral', () => {
|
||||
const result = getLogoEngine(InferenceEngine.mistral)
|
||||
expect(result).toBe('images/ModelProvider/mistral.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.martian', () => {
|
||||
const result = getLogoEngine(InferenceEngine.martian)
|
||||
expect(result).toBe('images/ModelProvider/martian.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.openrouter', () => {
|
||||
const result = getLogoEngine(InferenceEngine.openrouter)
|
||||
expect(result).toBe('images/ModelProvider/openRouter.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.openai', () => {
|
||||
const result = getLogoEngine(InferenceEngine.openai)
|
||||
expect(result).toBe('images/ModelProvider/openai.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.groq', () => {
|
||||
const result = getLogoEngine(InferenceEngine.groq)
|
||||
expect(result).toBe('images/ModelProvider/groq.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.triton_trtllm', () => {
|
||||
const result = getLogoEngine(InferenceEngine.triton_trtllm)
|
||||
expect(result).toBe('images/ModelProvider/triton_trtllm.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.cohere', () => {
|
||||
const result = getLogoEngine(InferenceEngine.cohere)
|
||||
expect(result).toBe('images/ModelProvider/cohere.svg')
|
||||
})
|
||||
|
||||
it('should return correct logo path for InferenceEngine.nvidia', () => {
|
||||
const result = getLogoEngine(InferenceEngine.nvidia)
|
||||
expect(result).toBe('images/ModelProvider/nvidia.svg')
|
||||
})
|
||||
|
||||
it('should return undefined for unknown engine', () => {
|
||||
const result = getLogoEngine('unknownEngine' as InferenceEngine)
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -1,4 +1,4 @@
|
||||
import { InferenceEngine } from '@janhq/core'
|
||||
import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
|
||||
|
||||
export const getLogoEngine = (engine: InferenceEngine) => {
|
||||
switch (engine) {
|
||||
@ -32,13 +32,19 @@ export const getLogoEngine = (engine: InferenceEngine) => {
|
||||
}
|
||||
}
|
||||
|
||||
export const localEngines = [
|
||||
InferenceEngine.nitro,
|
||||
InferenceEngine.nitro_tensorrt_llm,
|
||||
InferenceEngine.cortex_llamacpp,
|
||||
InferenceEngine.cortex_onnx,
|
||||
InferenceEngine.cortex_tensorrtllm,
|
||||
]
|
||||
/**
|
||||
* Check whether the engine is conform to LocalOAIEngine
|
||||
* @param engine
|
||||
* @returns
|
||||
*/
|
||||
export const isLocalEngine = (engine: string) => {
|
||||
const engineObj = EngineManager.instance().get(engine)
|
||||
if (!engineObj) return false
|
||||
return (
|
||||
Object.getPrototypeOf(engineObj).constructor.__proto__.name ===
|
||||
LocalOAIEngine.name
|
||||
)
|
||||
}
|
||||
|
||||
export const getTitleByEngine = (engine: InferenceEngine) => {
|
||||
switch (engine) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user