fix: #3491 - Unable to use tensorrt-llm (#3741)

* fix: #3491 - Unable to use tensorrt-llm

* fix: abortModelDownload input type
This commit is contained in:
Louis 2024-09-30 11:58:46 +07:00 committed by GitHub
parent cf0a232001
commit 8334076047
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 214 additions and 8 deletions

View File

@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}

View File

@ -22,6 +22,7 @@
"tensorrtVersion": "0.1.8",
"provider": "nitro-tensorrt-llm",
"scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
@ -49,7 +50,12 @@
"rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6",
"typescript": "^5.2.2"
"typescript": "^5.2.2",
"@types/jest": "^29.5.12",
"jest": "^29.7.0",
"jest-junit": "^16.0.0",
"jest-runner": "^29.7.0",
"ts-jest": "^29.2.5"
},
"dependencies": {
"@janhq/core": "file:../../core",

View File

@ -23,10 +23,10 @@ export default [
DOWNLOAD_RUNNER_URL:
process.platform === 'win32'
? JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
)
: JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
INFERENCE_URL: JSON.stringify(

View File

@ -0,0 +1,186 @@
import TensorRTLLMExtension from '../src/index'
import {
executeOnMain,
systemInformation,
fs,
baseName,
joinPath,
downloadFile,
} from '@janhq/core'
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
LocalOAIEngine: jest.fn().mockImplementation(function () {
// @ts-ignore
this.registerModels = () => {
return Promise.resolve()
}
// @ts-ignore
return this
}),
systemInformation: jest.fn(),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
},
joinPath: jest.fn(),
baseName: jest.fn(),
downloadFile: jest.fn(),
executeOnMain: jest.fn(),
showToast: jest.fn(),
events: {
emit: jest.fn(),
// @ts-ignore
on: (event, func) => {
func({ fileName: './' })
},
off: jest.fn(),
},
}))
// @ts-ignore
global.COMPATIBILITY = {
platform: ['win32'],
}
// @ts-ignore
global.PROVIDER = 'tensorrt-llm'
// @ts-ignore
global.INFERENCE_URL = 'http://localhost:5000'
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.MODELS = []
// @ts-ignore
global.TENSORRT_VERSION = ''
// @ts-ignore
global.DOWNLOAD_RUNNER_URL = ''
describe('TensorRTLLMExtension', () => {
let extension: TensorRTLLMExtension
beforeEach(() => {
// @ts-ignore
extension = new TensorRTLLMExtension()
jest.clearAllMocks()
})
describe('compatibility', () => {
it('should return the correct compatibility', () => {
const result = extension.compatibility()
expect(result).toEqual({
platform: ['win32'],
})
})
})
describe('install', () => {
it('should install if compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(executeOnMain as jest.Mock).mockResolvedValue({})
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
;(baseName as jest.Mock).mockResolvedValue('./')
;(joinPath as jest.Mock).mockResolvedValue('./')
;(downloadFile as jest.Mock).mockResolvedValue({})
await extension.install()
expect(executeOnMain).toHaveBeenCalled()
})
it('should not install if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
await extension.install()
expect(executeOnMain).not.toHaveBeenCalled()
})
})
describe('installationState', () => {
it('should return NotCompatible if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
const result = await extension.installationState()
expect(result).toBe('NotCompatible')
})
it('should return Installed if executable exists', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(true)
const result = await extension.installationState()
expect(result).toBe('Installed')
})
it('should return NotInstalled if executable does not exist', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
const result = await extension.installationState()
expect(result).toBe('NotInstalled')
})
})
describe('isCompatible', () => {
it('should return true for compatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
const result = extension.isCompatible(mockInfo)
expect(result).toBe(true)
})
it('should return false for incompatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
}
const result = extension.isCompatible(mockInfo)
expect(result).toBe(false)
})
})
})
describe('GitHub Release File URL Test', () => {
const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';
it('should return a status code 200 for the release file URL', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).toBe(200);
});
it('should not return a 404 status', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).not.toBe(404);
});
});

View File

@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
override nodeModule = NODE
private supportedGpuArch = ['ampere', 'ada']
private supportedPlatform = ['win32', 'linux']
override compatibility() {
return COMPATIBILITY as unknown as Compatibility
@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
!!info.gpuSetting &&
!!firstGpu &&
info.gpuSetting.gpus.length > 0 &&
this.supportedPlatform.includes(info.osInfo.platform) &&
this.compatibility().platform.includes(info.osInfo.platform) &&
!!firstGpu.arch &&
firstGpu.name.toLowerCase().includes('nvidia') &&
this.supportedGpuArch.includes(firstGpu.arch)

View File

@ -16,5 +16,6 @@
"resolveJsonModule": true,
"typeRoots": ["node_modules/@types"]
},
"include": ["src"]
"include": ["src"],
"exclude": ["**/*.test.ts"]
}

View File

@ -9,6 +9,8 @@ import {
ModelArtifact,
DownloadState,
GpuSetting,
ModelFile,
dirName,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
@ -91,9 +93,12 @@ export default function useDownloadModel() {
]
)
const abortModelDownload = useCallback(async (model: Model) => {
const abortModelDownload = useCallback(async (model: Model | ModelFile) => {
for (const source of model.sources) {
const path = await joinPath(['models', model.id, source.filename])
const path =
'file_path' in model
? await joinPath([await dirName(model.file_path), source.filename])
: await joinPath(['models', model.id, source.filename])
await abortDownload(path)
}
}, [])