fix: tests

This commit is contained in:
Louis 2024-10-21 16:14:41 +07:00
parent 03e15fb70f
commit ba59425e6a
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
14 changed files with 83 additions and 111 deletions

View File

@ -44,48 +44,14 @@ describe('LocalOAIEngine', () => {
it('should load model correctly', async () => {
const model: Model = { engine: 'testProvider', file_path: 'path/to/model' } as any
const modelFolder = 'path/to'
const systemInfo = { os: 'testOS' }
const res = { error: null }
;(dirName as jest.Mock).mockResolvedValue(modelFolder)
;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
;(executeOnMain as jest.Mock).mockResolvedValue(res)
await engine.loadModel(model)
expect(systemInformation).toHaveBeenCalled()
expect(executeOnMain).toHaveBeenCalledWith(
engine.nodeModule,
engine.loadModelFunctionName,
{ modelFolder, model },
systemInfo
)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelReady, model)
})
it('should handle load model error', async () => {
const model: any = { engine: 'testProvider', file_path: 'path/to/model' } as any
const modelFolder = 'path/to'
const systemInfo = { os: 'testOS' }
const res = { error: 'load error' }
;(dirName as jest.Mock).mockResolvedValue(modelFolder)
;(systemInformation as jest.Mock).mockResolvedValue(systemInfo)
;(executeOnMain as jest.Mock).mockResolvedValue(res)
await expect(engine.loadModel(model)).rejects.toEqual('load error')
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelFail, { error: res.error })
expect(engine.loadModel(model)).toBeTruthy()
})
it('should unload model correctly', async () => {
const model: Model = { engine: 'testProvider' } as any
await engine.unloadModel(model)
expect(executeOnMain).toHaveBeenCalledWith(engine.nodeModule, engine.unloadModelFunctionName)
expect(events.emit).toHaveBeenCalledWith(ModelEvent.OnModelStopped, {})
expect(engine.unloadModel(model)).toBeTruthy()
})
it('should not unload model if engine does not match', async () => {

View File

@ -36,11 +36,6 @@ export abstract class LocalOAIEngine extends OAIEngine {
* Stops the model.
*/
override async unloadModel(model?: Model) {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
this.loadedModel = undefined
await executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
return Promise.resolve()
}
}

View File

@ -8,7 +8,8 @@ jest.mock('../../helper', () => ({
jest.mock('../../helper/path', () => ({
validatePath: jest.fn().mockReturnValue('path/to/folder'),
normalizeFilePath: () => process.platform === 'win32' ? 'C:\\Users\path\\to\\file.gguf' : '/Users/path/to/file.gguf',
normalizeFilePath: () =>
process.platform === 'win32' ? 'C:\\Users\\path\\to\\file.gguf' : '/Users/path/to/file.gguf',
}))
jest.mock(

View File

@ -31,7 +31,7 @@ export enum InferenceEngine {
cortex = 'cortex',
cortex_llamacpp = 'llama-cpp',
cortex_onnx = 'onnxruntime',
cortex_tensorrtllm = '.tensorrt-llm',
cortex_tensorrtllm = 'tensorrt-llm',
}
export type ModelArtifact = {

View File

@ -20,8 +20,8 @@ export default [
replace({
preventAssignment: true,
SETTINGS: JSON.stringify(settingJson),
API_URL: 'http://127.0.0.1:39291',
SOCKET_URL: 'ws://127.0.0.1:39291',
API_URL: JSON.stringify('http://127.0.0.1:39291'),
SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
}),
// Allow json resolution
json(),

View File

@ -32,13 +32,22 @@ describe('Model.atom.ts', () => {
})
describe('showEngineListModelAtom', () => {
it('should initialize as an empty array', () => {
expect(ModelAtoms.showEngineListModelAtom.init).toEqual(['nitro'])
it('should initialize with local engines', () => {
expect(ModelAtoms.showEngineListModelAtom.init).toEqual([
'nitro',
'cortex',
'llama-cpp',
'onnxruntime',
'tensorrt-llm',
])
})
})
describe('addDownloadingModelAtom', () => {
it('should add downloading model', async () => {
const { result: reset } = renderHook(() =>
useSetAtom(ModelAtoms.downloadingModelsAtom)
)
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.addDownloadingModelAtom)
)
@ -49,11 +58,16 @@ describe('Model.atom.ts', () => {
setAtom.current({ id: '1' } as any)
})
expect(getAtom.current).toEqual([{ id: '1' }])
reset.current([])
})
})
describe('removeDownloadingModelAtom', () => {
it('should remove downloading model', async () => {
const { result: reset } = renderHook(() =>
useSetAtom(ModelAtoms.downloadingModelsAtom)
)
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.addDownloadingModelAtom)
)
@ -63,16 +77,21 @@ describe('Model.atom.ts', () => {
const { result: getAtom } = renderHook(() =>
useAtomValue(ModelAtoms.getDownloadingModelAtom)
)
expect(getAtom.current).toEqual([])
act(() => {
setAtom.current({ id: '1' } as any)
setAtom.current('1')
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
reset.current([])
})
})
describe('removeDownloadedModelAtom', () => {
it('should remove downloaded model', async () => {
const { result: reset } = renderHook(() =>
useSetAtom(ModelAtoms.downloadingModelsAtom)
)
const { result: setAtom } = renderHook(() =>
useSetAtom(ModelAtoms.downloadedModelsAtom)
)
@ -94,6 +113,7 @@ describe('Model.atom.ts', () => {
removeAtom.current('1')
})
expect(getAtom.current).toEqual([])
reset.current([])
})
})
@ -284,10 +304,4 @@ describe('Model.atom.ts', () => {
expect(importAtom.current[0]).toEqual([])
})
})
describe('defaultModelAtom', () => {
it('should initialize as undefined', () => {
expect(ModelAtoms.defaultModelAtom.init).toBeUndefined()
})
})
})

View File

@ -64,13 +64,13 @@ export const stateModel = atom({ state: 'start', loading: false, model: '' })
/**
* Stores the list of models which are being downloaded.
*/
const downloadingModelsAtom = atom<string[]>([])
export const downloadingModelsAtom = atom<string[]>([])
export const getDownloadingModelAtom = atom((get) => get(downloadingModelsAtom))
export const addDownloadingModelAtom = atom(null, (get, set, model: string) => {
const downloadingModels = get(downloadingModelsAtom)
if (!downloadingModels.find((e) => e === model)) {
if (!downloadingModels.includes(model)) {
set(downloadingModelsAtom, [...downloadingModels, model])
}
})

View File

@ -35,7 +35,7 @@ describe('useDeleteModel', () => {
await result.current.deleteModel(mockModel)
})
expect(mockDeleteModel).toHaveBeenCalledWith(mockModel)
expect(mockDeleteModel).toHaveBeenCalledWith('test-model')
expect(toaster).toHaveBeenCalledWith({
title: 'Model Deletion Successful',
description: `Model ${mockModel.name} has been successfully deleted.`,
@ -67,7 +67,7 @@ describe('useDeleteModel', () => {
)
})
expect(mockDeleteModel).toHaveBeenCalledWith(mockModel)
expect(mockDeleteModel).toHaveBeenCalledWith("test-model")
expect(toaster).not.toHaveBeenCalled()
})
})

View File

@ -13,12 +13,6 @@ jest.mock('jotai', () => ({
}))
jest.mock('@janhq/core')
jest.mock('@/extension/ExtensionManager')
jest.mock('./useGpuSetting', () => ({
__esModule: true,
default: () => ({
getGpuSettings: jest.fn().mockResolvedValue({ some: 'gpuSettings' }),
}),
}))
describe('useDownloadModel', () => {
beforeEach(() => {
@ -29,25 +23,24 @@ describe('useDownloadModel', () => {
it('should download a model', async () => {
const mockModel: core.Model = {
id: 'test-model',
sources: [{ filename: 'test.bin' }],
sources: [{ filename: 'test.bin', url: 'https://fake.url' }],
} as core.Model
const mockExtension = {
downloadModel: jest.fn().mockResolvedValue(undefined),
pullModel: jest.fn().mockResolvedValue(undefined),
}
;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension)
const { result } = renderHook(() => useDownloadModel())
await act(async () => {
await result.current.downloadModel(mockModel)
act(() => {
result.current.downloadModel(mockModel.sources[0].url, mockModel.id)
})
expect(mockExtension.downloadModel).toHaveBeenCalledWith(
mockModel,
{ some: 'gpuSettings' },
{ ignoreSSL: undefined, proxy: '' }
expect(mockExtension.pullModel).toHaveBeenCalledWith(
mockModel.sources[0].url,
mockModel.id
)
})
@ -58,15 +51,18 @@ describe('useDownloadModel', () => {
} as core.Model
;(core.joinPath as jest.Mock).mockResolvedValue('/path/to/model/test.bin')
;(core.abortDownload as jest.Mock).mockResolvedValue(undefined)
const mockExtension = {
cancelModelPull: jest.fn().mockResolvedValue(undefined),
}
;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension)
const { result } = renderHook(() => useDownloadModel())
await act(async () => {
await result.current.abortModelDownload(mockModel)
act(() => {
result.current.abortModelDownload(mockModel.id)
})
expect(core.abortDownload).toHaveBeenCalledWith('/path/to/model/test.bin')
expect(mockExtension.cancelModelPull).toHaveBeenCalledWith('test-model')
})
it('should handle proxy settings', async () => {
@ -76,7 +72,7 @@ describe('useDownloadModel', () => {
} as core.Model
const mockExtension = {
downloadModel: jest.fn().mockResolvedValue(undefined),
pullModel: jest.fn().mockResolvedValue(undefined),
}
;(useSetAtom as jest.Mock).mockReturnValue(() => undefined)
;(extensionManager.get as jest.Mock).mockReturnValue(mockExtension)
@ -85,14 +81,13 @@ describe('useDownloadModel', () => {
const { result } = renderHook(() => useDownloadModel())
await act(async () => {
await result.current.downloadModel(mockModel)
act(() => {
result.current.downloadModel(mockModel.sources[0].url, mockModel.id)
})
expect(mockExtension.downloadModel).toHaveBeenCalledWith(
mockModel,
expect.objectContaining({ some: 'gpuSettings' }),
expect.anything()
expect(mockExtension.pullModel).toHaveBeenCalledWith(
mockModel.sources[0].url,
mockModel.id
)
})
})

View File

@ -1,6 +1,10 @@
/**
* @jest-environment jsdom
*/
import { renderHook, act } from '@testing-library/react'
import { useGetHFRepoData } from './useGetHFRepoData'
import { extensionManager } from '@/extension'
import * as hf from '@/utils/huggingface'
jest.mock('@/extension', () => ({
extensionManager: {
@ -8,6 +12,8 @@ jest.mock('@/extension', () => ({
},
}))
jest.mock('@/utils/huggingface')
describe('useGetHFRepoData', () => {
beforeEach(() => {
jest.clearAllMocks()
@ -15,10 +21,7 @@ describe('useGetHFRepoData', () => {
it('should fetch HF repo data successfully', async () => {
const mockData = { name: 'Test Repo', stars: 100 }
const mockFetchHuggingFaceRepoData = jest.fn().mockResolvedValue(mockData)
;(extensionManager.get as jest.Mock).mockReturnValue({
fetchHuggingFaceRepoData: mockFetchHuggingFaceRepoData,
})
;(hf.fetchHuggingFaceRepoData as jest.Mock).mockReturnValue(mockData)
const { result } = renderHook(() => useGetHFRepoData())
@ -34,6 +37,5 @@ describe('useGetHFRepoData', () => {
expect(result.current.error).toBeUndefined()
expect(await data).toEqual(mockData)
expect(mockFetchHuggingFaceRepoData).toHaveBeenCalledWith('test-repo')
})
})

View File

@ -18,7 +18,7 @@ describe('useImportModel', () => {
it('should import models successfully', async () => {
const mockImportModels = jest.fn().mockResolvedValue(undefined)
const mockExtension = {
importModels: mockImportModels,
importModel: mockImportModels,
} as any
jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension)
@ -26,15 +26,16 @@ describe('useImportModel', () => {
const { result } = renderHook(() => useImportModel())
const models = [
{ importId: '1', name: 'Model 1', path: '/path/to/model1' },
{ importId: '2', name: 'Model 2', path: '/path/to/model2' },
{ modelId: '1', path: '/path/to/model1' },
{ modelId: '2', path: '/path/to/model2' },
] as any
await act(async () => {
await result.current.importModels(models, 'local' as any)
})
expect(mockImportModels).toHaveBeenCalledWith(models, 'local')
expect(mockImportModels).toHaveBeenCalledWith('1', '/path/to/model1')
expect(mockImportModels).toHaveBeenCalledWith('2', '/path/to/model2')
})
it('should update model info successfully', async () => {
@ -42,7 +43,7 @@ describe('useImportModel', () => {
.fn()
.mockResolvedValue({ id: 'model-1', name: 'Updated Model' })
const mockExtension = {
updateModelInfo: mockUpdateModelInfo,
updateModel: mockUpdateModelInfo,
} as any
jest.spyOn(extensionManager, 'get').mockReturnValue(mockExtension)

View File

@ -103,6 +103,7 @@ const useImportModel = () => {
const localImportModels = async (
models: ImportingModel[],
// TODO: @louis - We will set this option when cortex.cpp supports it
optionType: OptionType
): Promise<void> => {
await models

View File

@ -1,7 +1,7 @@
// useModels.test.ts
import { renderHook, act } from '@testing-library/react'
import { events, ModelEvent } from '@janhq/core'
import { events, ModelEvent, ModelManager } from '@janhq/core'
import { extensionManager } from '@/extension'
// Mock dependencies
@ -11,18 +11,11 @@ jest.mock('@/extension')
import useModels from './useModels'
// Mock data
const mockDownloadedModels = [
const models = [
{ id: 'model-1', name: 'Model 1' },
{ id: 'model-2', name: 'Model 2' },
]
const mockConfiguredModels = [
{ id: 'model-3', name: 'Model 3' },
{ id: 'model-4', name: 'Model 4' },
]
const mockDefaultModel = { id: 'default-model', name: 'Default Model' }
describe('useModels', () => {
beforeEach(() => {
jest.clearAllMocks()
@ -30,20 +23,23 @@ describe('useModels', () => {
it('should fetch and set models on mount', async () => {
const mockModelExtension = {
getDownloadedModels: jest.fn().mockResolvedValue(mockDownloadedModels),
getConfiguredModels: jest.fn().mockResolvedValue(mockConfiguredModels),
getDefaultModel: jest.fn().mockResolvedValue(mockDefaultModel),
getModels: jest.fn().mockResolvedValue(models),
} as any
;(ModelManager.instance as jest.Mock).mockReturnValue({
models: {
values: () => ({
toArray: () => {},
}),
},
})
jest.spyOn(extensionManager, 'get').mockReturnValue(mockModelExtension)
await act(async () => {
act(() => {
renderHook(() => useModels())
})
expect(mockModelExtension.getDownloadedModels).toHaveBeenCalled()
expect(mockModelExtension.getConfiguredModels).toHaveBeenCalled()
expect(mockModelExtension.getDefaultModel).toHaveBeenCalled()
expect(mockModelExtension.getModels).toHaveBeenCalled()
})
it('should remove event listener on unmount', async () => {

View File

@ -15,12 +15,13 @@ import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
import { formatDownloadPercentage, toGibibytes } from '@/utils/converter'
import { normalizeModelId } from '@/utils/model'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { importHuggingFaceModelStageAtom } from '@/helpers/atoms/HuggingFace.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { normalizeModelId } from '@/utils/model'
type Props = {
index: number