From 8334076047f83f1a5e77e7b288530ab2dc8e984a Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 30 Sep 2024 11:58:46 +0700 Subject: [PATCH] fix: #3491 - Unable to use tensorrt-llm (#3741) * fix: #3491 - Unable to use tensorrt-llm * fix: abortModelDownload input type --- .../tensorrt-llm-extension/jest.config.js | 9 + .../tensorrt-llm-extension/package.json | 8 +- .../tensorrt-llm-extension/rollup.config.ts | 4 +- .../tensorrt-llm-extension/src/index.test.ts | 186 ++++++++++++++++++ .../tensorrt-llm-extension/src/index.ts | 3 +- .../tensorrt-llm-extension/tsconfig.json | 3 +- web/hooks/useDownloadModel.ts | 9 +- 7 files changed, 214 insertions(+), 8 deletions(-) create mode 100644 extensions/tensorrt-llm-extension/jest.config.js create mode 100644 extensions/tensorrt-llm-extension/src/index.test.ts diff --git a/extensions/tensorrt-llm-extension/jest.config.js b/extensions/tensorrt-llm-extension/jest.config.js new file mode 100644 index 000000000..3e32adceb --- /dev/null +++ b/extensions/tensorrt-llm-extension/jest.config.js @@ -0,0 +1,9 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + transform: { + 'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest', + }, + transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'], +} diff --git a/extensions/tensorrt-llm-extension/package.json b/extensions/tensorrt-llm-extension/package.json index c5cb54809..7a7ef6ef0 100644 --- a/extensions/tensorrt-llm-extension/package.json +++ b/extensions/tensorrt-llm-extension/package.json @@ -22,6 +22,7 @@ "tensorrtVersion": "0.1.8", "provider": "nitro-tensorrt-llm", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts", "build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", @@ -49,7 +50,12 @@ "rollup-plugin-sourcemaps": "^0.6.3", "rollup-plugin-typescript2": "^0.36.0", "run-script-os": "^1.1.6", - "typescript": "^5.2.2" + "typescript": "^5.2.2", + "@types/jest": "^29.5.12", + "jest": "^29.7.0", + "jest-junit": "^16.0.0", + "jest-runner": "^29.7.0", + "ts-jest": "^29.2.5" }, "dependencies": { "@janhq/core": "file:../../core", diff --git a/extensions/tensorrt-llm-extension/rollup.config.ts b/extensions/tensorrt-llm-extension/rollup.config.ts index 1fad0e711..50b4350e7 100644 --- a/extensions/tensorrt-llm-extension/rollup.config.ts +++ b/extensions/tensorrt-llm-extension/rollup.config.ts @@ -23,10 +23,10 @@ export default [ DOWNLOAD_RUNNER_URL: process.platform === 'win32' ? JSON.stringify( - 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' + 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v-tensorrt-llm-v0.7.1/nitro-windows-v-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz' ) : JSON.stringify( - 'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' + 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v/nitro-linux-v-amd64-tensorrt-llm-.tar.gz' ), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), INFERENCE_URL: JSON.stringify( diff --git a/extensions/tensorrt-llm-extension/src/index.test.ts b/extensions/tensorrt-llm-extension/src/index.test.ts new file mode 100644 index 000000000..48d6e71d7 --- /dev/null +++ b/extensions/tensorrt-llm-extension/src/index.test.ts @@ -0,0 +1,186 @@ +import TensorRTLLMExtension from '../src/index' +import { + executeOnMain, + systemInformation, + fs, + baseName, + joinPath, + downloadFile, +} from '@janhq/core' + +jest.mock('@janhq/core', () => ({ + ...jest.requireActual('@janhq/core/node'), + LocalOAIEngine: jest.fn().mockImplementation(function () { + // @ts-ignore + this.registerModels = () => { + return Promise.resolve() + } + // @ts-ignore + return this + }), + systemInformation: jest.fn(), + fs: { + existsSync: jest.fn(), + mkdir: jest.fn(), + }, + joinPath: jest.fn(), + baseName: jest.fn(), + downloadFile: jest.fn(), + executeOnMain: jest.fn(), + showToast: jest.fn(), + events: { + emit: jest.fn(), + // @ts-ignore + on: (event, func) => { + func({ fileName: './' }) + }, + off: jest.fn(), + }, +})) + +// @ts-ignore +global.COMPATIBILITY = { + platform: ['win32'], +} +// @ts-ignore +global.PROVIDER = 'tensorrt-llm' +// @ts-ignore +global.INFERENCE_URL = 'http://localhost:5000' +// @ts-ignore +global.NODE = 'node' +// @ts-ignore +global.MODELS = [] +// @ts-ignore +global.TENSORRT_VERSION = '' +// @ts-ignore +global.DOWNLOAD_RUNNER_URL = '' + +describe('TensorRTLLMExtension', () => { + let extension: TensorRTLLMExtension + + beforeEach(() => { + // @ts-ignore + extension = new TensorRTLLMExtension() + jest.clearAllMocks() + }) + + describe('compatibility', () => { + it('should return the correct compatibility', () => { + const result = extension.compatibility() + expect(result).toEqual({ + platform: ['win32'], + }) + }) + }) + + describe('install', () => { + it('should install if compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(executeOnMain as jest.Mock).mockResolvedValue({}) + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(false) + ;(fs.mkdir as jest.Mock).mockResolvedValue(undefined) + ;(baseName as jest.Mock).mockResolvedValue('./') + ;(joinPath as jest.Mock).mockResolvedValue('./') + ;(downloadFile as jest.Mock).mockResolvedValue({}) + + await extension.install() + + expect(executeOnMain).toHaveBeenCalled() + }) + + it('should not install if not compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + + jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve()) + await extension.install() + + expect(executeOnMain).not.toHaveBeenCalled() + }) + }) + + describe('installationState', () => { + it('should return NotCompatible if not compatible', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + + const result = await extension.installationState() + + expect(result).toBe('NotCompatible') + }) + + it('should return Installed if executable exists', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(true) + + const result = await extension.installationState() + + expect(result).toBe('Installed') + }) + + it('should return NotInstalled if executable does not exist', async () => { + const mockSystemInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + ;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo) + ;(fs.existsSync as jest.Mock).mockResolvedValue(false) + + const result = await extension.installationState() + + expect(result).toBe('NotInstalled') + }) + }) + + describe('isCompatible', () => { + it('should return true for compatible system', () => { + const mockInfo: any = { + osInfo: { platform: 'win32' }, + gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] }, + } + + const result = extension.isCompatible(mockInfo) + + expect(result).toBe(true) + }) + + it('should return false for incompatible system', () => { + const mockInfo: any = { + osInfo: { platform: 'linux' }, + gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] }, + } + + const result = extension.isCompatible(mockInfo) + + expect(result).toBe(false) + }) + }) +}) + +describe('GitHub Release File URL Test', () => { + const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'; + + it('should return a status code 200 for the release file URL', async () => { + const response = await fetch(url, { method: 'HEAD' }); + expect(response.status).toBe(200); + }); + + it('should not return a 404 status', async () => { + const response = await fetch(url, { method: 'HEAD' }); + expect(response.status).not.toBe(404); + }); +}); diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index 7f68c43bd..11c86a9a7 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { override nodeModule = NODE private supportedGpuArch = ['ampere', 'ada'] - private supportedPlatform = ['win32', 'linux'] override compatibility() { return COMPATIBILITY as unknown as Compatibility @@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { !!info.gpuSetting && !!firstGpu && info.gpuSetting.gpus.length > 0 && - this.supportedPlatform.includes(info.osInfo.platform) && + this.compatibility().platform.includes(info.osInfo.platform) && !!firstGpu.arch && firstGpu.name.toLowerCase().includes('nvidia') && this.supportedGpuArch.includes(firstGpu.arch) diff --git a/extensions/tensorrt-llm-extension/tsconfig.json b/extensions/tensorrt-llm-extension/tsconfig.json index 478a05728..be07e716c 100644 --- a/extensions/tensorrt-llm-extension/tsconfig.json +++ b/extensions/tensorrt-llm-extension/tsconfig.json @@ -16,5 +16,6 @@ "resolveJsonModule": true, "typeRoots": ["node_modules/@types"] }, - "include": ["src"] + "include": ["src"], + "exclude": ["**/*.test.ts"] } diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts index d0d13d93b..0cd21ea83 100644 --- a/web/hooks/useDownloadModel.ts +++ b/web/hooks/useDownloadModel.ts @@ -9,6 +9,8 @@ import { ModelArtifact, DownloadState, GpuSetting, + ModelFile, + dirName, } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' @@ -91,9 +93,12 @@ export default function useDownloadModel() { ] ) - const abortModelDownload = useCallback(async (model: Model) => { + const abortModelDownload = useCallback(async (model: Model | ModelFile) => { for (const source of model.sources) { - const path = await joinPath(['models', model.id, source.filename]) + const path = + 'file_path' in model + ? await joinPath([await dirName(model.file_path), source.filename]) + : await joinPath(['models', model.id, source.filename]) await abortDownload(path) } }, [])