fix: correct model dropdown for local models (#3736)

* fix: correct model dropdown for local models

* fix: clean unused import

* test: add Model.atom and model.Engine tests
This commit is contained in:
Louis 2024-09-30 11:58:55 +07:00 committed by GitHub
parent 8334076047
commit ba1ddacde3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 545 additions and 57 deletions

View File

@ -6,7 +6,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
import { toGibibytes } from '@/utils/converter'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
@ -35,7 +35,7 @@ const TableActiveModel = () => {
})}
</tr>
</thead>
{activeModel && localEngines.includes(activeModel.engine) ? (
{activeModel && isLocalEngine(activeModel.engine) ? (
<tbody>
<tr>
<td

View File

@ -1,5 +1,5 @@
import { render, screen, waitFor } from '@testing-library/react'
import { useAtomValue, useAtom, useSetAtom } from 'jotai'
import { render, screen, waitFor, fireEvent } from '@testing-library/react'
import { useAtomValue, useAtom } from 'jotai'
import ModelDropdown from './index'
import useRecommendedModel from '@/hooks/useRecommendedModel'
import '@testing-library/jest-dom'
@ -38,7 +38,7 @@ describe('ModelDropdown', () => {
engine: 'nitro',
}
const configuredModels = [remoteModel, localModel]
const configuredModels = [remoteModel, localModel, localModel]
const mockConfiguredModel = configuredModels
const selectedModel = { id: 'selectedModel', name: 'selectedModel' }
@ -94,8 +94,20 @@ describe('ModelDropdown', () => {
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
expect(screen.getByText('On-device'))
expect(screen.getByText('Cloud'))
expect(screen.getByText('On-device')).toBeInTheDocument()
expect(screen.getByText('Cloud')).toBeInTheDocument()
})
})
it('filters models correctly', async () => {
render(<ModelDropdown />)
await waitFor(() => {
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
fireEvent.click(screen.getByText('Cloud'))
fireEvent.change(screen.getByText('Cloud'), {
target: { value: 'remote' },
})
})
})
})

View File

@ -40,7 +40,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
import {
getLogoEngine,
getTitleByEngine,
localEngines,
isLocalEngine,
priorityEngine,
} from '@/utils/modelEngine'
@ -101,7 +101,7 @@ const ModelDropdown = ({
const isModelSupportRagAndTools = useCallback((model: Model) => {
return (
model?.engine === InferenceEngine.openai ||
localEngines.includes(model?.engine as InferenceEngine)
isLocalEngine(model?.engine as InferenceEngine)
)
}, [])
@ -113,10 +113,10 @@ const ModelDropdown = ({
)
.filter((e) => {
if (searchFilter === 'local') {
return localEngines.includes(e.engine)
return isLocalEngine(e.engine)
}
if (searchFilter === 'remote') {
return !localEngines.includes(e.engine)
return !isLocalEngine(e.engine)
}
})
.sort((a, b) => a.name.localeCompare(b.name))
@ -236,7 +236,6 @@ const ModelDropdown = ({
for (const extension of extensions) {
if (typeof extension.getSettings === 'function') {
const settings = await extension.getSettings()
if (
(settings && settings.length > 0) ||
(await extension.installationState()) !== 'NotRequired'
@ -295,7 +294,7 @@ const ModelDropdown = ({
}, [setShowEngineListModel, extensionHasSettings])
const isDownloadALocalModel = downloadedModels.some((x) =>
localEngines.includes(x.engine)
isLocalEngine(x.engine)
)
if (strictedThread && !activeThread) {
@ -377,7 +376,7 @@ const ModelDropdown = ({
</div>
<ScrollArea className="h-[calc(100%-90px)] w-full">
{groupByEngine.map((engine, i) => {
const apiKey = !localEngines.includes(engine)
const apiKey = !isLocalEngine(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]
?.apiKey.length > 1
: true
@ -417,7 +416,7 @@ const ModelDropdown = ({
</h6>
</div>
<div className="-mr-2 flex gap-1">
{!localEngines.includes(engine) && (
{!isLocalEngine(engine) && (
<SetupRemoteModel engine={engine} />
)}
{!showModel ? (
@ -438,7 +437,7 @@ const ModelDropdown = ({
</div>
</div>
{engine === InferenceEngine.nitro &&
{isLocalEngine(engine) &&
!isDownloadALocalModel &&
showModel &&
!searchText.length && (
@ -503,10 +502,7 @@ const ModelDropdown = ({
{filteredDownloadedModels
.filter((x) => x.engine === engine)
.filter((y) => {
if (
localEngines.includes(y.engine) &&
!searchText.length
) {
if (isLocalEngine(y.engine) && !searchText.length) {
return downloadedModels.find((c) => c.id === y.id)
} else {
return y
@ -530,10 +526,7 @@ const ModelDropdown = ({
: 'text-[hsla(var(--text-primary))]'
)}
onClick={() => {
if (
!apiKey &&
!localEngines.includes(model.engine)
)
if (!apiKey && !isLocalEngine(model.engine))
return null
if (isdDownloaded) {
onClickModelItem(model.id)

View File

@ -21,7 +21,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
@ -242,9 +242,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
}
// Check model engine; we don't want to generate a title when it's not a local engine. remote model using first promp
if (
!localEngines.includes(activeModelRef.current?.engine as InferenceEngine)
) {
if (!isLocalEngine(activeModelRef.current?.engine as InferenceEngine)) {
const updatedThread: Thread = {
...thread,
title: (thread.metadata?.lastMessage as string) || defaultThreadTitle,

View File

@ -8,7 +8,7 @@ import { SettingsIcon, PlusIcon } from 'lucide-react'
import { MainViewState } from '@/constants/screens'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
@ -74,7 +74,7 @@ const SetupRemoteModel = ({ engine }: Props) => {
)
}
const apiKey = !localEngines.includes(engine)
const apiKey = !isLocalEngine(engine)
? extensionHasSettings.filter((x) => x.provider === engine)[0]?.apiKey
.length > 1
: true

View File

@ -0,0 +1,298 @@
import { act, renderHook, waitFor } from '@testing-library/react'
import * as ModelAtoms from './Model.atom'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
describe('Model.atom.ts', () => {
let mockJotaiGet: jest.Mock
let mockJotaiSet: jest.Mock
beforeEach(() => {
mockJotaiGet = jest.fn()
mockJotaiSet = jest.fn()
})
afterEach(() => {
jest.clearAllMocks()
})
describe('stateModel', () => {
it('should initialize with correct default values', () => {
expect(ModelAtoms.stateModel.init).toEqual({
state: 'start',
loading: false,
model: '',
})
})
})
describe('activeAssistantModelAtom', () => {
it('should initialize as undefined', () => {
expect(ModelAtoms.activeAssistantModelAtom.init).toBeUndefined()
})
})
describe('selectedModelAtom', () => {
it('should initialize as undefined', () => {
expect(ModelAtoms.selectedModelAtom.init).toBeUndefined()
})
})
describe('showEngineListModelAtom', () => {
it('should initialize as an empty array', () => {
expect(ModelAtoms.showEngineListModelAtom.init).toEqual([])
})
})
describe('addDownloadingModelAtom', () => {
it('should add downloading model', async () => {
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.addDownloadingModelAtom)
)
const { result: getAtom } = renderHook(() =>
useAtomValue(ModelAtoms.getDownloadingModelAtom)
)
act(() => {
setAtom.current({ id: '1' } as any)
})
expect(getAtom.current).toEqual([{ id: '1' }])
})
})
describe('removeDownloadingModelAtom', () => {
it('should remove downloading model', async () => {
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.addDownloadingModelAtom)
)
const { result: removeAtom } = renderHook(() =>
useSetAtom(ModelAtoms.removeDownloadingModelAtom)
)
const { result: getAtom } = renderHook(() =>
useAtomValue(ModelAtoms.getDownloadingModelAtom)
)
act(() => {
setAtom.current({ id: '1' } as any)
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
})
})
describe('removeDownloadedModelAtom', () => {
it('should remove downloaded model', async () => {
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.downloadedModelsAtom)
)
const { result: removeAtom } = renderHook(() =>
useSetAtom(ModelAtoms.removeDownloadedModelAtom)
)
const { result: getAtom } = renderHook(() =>
useAtomValue(ModelAtoms.downloadedModelsAtom)
)
act(() => {
setAtom.current([{ id: '1' }] as any)
})
expect(getAtom.current).toEqual([
{
id: '1',
},
])
act(() => {
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
})
})
describe('importingModelAtom', () => {
afterEach(() => {
jest.resetAllMocks()
jest.clearAllMocks()
})
it('should not update for non-existing import', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: updateAtom } = renderHook(() =>
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
)
act(() => {
importAtom.current[1]([])
updateAtom.current('2', 50)
})
expect(importAtom.current[0]).toEqual([])
})
it('should update progress for existing import', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: updateAtom } = renderHook(() =>
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
)
act(() => {
importAtom.current[1]([
{ importId: '1', status: 'MODEL_SELECTED' },
] as any)
updateAtom.current('1', 50)
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'IMPORTING',
percentage: 50,
},
])
})
it('should not update with invalid data', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: updateAtom } = renderHook(() =>
useSetAtom(ModelAtoms.updateImportingModelProgressAtom)
)
act(() => {
importAtom.current[1]([
{ importId: '1', status: 'MODEL_SELECTED' },
] as any)
updateAtom.current('2', 50)
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'MODEL_SELECTED',
},
])
})
it('should update import error', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: errorAtom } = renderHook(() =>
useSetAtom(ModelAtoms.setImportingModelErrorAtom)
)
act(() => {
importAtom.current[1]([
{ importId: '1', status: 'IMPORTING', percentage: 50 },
] as any)
errorAtom.current('1', 'unknown')
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'FAILED',
percentage: 50,
},
])
})
it('should not update import error on invalid import ID', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: errorAtom } = renderHook(() =>
useSetAtom(ModelAtoms.setImportingModelErrorAtom)
)
act(() => {
importAtom.current[1]([
{ importId: '1', status: 'IMPORTING', percentage: 50 },
] as any)
errorAtom.current('2', 'unknown')
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'IMPORTING',
percentage: 50,
},
])
})
it('should update import success', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: successAtom } = renderHook(() =>
useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
)
act(() => {
importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
successAtom.current('1', 'id')
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'IMPORTED',
percentage: 1,
modelId: 'id',
},
])
})
it('should update with invalid import ID', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: successAtom } = renderHook(() =>
useSetAtom(ModelAtoms.setImportingModelSuccessAtom)
)
act(() => {
importAtom.current[1]([{ importId: '1', status: 'IMPORTING' }] as any)
successAtom.current('2', 'id')
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
status: 'IMPORTING',
},
])
})
it('should not update with valid data', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: updateAtom } = renderHook(() =>
useSetAtom(ModelAtoms.updateImportingModelAtom)
)
act(() => {
importAtom.current[1]([
{ importId: '1', status: 'IMPORTING', percentage: 1 },
] as any)
updateAtom.current('1', 'name', 'description', ['tag'])
})
expect(importAtom.current[0]).toEqual([
{
importId: '1',
percentage: 1,
status: 'IMPORTING',
name: 'name',
tags: ['tag'],
description: 'description',
},
])
})
it('should not update when there is no importing model', async () => {
const { result: importAtom } = renderHook(() =>
useAtom(ModelAtoms.importingModelsAtom)
)
const { result: updateAtom } = renderHook(() =>
useSetAtom(ModelAtoms.updateImportingModelAtom)
)
act(() => {
importAtom.current[1]([])
updateAtom.current('1', 'name', 'description', ['tag'])
})
expect(importAtom.current[0]).toEqual([])
})
})
describe('defaultModelAtom', () => {
it('should initialize as undefined', () => {
expect(ModelAtoms.defaultModelAtom.init).toBeUndefined()
})
})
})

View File

@ -1,8 +1,6 @@
import { ImportingModel, Model, InferenceEngine, ModelFile } from '@janhq/core'
import { ImportingModel, Model, ModelFile } from '@janhq/core'
import { atom } from 'jotai'
import { localEngines } from '@/utils/modelEngine'
export const stateModel = atom({ state: 'start', loading: false, model: '' })
export const activeAssistantModelAtom = atom<Model | undefined>(undefined)
@ -135,4 +133,4 @@ export const updateImportingModelAtom = atom(
export const selectedModelAtom = atom<ModelFile | undefined>(undefined)
export const showEngineListModelAtom = atom<InferenceEngine[]>(localEngines)
export const showEngineListModelAtom = atom<string[]>([])

View File

@ -2,7 +2,7 @@ import { useState, useEffect } from 'react'
import { useAtomValue } from 'jotai'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import { extensionManager } from '@/extension'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
@ -13,7 +13,7 @@ export function useStarterScreen() {
const threads = useAtomValue(threadsAtom)
const isDownloadALocalModel = downloadedModels.some((x) =>
localEngines.includes(x.engine)
isLocalEngine(x.engine)
)
const [extensionHasSettings, setExtensionHasSettings] = useState<

View File

@ -16,7 +16,7 @@ import useDeleteModel from '@/hooks/useDeleteModel'
import { toGibibytes } from '@/utils/converter'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
@ -74,7 +74,7 @@ const MyModelList = ({ model }: Props) => {
</div>
</div>
{localEngines.includes(model.engine) && (
{isLocalEngine(model.engine) && (
<div className="flex gap-x-4">
<div className="md:min-w-[90px] md:max-w-[90px]">
<Badge theme="secondary" className="sm:mr-8">

View File

@ -29,7 +29,7 @@ import { setImportModelStageAtom } from '@/hooks/useImportModel'
import {
getLogoEngine,
getTitleByEngine,
localEngines,
isLocalEngine,
priorityEngine,
} from '@/utils/modelEngine'
@ -222,7 +222,7 @@ const MyModels = () => {
</h6>
</div>
<div className="flex gap-1">
{!localEngines.includes(engine) && (
{!isLocalEngine(engine) && (
<SetupRemoteModel engine={engine} />
)}
{!showModel ? (

View File

@ -28,7 +28,7 @@ import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
import {
getLogoEngine,
getTitleByEngine,
localEngines,
isLocalEngine,
} from '@/utils/modelEngine'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
@ -74,13 +74,11 @@ const OnDeviceStarterScreen = ({ extensionHasSettings }: Props) => {
}
})
const remoteModel = configuredModels.filter(
(x) => !localEngines.includes(x.engine)
)
const remoteModel = configuredModels.filter((x) => !isLocalEngine(x.engine))
const filteredModels = configuredModels.filter((model) => {
return (
localEngines.includes(model.engine) &&
isLocalEngine(model.engine) &&
model.name.toLowerCase().includes(searchValue.toLowerCase())
)
})

View File

@ -24,7 +24,7 @@ import { useActiveModel } from '@/hooks/useActiveModel'
import useSendChatMessage from '@/hooks/useSendChatMessage'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import FileUploadPreview from '../FileUploadPreview'
import ImageUploadPreview from '../ImageUploadPreview'
@ -130,7 +130,7 @@ const ChatInput = () => {
const isModelSupportRagAndTools =
selectedModel?.engine === InferenceEngine.openai ||
localEngines.includes(selectedModel?.engine as InferenceEngine)
isLocalEngine(selectedModel?.engine as InferenceEngine)
/**
* Handles the change event of the extension file input element by setting the file name state.

View File

@ -28,7 +28,7 @@ import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
import { localEngines } from '@/utils/modelEngine'
import { isLocalEngine } from '@/utils/modelEngine'
import {
extractInferenceParams,
extractModelLoadParams,
@ -63,7 +63,7 @@ const ThreadRightPanel = () => {
const isModelSupportRagAndTools =
selectedModel?.engine === InferenceEngine.openai ||
localEngines.includes(selectedModel?.engine as InferenceEngine)
isLocalEngine(selectedModel?.engine as InferenceEngine)
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const { stopModel } = useActiveModel()

View File

@ -0,0 +1,185 @@
import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
import {
getTitleByEngine,
isLocalEngine,
priorityEngine,
getLogoEngine,
} from './modelEngine'
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core'),
EngineManager: {
instance: jest.fn().mockReturnValue({
get: jest.fn(),
}),
},
}))
describe('isLocalEngine', () => {
const mockEngineManagerInstance = EngineManager.instance()
const mockGet = mockEngineManagerInstance.get as jest.Mock
beforeEach(() => {
jest.clearAllMocks()
})
it('should return false if engine is not found', () => {
mockGet.mockReturnValue(null)
const result = isLocalEngine('nonexistentEngine')
expect(result).toBe(false)
})
it('should return true if engine is an instance of LocalOAIEngine', () => {
const mockEngineObj = {
__proto__: {
constructor: {
__proto__: {
name: LocalOAIEngine.name,
},
},
},
}
mockGet.mockReturnValue(mockEngineObj)
const result = isLocalEngine('localEngine')
expect(result).toBe(true)
})
it('should return false if engine is not an instance of LocalOAIEngine', () => {
const mockEngineObj = {
__proto__: {
constructor: {
__proto__: {
name: 'SomeOtherEngine',
},
},
},
}
mockGet.mockReturnValue(mockEngineObj)
const result = isLocalEngine('someOtherEngine')
expect(result).toBe(false)
})
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core'),
EngineManager: {
instance: jest.fn().mockReturnValue({
get: jest.fn(),
}),
},
}))
describe('getTitleByEngine', () => {
it('should return correct title for InferenceEngine.nitro', () => {
const result = getTitleByEngine(InferenceEngine.nitro)
expect(result).toBe('Llama.cpp (Nitro)')
})
it('should return correct title for InferenceEngine.nitro_tensorrt_llm', () => {
const result = getTitleByEngine(InferenceEngine.nitro_tensorrt_llm)
expect(result).toBe('TensorRT-LLM (Nitro)')
})
it('should return correct title for InferenceEngine.cortex_llamacpp', () => {
const result = getTitleByEngine(InferenceEngine.cortex_llamacpp)
expect(result).toBe('Llama.cpp (Cortex)')
})
it('should return correct title for InferenceEngine.cortex_onnx', () => {
const result = getTitleByEngine(InferenceEngine.cortex_onnx)
expect(result).toBe('Onnx (Cortex)')
})
it('should return correct title for InferenceEngine.cortex_tensorrtllm', () => {
const result = getTitleByEngine(InferenceEngine.cortex_tensorrtllm)
expect(result).toBe('TensorRT-LLM (Cortex)')
})
it('should return correct title for InferenceEngine.openai', () => {
const result = getTitleByEngine(InferenceEngine.openai)
expect(result).toBe('OpenAI')
})
it('should return correct title for InferenceEngine.openrouter', () => {
const result = getTitleByEngine(InferenceEngine.openrouter)
expect(result).toBe('OpenRouter')
})
it('should return capitalized engine name for unknown engine', () => {
const result = getTitleByEngine('unknownEngine' as InferenceEngine)
expect(result).toBe('UnknownEngine')
})
})
describe('priorityEngine', () => {
it('should contain the correct engines in the correct order', () => {
expect(priorityEngine).toEqual([
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_onnx,
InferenceEngine.cortex_tensorrtllm,
InferenceEngine.nitro,
])
})
})
describe('getLogoEngine', () => {
it('should return correct logo path for InferenceEngine.anthropic', () => {
const result = getLogoEngine(InferenceEngine.anthropic)
expect(result).toBe('images/ModelProvider/anthropic.svg')
})
it('should return correct logo path for InferenceEngine.nitro_tensorrt_llm', () => {
const result = getLogoEngine(InferenceEngine.nitro_tensorrt_llm)
expect(result).toBe('images/ModelProvider/nitro.svg')
})
it('should return correct logo path for InferenceEngine.cortex_llamacpp', () => {
const result = getLogoEngine(InferenceEngine.cortex_llamacpp)
expect(result).toBe('images/ModelProvider/cortex.svg')
})
it('should return correct logo path for InferenceEngine.mistral', () => {
const result = getLogoEngine(InferenceEngine.mistral)
expect(result).toBe('images/ModelProvider/mistral.svg')
})
it('should return correct logo path for InferenceEngine.martian', () => {
const result = getLogoEngine(InferenceEngine.martian)
expect(result).toBe('images/ModelProvider/martian.svg')
})
it('should return correct logo path for InferenceEngine.openrouter', () => {
const result = getLogoEngine(InferenceEngine.openrouter)
expect(result).toBe('images/ModelProvider/openRouter.svg')
})
it('should return correct logo path for InferenceEngine.openai', () => {
const result = getLogoEngine(InferenceEngine.openai)
expect(result).toBe('images/ModelProvider/openai.svg')
})
it('should return correct logo path for InferenceEngine.groq', () => {
const result = getLogoEngine(InferenceEngine.groq)
expect(result).toBe('images/ModelProvider/groq.svg')
})
it('should return correct logo path for InferenceEngine.triton_trtllm', () => {
const result = getLogoEngine(InferenceEngine.triton_trtllm)
expect(result).toBe('images/ModelProvider/triton_trtllm.svg')
})
it('should return correct logo path for InferenceEngine.cohere', () => {
const result = getLogoEngine(InferenceEngine.cohere)
expect(result).toBe('images/ModelProvider/cohere.svg')
})
it('should return correct logo path for InferenceEngine.nvidia', () => {
const result = getLogoEngine(InferenceEngine.nvidia)
expect(result).toBe('images/ModelProvider/nvidia.svg')
})
it('should return undefined for unknown engine', () => {
const result = getLogoEngine('unknownEngine' as InferenceEngine)
expect(result).toBeUndefined()
})
})
})

View File

@ -1,4 +1,4 @@
import { InferenceEngine } from '@janhq/core'
import { EngineManager, InferenceEngine, LocalOAIEngine } from '@janhq/core'
export const getLogoEngine = (engine: InferenceEngine) => {
switch (engine) {
@ -32,13 +32,19 @@ export const getLogoEngine = (engine: InferenceEngine) => {
}
}
export const localEngines = [
InferenceEngine.nitro,
InferenceEngine.nitro_tensorrt_llm,
InferenceEngine.cortex_llamacpp,
InferenceEngine.cortex_onnx,
InferenceEngine.cortex_tensorrtllm,
]
/**
* Check whether the engine is conform to LocalOAIEngine
* @param engine
* @returns
*/
export const isLocalEngine = (engine: string) => {
const engineObj = EngineManager.instance().get(engine)
if (!engineObj) return false
return (
Object.getPrototypeOf(engineObj).constructor.__proto__.name ===
LocalOAIEngine.name
)
}
export const getTitleByEngine = (engine: InferenceEngine) => {
switch (engine) {