diff --git a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx index e68f843a9..5e8549c7f 100644 --- a/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx +++ b/web/containers/Layout/BottomPanel/SystemMonitor/TableActiveModel/index.tsx @@ -6,7 +6,7 @@ import { useActiveModel } from '@/hooks/useActiveModel' import { toGibibytes } from '@/utils/converter' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' @@ -35,7 +35,7 @@ const TableActiveModel = () => { })} - {activeModel && localEngines.includes(activeModel.engine) ? ( + {activeModel && isLocalEngine(activeModel.engine) ? ( { engine: 'nitro', } - const configuredModels = [remoteModel, localModel] + const configuredModels = [remoteModel, localModel, localModel] const mockConfiguredModel = configuredModels const selectedModel = { id: 'selectedModel', name: 'selectedModel' } @@ -94,8 +94,20 @@ describe('ModelDropdown', () => { await waitFor(() => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() - expect(screen.getByText('On-device')) - expect(screen.getByText('Cloud')) + expect(screen.getByText('On-device')).toBeInTheDocument() + expect(screen.getByText('Cloud')).toBeInTheDocument() + }) + }) + + it('filters models correctly', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('model-selector')).toBeInTheDocument() + fireEvent.click(screen.getByText('Cloud')) + fireEvent.change(screen.getByText('Cloud'), { + target: { value: 'remote' }, + }) }) }) }) diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 2a0c4ffaf..9ebcf4fa2 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -40,7 +40,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' import { getLogoEngine, getTitleByEngine, - localEngines, + isLocalEngine, priorityEngine, } from '@/utils/modelEngine' @@ -101,7 +101,7 @@ const ModelDropdown = ({ const isModelSupportRagAndTools = useCallback((model: Model) => { return ( model?.engine === InferenceEngine.openai || - localEngines.includes(model?.engine as InferenceEngine) + isLocalEngine(model?.engine as InferenceEngine) ) }, []) @@ -113,10 +113,10 @@ const ModelDropdown = ({ ) .filter((e) => { if (searchFilter === 'local') { - return localEngines.includes(e.engine) + return isLocalEngine(e.engine) } if (searchFilter === 'remote') { - return !localEngines.includes(e.engine) + return !isLocalEngine(e.engine) } }) .sort((a, b) => a.name.localeCompare(b.name)) @@ -236,7 +236,6 @@ const ModelDropdown = ({ for (const extension of extensions) { if (typeof extension.getSettings === 'function') { const settings = await extension.getSettings() - if ( (settings && settings.length > 0) || (await extension.installationState()) !== 'NotRequired' @@ -295,7 +294,7 @@ const ModelDropdown = ({ }, [setShowEngineListModel, extensionHasSettings]) const isDownloadALocalModel = downloadedModels.some((x) => - localEngines.includes(x.engine) + isLocalEngine(x.engine) ) if (strictedThread && !activeThread) { @@ -377,7 +376,7 @@ const ModelDropdown = ({ {groupByEngine.map((engine, i) => { - const apiKey = !localEngines.includes(engine) + const apiKey = !isLocalEngine(engine) ? extensionHasSettings.filter((x) => x.provider === engine)[0] ?.apiKey.length > 1 : true @@ -417,7 +416,7 @@ const ModelDropdown = ({
- {!localEngines.includes(engine) && ( + {!isLocalEngine(engine) && ( )} {!showModel ? ( @@ -438,7 +437,7 @@ const ModelDropdown = ({
- {engine === InferenceEngine.nitro && + {isLocalEngine(engine) && !isDownloadALocalModel && showModel && !searchText.length && ( @@ -503,10 +502,7 @@ const ModelDropdown = ({ {filteredDownloadedModels .filter((x) => x.engine === engine) .filter((y) => { - if ( - localEngines.includes(y.engine) && - !searchText.length - ) { + if (isLocalEngine(y.engine) && !searchText.length) { return downloadedModels.find((c) => c.id === y.id) } else { return y @@ -530,10 +526,7 @@ const ModelDropdown = ({ : 'text-[hsla(var(--text-primary))]' )} onClick={() => { - if ( - !apiKey && - !localEngines.includes(model.engine) - ) + if (!apiKey && !isLocalEngine(model.engine)) return null if (isdDownloaded) { onClickModelItem(model.id) diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 1fbcd3919..5cc92219c 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -21,7 +21,7 @@ import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { extractInferenceParams } from '@/utils/modelParam' import { extensionManager } from '@/extension' @@ -242,9 +242,7 @@ export default function EventHandler({ children }: { children: ReactNode }) { } // Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp - if ( - !localEngines.includes(activeModelRef.current?.engine as InferenceEngine) - ) { + if (!isLocalEngine(activeModelRef.current?.engine as InferenceEngine)) { const updatedThread: Thread = { ...thread, title: (thread.metadata?.lastMessage as string) || defaultThreadTitle, diff --git a/web/containers/SetupRemoteModel/index.tsx b/web/containers/SetupRemoteModel/index.tsx index 914f240de..1f5478d73 100644 --- a/web/containers/SetupRemoteModel/index.tsx +++ b/web/containers/SetupRemoteModel/index.tsx @@ -8,7 +8,7 @@ import { SettingsIcon, PlusIcon } from 'lucide-react' import { MainViewState } from '@/constants/screens' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { extensionManager } from '@/extension' import { mainViewStateAtom } from '@/helpers/atoms/App.atom' @@ -74,7 +74,7 @@ const SetupRemoteModel = ({ engine }: Props) => { ) } - const apiKey = !localEngines.includes(engine) + const apiKey = !isLocalEngine(engine) ? extensionHasSettings.filter((x) => x.provider === engine)[0]?.apiKey .length > 1 : true diff --git a/web/helpers/atoms/Model.atom.test.ts b/web/helpers/atoms/Model.atom.test.ts new file mode 100644 index 000000000..36f2ce71c --- /dev/null +++ b/web/helpers/atoms/Model.atom.test.ts @@ -0,0 +1,298 @@ +import { act, renderHook, waitFor } from '@testing-library/react' +import * as ModelAtoms from './Model.atom' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' + +describe('Model.atom.ts', () => { + let mockJotaiGet: jest.Mock + let mockJotaiSet: jest.Mock + + beforeEach(() => { + mockJotaiGet = jest.fn() + mockJotaiSet = jest.fn() + }) + + afterEach(() => { + jest.clearAllMocks() + }) + + describe('stateModel', () => { + it('should initialize with correct default values', () => { + expect(ModelAtoms.stateModel.init).toEqual({ + state: 'start', + loading: false, + model: '', + }) + }) + }) + describe('activeAssistantModelAtom', () => { + it('should initialize as undefined', () => { + expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined() + }) + }) + + describe('selectedModelAtom', () => { + it('should initialize as undefined', () => { + expect(ModelAtoms.selectedModelAtom.init).toBeUndefined() + }) + }) + + describe('showEngineListModelAtom', () => { + it('should initialize as an empty array', () => { + expect(ModelAtoms.showEngineListModelAtom.init).toEqual([]) + }) + }) + + describe('addDownloadingModelAtom', () => { + it('should add downloading model', async () => { + const { result: setAtom } = renderHook(() => + useSetAtom(ModelAtoms.addDownloadingModelAtom) + ) + const { result: getAtom } = renderHook(() => + useAtomValue(ModelAtoms.getDownloadingModelAtom) + ) + act(() => { + setAtom.current({ id: '1' } as any) + }) + expect(getAtom.current).toEqual([{ id: '1' }]) + }) + }) + + describe('removeDownloadingModelAtom', () => { + it('should remove downloading model', async () => { + const { result: setAtom } = renderHook(() => + useSetAtom(ModelAtoms.addDownloadingModelAtom) + ) + const { result: removeAtom } = renderHook(() => + useSetAtom(ModelAtoms.removeDownloadingModelAtom) + ) + const { result: getAtom } = renderHook(() => + useAtomValue(ModelAtoms.getDownloadingModelAtom) + ) + act(() => { + setAtom.current({ id: '1' } as any) + removeAtom.current('1') + }) + expect(getAtom.current).toEqual([]) + }) + }) + + describe('removeDownloadedModelAtom', () => { + it('should remove downloaded model', async () => { + const { result: setAtom } = renderHook(() => + useSetAtom(ModelAtoms.downloadedModelsAtom) + ) + const { result: removeAtom } = renderHook(() => + useSetAtom(ModelAtoms.removeDownloadedModelAtom) + ) + const { result: getAtom } = renderHook(() => + useAtomValue(ModelAtoms.downloadedModelsAtom) + ) + act(() => { + setAtom.current([{ id: '1' }] as any) + }) + expect(getAtom.current).toEqual([ + { + id: '1', + }, + ]) + act(() => { + removeAtom.current('1') + }) + expect(getAtom.current).toEqual([]) + }) + }) + + describe('importingModelAtom', () => { + afterEach(() => { + jest.resetAllMocks() + jest.clearAllMocks() + }) + it('should not update for non-existing import', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: updateAtom } = renderHook(() => + useSetAtom(ModelAtoms.updateImportingModelProgressAtom) + ) + act(() => { + importAtom.current[1]([]) + updateAtom.current('2', 50) + }) + expect(importAtom.current[0]).toEqual([]) + }) + it('should update progress for existing import', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: updateAtom } = renderHook(() => + useSetAtom(ModelAtoms.updateImportingModelProgressAtom) + ) + + act(() => { + importAtom.current[1]([ + { importId: '1', status: 'MODEL_SELECTED' }, + ] as any) + updateAtom.current('1', 50) + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'IMPORTING', + percentage: 50, + }, + ]) + }) + + it('should not update with invalid data', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: updateAtom } = renderHook(() => + useSetAtom(ModelAtoms.updateImportingModelProgressAtom) + ) + + act(() => { + importAtom.current[1]([ + { importId: '1', status: 'MODEL_SELECTED' }, + ] as any) + updateAtom.current('2', 50) + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'MODEL_SELECTED', + }, + ]) + }) + it('should update import error', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: errorAtom } = renderHook(() => + useSetAtom(ModelAtoms.setImportingModelErrorAtom) + ) + act(() => { + importAtom.current[1]([ + { importId: '1', status: 'IMPORTING', percentage: 50 }, + ] as any) + errorAtom.current('1', 'unknown') + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'FAILED', + percentage: 50, + }, + ]) + }) + it('should not update import error on invalid import ID', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: errorAtom } = renderHook(() => + useSetAtom(ModelAtoms.setImportingModelErrorAtom) + ) + act(() => { + importAtom.current[1]([ + { importId: '1', status: 'IMPORTING', percentage: 50 }, + ] as any) + errorAtom.current('2', 'unknown') + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'IMPORTING', + percentage: 50, + }, + ]) + }) + + it('should update import success', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: successAtom } = renderHook(() => + useSetAtom(ModelAtoms.setImportingModelSuccessAtom) + ) + + act(() => { + importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any) + successAtom.current('1', 'id') + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'IMPORTED', + percentage: 1, + modelId: 'id', + }, + ]) + }) + + it('should update with invalid import ID', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: successAtom } = renderHook(() => + useSetAtom(ModelAtoms.setImportingModelSuccessAtom) + ) + + act(() => { + importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any) + successAtom.current('2', 'id') + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + status: 'IMPORTING', + }, + ]) + }) + it('should not update with valid data', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: updateAtom } = renderHook(() => + useSetAtom(ModelAtoms.updateImportingModelAtom) + ) + + act(() => { + importAtom.current[1]([ + { importId: '1', status: 'IMPORTING', percentage: 1 }, + ] as any) + updateAtom.current('1', 'name', 'description', ['tag']) + }) + expect(importAtom.current[0]).toEqual([ + { + importId: '1', + percentage: 1, + status: 'IMPORTING', + name: 'name', + tags: ['tag'], + description: 'description', + }, + ]) + }) + + it('should not update when there is no importing model', async () => { + const { result: importAtom } = renderHook(() => + useAtom(ModelAtoms.importingModelsAtom) + ) + const { result: updateAtom } = renderHook(() => + useSetAtom(ModelAtoms.updateImportingModelAtom) + ) + + act(() => { + importAtom.current[1]([]) + updateAtom.current('1', 'name', 'description', ['tag']) + }) + expect(importAtom.current[0]).toEqual([]) + }) + }) + + describe('defaultModelAtom', () => { + it('should initialize as undefined', () => { + expect(ModelAtoms.defaultModelAtom.init).toBeUndefined() + }) + }) +}) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index d2d0ca9f4..28a6384eb 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -1,8 +1,6 @@ -import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core' +import { ImportingModel, Model, ModelFile } from '@janhq/core' import { atom } from 'jotai' -import { localEngines } from '@/utils/modelEngine' - export const stateModel = atom({ state: 'start', loading: false, model: '' }) export const activeAssistantModelAtom = atom(undefined) @@ -135,4 +133,4 @@ export const updateImportingModelAtom = atom( export const selectedModelAtom = atom(undefined) -export const showEngineListModelAtom = atom(localEngines) +export const showEngineListModelAtom = atom([]) diff --git a/web/hooks/useStarterScreen.ts b/web/hooks/useStarterScreen.ts index 1a6bbfbc7..3305c0072 100644 --- a/web/hooks/useStarterScreen.ts +++ b/web/hooks/useStarterScreen.ts @@ -2,7 +2,7 @@ import { useState, useEffect } from 'react' import { useAtomValue } from 'jotai' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { extensionManager } from '@/extension' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' @@ -13,7 +13,7 @@ export function useStarterScreen() { const threads = useAtomValue(threadsAtom) const isDownloadALocalModel = downloadedModels.some((x) => - localEngines.includes(x.engine) + isLocalEngine(x.engine) ) const [extensionHasSettings, setExtensionHasSettings] = useState< diff --git a/web/screens/Settings/MyModels/MyModelList/index.tsx b/web/screens/Settings/MyModels/MyModelList/index.tsx index 329248923..c9ca6e867 100644 --- a/web/screens/Settings/MyModels/MyModelList/index.tsx +++ b/web/screens/Settings/MyModels/MyModelList/index.tsx @@ -16,7 +16,7 @@ import useDeleteModel from '@/hooks/useDeleteModel' import { toGibibytes } from '@/utils/converter' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' @@ -74,7 +74,7 @@ const MyModelList = ({ model }: Props) => { - {localEngines.includes(model.engine) && ( + {isLocalEngine(model.engine) && (
diff --git a/web/screens/Settings/MyModels/index.tsx b/web/screens/Settings/MyModels/index.tsx index 8dafd6e20..547e6153b 100644 --- a/web/screens/Settings/MyModels/index.tsx +++ b/web/screens/Settings/MyModels/index.tsx @@ -29,7 +29,7 @@ import { setImportModelStageAtom } from '@/hooks/useImportModel' import { getLogoEngine, getTitleByEngine, - localEngines, + isLocalEngine, priorityEngine, } from '@/utils/modelEngine' @@ -222,7 +222,7 @@ const MyModels = () => {
- {!localEngines.includes(engine) && ( + {!isLocalEngine(engine) && ( )} {!showModel ? ( diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx index 26036a627..b1e9d081a 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/OnDeviceStarterScreen/index.tsx @@ -28,7 +28,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter' import { getLogoEngine, getTitleByEngine, - localEngines, + isLocalEngine, } from '@/utils/modelEngine' import { mainViewStateAtom } from '@/helpers/atoms/App.atom' @@ -74,13 +74,11 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => { } }) - const remoteModel = configuredModels.filter( - (x) => !localEngines.includes(x.engine) - ) + const remoteModel = configuredModels.filter((x) => !isLocalEngine(x.engine)) const filteredModels = configuredModels.filter((model) => { return ( - localEngines.includes(model.engine) && + isLocalEngine(model.engine) && model.name.toLowerCase().includes(searchValue.toLowerCase()) ) }) diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index 235ebeae6..a7c5ad121 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -24,7 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel' import useSendChatMessage from '@/hooks/useSendChatMessage' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import FileUploadPreview from '../FileUploadPreview' import ImageUploadPreview from '../ImageUploadPreview' @@ -130,7 +130,7 @@ const ChatInput = () => { const isModelSupportRagAndTools = selectedModel?.engine === InferenceEngine.openai || - localEngines.includes(selectedModel?.engine as InferenceEngine) + isLocalEngine(selectedModel?.engine as InferenceEngine) /** * Handles the change event of the extension file input element by setting the file name state. diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 78119ba6d..027d1b0b6 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -28,7 +28,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { getConfigurationsData } from '@/utils/componentSettings' -import { localEngines } from '@/utils/modelEngine' +import { isLocalEngine } from '@/utils/modelEngine' import { extractInferenceParams, extractModelLoadParams, @@ -63,7 +63,7 @@ const ThreadRightPanel = () => { const isModelSupportRagAndTools = selectedModel?.engine === InferenceEngine.openai || - localEngines.includes(selectedModel?.engine as InferenceEngine) + isLocalEngine(selectedModel?.engine as InferenceEngine) const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) const { stopModel } = useActiveModel() diff --git a/web/utils/modelEngine.test.ts b/web/utils/modelEngine.test.ts new file mode 100644 index 000000000..738e04c2a --- /dev/null +++ b/web/utils/modelEngine.test.ts @@ -0,0 +1,185 @@ +import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core' +import { + getTitleByEngine, + isLocalEngine, + priorityEngine, + getLogoEngine, +} from './modelEngine' + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core'), + EngineManager: { + instance: jest.fn().mockReturnValue({ + get: jest.fn(), + }), + }, +})) + +describe('isLocalEngine', () => { + const mockEngineManagerInstance = EngineManager.instance() + const mockGet = mockEngineManagerInstance.get as jest.Mock + + beforeEach(() => { + jest.clearAllMocks() + }) + + it('should return false if engine is not found', () => { + mockGet.mockReturnValue(null) + const result = isLocalEngine('nonexistentEngine') + expect(result).toBe(false) + }) + + it('should return true if engine is an instance of LocalOAIEngine', () => { + const mockEngineObj = { + __proto__: { + constructor: { + __proto__: { + name: LocalOAIEngine.name, + }, + }, + }, + } + mockGet.mockReturnValue(mockEngineObj) + const result = isLocalEngine('localEngine') + expect(result).toBe(true) + }) + + it('should return false if engine is not an instance of LocalOAIEngine', () => { + const mockEngineObj = { + __proto__: { + constructor: { + __proto__: { + name: 'SomeOtherEngine', + }, + }, + }, + } + mockGet.mockReturnValue(mockEngineObj) + const result = isLocalEngine('someOtherEngine') + expect(result).toBe(false) + }) + + jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core'), + EngineManager: { + instance: jest.fn().mockReturnValue({ + get: jest.fn(), + }), + }, + })) + + describe('getTitleByEngine', () => { + it('should return correct title for InferenceEngine.nitro', () => { + const result = getTitleByEngine(InferenceEngine.nitro) + expect(result).toBe('Llama.cpp (Nitro)') + }) + + it('should return correct title for InferenceEngine.nitro_tensorrt_llm', () => { + const result = getTitleByEngine(InferenceEngine.nitro_tensorrt_llm) + expect(result).toBe('TensorRT-LLM (Nitro)') + }) + + it('should return correct title for InferenceEngine.cortex_llamacpp', () => { + const result = getTitleByEngine(InferenceEngine.cortex_llamacpp) + expect(result).toBe('Llama.cpp (Cortex)') + }) + + it('should return correct title for InferenceEngine.cortex_onnx', () => { + const result = getTitleByEngine(InferenceEngine.cortex_onnx) + expect(result).toBe('Onnx (Cortex)') + }) + + it('should return correct title for InferenceEngine.cortex_tensorrtllm', () => { + const result = getTitleByEngine(InferenceEngine.cortex_tensorrtllm) + expect(result).toBe('TensorRT-LLM (Cortex)') + }) + + it('should return correct title for InferenceEngine.openai', () => { + const result = getTitleByEngine(InferenceEngine.openai) + expect(result).toBe('OpenAI') + }) + + it('should return correct title for InferenceEngine.openrouter', () => { + const result = getTitleByEngine(InferenceEngine.openrouter) + expect(result).toBe('OpenRouter') + }) + + it('should return capitalized engine name for unknown engine', () => { + const result = getTitleByEngine('unknownEngine' as InferenceEngine) + expect(result).toBe('UnknownEngine') + }) + }) + + describe('priorityEngine', () => { + it('should contain the correct engines in the correct order', () => { + expect(priorityEngine).toEqual([ + InferenceEngine.cortex_llamacpp, + InferenceEngine.cortex_onnx, + InferenceEngine.cortex_tensorrtllm, + InferenceEngine.nitro, + ]) + }) + }) + + describe('getLogoEngine', () => { + it('should return correct logo path for InferenceEngine.anthropic', () => { + const result = getLogoEngine(InferenceEngine.anthropic) + expect(result).toBe('images/ModelProvider/anthropic.svg') + }) + + it('should return correct logo path for InferenceEngine.nitro_tensorrt_llm', () => { + const result = getLogoEngine(InferenceEngine.nitro_tensorrt_llm) + expect(result).toBe('images/ModelProvider/nitro.svg') + }) + + it('should return correct logo path for InferenceEngine.cortex_llamacpp', () => { + const result = getLogoEngine(InferenceEngine.cortex_llamacpp) + expect(result).toBe('images/ModelProvider/cortex.svg') + }) + + it('should return correct logo path for InferenceEngine.mistral', () => { + const result = getLogoEngine(InferenceEngine.mistral) + expect(result).toBe('images/ModelProvider/mistral.svg') + }) + + it('should return correct logo path for InferenceEngine.martian', () => { + const result = getLogoEngine(InferenceEngine.martian) + expect(result).toBe('images/ModelProvider/martian.svg') + }) + + it('should return correct logo path for InferenceEngine.openrouter', () => { + const result = getLogoEngine(InferenceEngine.openrouter) + expect(result).toBe('images/ModelProvider/openRouter.svg') + }) + + it('should return correct logo path for InferenceEngine.openai', () => { + const result = getLogoEngine(InferenceEngine.openai) + expect(result).toBe('images/ModelProvider/openai.svg') + }) + + it('should return correct logo path for InferenceEngine.groq', () => { + const result = getLogoEngine(InferenceEngine.groq) + expect(result).toBe('images/ModelProvider/groq.svg') + }) + + it('should return correct logo path for InferenceEngine.triton_trtllm', () => { + const result = getLogoEngine(InferenceEngine.triton_trtllm) + expect(result).toBe('images/ModelProvider/triton_trtllm.svg') + }) + + it('should return correct logo path for InferenceEngine.cohere', () => { + const result = getLogoEngine(InferenceEngine.cohere) + expect(result).toBe('images/ModelProvider/cohere.svg') + }) + + it('should return correct logo path for InferenceEngine.nvidia', () => { + const result = getLogoEngine(InferenceEngine.nvidia) + expect(result).toBe('images/ModelProvider/nvidia.svg') + }) + + it('should return undefined for unknown engine', () => { + const result = getLogoEngine('unknownEngine' as InferenceEngine) + expect(result).toBeUndefined() + }) + }) +}) diff --git a/web/utils/modelEngine.ts b/web/utils/modelEngine.ts index 3d132c5d5..33b3ec3e1 100644 --- a/web/utils/modelEngine.ts +++ b/web/utils/modelEngine.ts @@ -1,4 +1,4 @@ -import { InferenceEngine } from '@janhq/core' +import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core' export const getLogoEngine = (engine: InferenceEngine) => { switch (engine) { @@ -32,13 +32,19 @@ export const getLogoEngine = (engine: InferenceEngine) => { } } -export const localEngines = [ - InferenceEngine.nitro, - InferenceEngine.nitro_tensorrt_llm, - InferenceEngine.cortex_llamacpp, - InferenceEngine.cortex_onnx, - InferenceEngine.cortex_tensorrtllm, -] +/** + * Check whether the engine is conform to LocalOAIEngine + * @param engine + * @returns + */ +export const isLocalEngine = (engine: string) => { + const engineObj = EngineManager.instance().get(engine) + if (!engineObj) return false + return ( + Object.getPrototypeOf(engineObj).constructor.__proto__.name === + LocalOAIEngine.name + ) +} export const getTitleByEngine = (engine: InferenceEngine) => { switch (engine) {