* fix: #3491 - Unable to use tensorrt-llm * fix: abortModelDownload input type
This commit is contained in:
parent
cf0a232001
commit
8334076047
9
extensions/tensorrt-llm-extension/jest.config.js
Normal file
9
extensions/tensorrt-llm-extension/jest.config.js
Normal file
@ -0,0 +1,9 @@
|
||||
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||
module.exports = {
|
||||
preset: 'ts-jest',
|
||||
testEnvironment: 'node',
|
||||
transform: {
|
||||
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
|
||||
},
|
||||
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
|
||||
}
|
||||
@ -22,6 +22,7 @@
|
||||
"tensorrtVersion": "0.1.8",
|
||||
"provider": "nitro-tensorrt-llm",
|
||||
"scripts": {
|
||||
"test": "jest",
|
||||
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
|
||||
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||
@ -49,7 +50,12 @@
|
||||
"rollup-plugin-sourcemaps": "^0.6.3",
|
||||
"rollup-plugin-typescript2": "^0.36.0",
|
||||
"run-script-os": "^1.1.6",
|
||||
"typescript": "^5.2.2"
|
||||
"typescript": "^5.2.2",
|
||||
"@types/jest": "^29.5.12",
|
||||
"jest": "^29.7.0",
|
||||
"jest-junit": "^16.0.0",
|
||||
"jest-runner": "^29.7.0",
|
||||
"ts-jest": "^29.2.5"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core",
|
||||
|
||||
@ -23,10 +23,10 @@ export default [
|
||||
DOWNLOAD_RUNNER_URL:
|
||||
process.platform === 'win32'
|
||||
? JSON.stringify(
|
||||
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
|
||||
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
|
||||
)
|
||||
: JSON.stringify(
|
||||
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
|
||||
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
|
||||
),
|
||||
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
|
||||
INFERENCE_URL: JSON.stringify(
|
||||
|
||||
186
extensions/tensorrt-llm-extension/src/index.test.ts
Normal file
186
extensions/tensorrt-llm-extension/src/index.test.ts
Normal file
@ -0,0 +1,186 @@
|
||||
import TensorRTLLMExtension from '../src/index'
|
||||
import {
|
||||
executeOnMain,
|
||||
systemInformation,
|
||||
fs,
|
||||
baseName,
|
||||
joinPath,
|
||||
downloadFile,
|
||||
} from '@janhq/core'
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
LocalOAIEngine: jest.fn().mockImplementation(function () {
|
||||
// @ts-ignore
|
||||
this.registerModels = () => {
|
||||
return Promise.resolve()
|
||||
}
|
||||
// @ts-ignore
|
||||
return this
|
||||
}),
|
||||
systemInformation: jest.fn(),
|
||||
fs: {
|
||||
existsSync: jest.fn(),
|
||||
mkdir: jest.fn(),
|
||||
},
|
||||
joinPath: jest.fn(),
|
||||
baseName: jest.fn(),
|
||||
downloadFile: jest.fn(),
|
||||
executeOnMain: jest.fn(),
|
||||
showToast: jest.fn(),
|
||||
events: {
|
||||
emit: jest.fn(),
|
||||
// @ts-ignore
|
||||
on: (event, func) => {
|
||||
func({ fileName: './' })
|
||||
},
|
||||
off: jest.fn(),
|
||||
},
|
||||
}))
|
||||
|
||||
// @ts-ignore
|
||||
global.COMPATIBILITY = {
|
||||
platform: ['win32'],
|
||||
}
|
||||
// @ts-ignore
|
||||
global.PROVIDER = 'tensorrt-llm'
|
||||
// @ts-ignore
|
||||
global.INFERENCE_URL = 'http://localhost:5000'
|
||||
// @ts-ignore
|
||||
global.NODE = 'node'
|
||||
// @ts-ignore
|
||||
global.MODELS = []
|
||||
// @ts-ignore
|
||||
global.TENSORRT_VERSION = ''
|
||||
// @ts-ignore
|
||||
global.DOWNLOAD_RUNNER_URL = ''
|
||||
|
||||
describe('TensorRTLLMExtension', () => {
|
||||
let extension: TensorRTLLMExtension
|
||||
|
||||
beforeEach(() => {
|
||||
// @ts-ignore
|
||||
extension = new TensorRTLLMExtension()
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('compatibility', () => {
|
||||
it('should return the correct compatibility', () => {
|
||||
const result = extension.compatibility()
|
||||
expect(result).toEqual({
|
||||
platform: ['win32'],
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('install', () => {
|
||||
it('should install if compatible', async () => {
|
||||
const mockSystemInfo: any = {
|
||||
osInfo: { platform: 'win32' },
|
||||
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
;(executeOnMain as jest.Mock).mockResolvedValue({})
|
||||
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
|
||||
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
|
||||
;(baseName as jest.Mock).mockResolvedValue('./')
|
||||
;(joinPath as jest.Mock).mockResolvedValue('./')
|
||||
;(downloadFile as jest.Mock).mockResolvedValue({})
|
||||
|
||||
await extension.install()
|
||||
|
||||
expect(executeOnMain).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not install if not compatible', async () => {
|
||||
const mockSystemInfo: any = {
|
||||
osInfo: { platform: 'linux' },
|
||||
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||
|
||||
jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
|
||||
await extension.install()
|
||||
|
||||
expect(executeOnMain).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('installationState', () => {
|
||||
it('should return NotCompatible if not compatible', async () => {
|
||||
const mockSystemInfo: any = {
|
||||
osInfo: { platform: 'linux' },
|
||||
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||
|
||||
const result = await extension.installationState()
|
||||
|
||||
expect(result).toBe('NotCompatible')
|
||||
})
|
||||
|
||||
it('should return Installed if executable exists', async () => {
|
||||
const mockSystemInfo: any = {
|
||||
osInfo: { platform: 'win32' },
|
||||
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||
;(fs.existsSync as jest.Mock).mockResolvedValue(true)
|
||||
|
||||
const result = await extension.installationState()
|
||||
|
||||
expect(result).toBe('Installed')
|
||||
})
|
||||
|
||||
it('should return NotInstalled if executable does not exist', async () => {
|
||||
const mockSystemInfo: any = {
|
||||
osInfo: { platform: 'win32' },
|
||||
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
|
||||
|
||||
const result = await extension.installationState()
|
||||
|
||||
expect(result).toBe('NotInstalled')
|
||||
})
|
||||
})
|
||||
|
||||
describe('isCompatible', () => {
|
||||
it('should return true for compatible system', () => {
|
||||
const mockInfo: any = {
|
||||
osInfo: { platform: 'win32' },
|
||||
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||
}
|
||||
|
||||
const result = extension.isCompatible(mockInfo)
|
||||
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for incompatible system', () => {
|
||||
const mockInfo: any = {
|
||||
osInfo: { platform: 'linux' },
|
||||
gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
|
||||
}
|
||||
|
||||
const result = extension.isCompatible(mockInfo)
|
||||
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('GitHub Release File URL Test', () => {
|
||||
const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';
|
||||
|
||||
it('should return a status code 200 for the release file URL', async () => {
|
||||
const response = await fetch(url, { method: 'HEAD' });
|
||||
expect(response.status).toBe(200);
|
||||
});
|
||||
|
||||
it('should not return a 404 status', async () => {
|
||||
const response = await fetch(url, { method: 'HEAD' });
|
||||
expect(response.status).not.toBe(404);
|
||||
});
|
||||
});
|
||||
@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
||||
override nodeModule = NODE
|
||||
|
||||
private supportedGpuArch = ['ampere', 'ada']
|
||||
private supportedPlatform = ['win32', 'linux']
|
||||
|
||||
override compatibility() {
|
||||
return COMPATIBILITY as unknown as Compatibility
|
||||
@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
||||
!!info.gpuSetting &&
|
||||
!!firstGpu &&
|
||||
info.gpuSetting.gpus.length > 0 &&
|
||||
this.supportedPlatform.includes(info.osInfo.platform) &&
|
||||
this.compatibility().platform.includes(info.osInfo.platform) &&
|
||||
!!firstGpu.arch &&
|
||||
firstGpu.name.toLowerCase().includes('nvidia') &&
|
||||
this.supportedGpuArch.includes(firstGpu.arch)
|
||||
|
||||
@ -16,5 +16,6 @@
|
||||
"resolveJsonModule": true,
|
||||
"typeRoots": ["node_modules/@types"]
|
||||
},
|
||||
"include": ["src"]
|
||||
"include": ["src"],
|
||||
"exclude": ["**/*.test.ts"]
|
||||
}
|
||||
|
||||
@ -9,6 +9,8 @@ import {
|
||||
ModelArtifact,
|
||||
DownloadState,
|
||||
GpuSetting,
|
||||
ModelFile,
|
||||
dirName,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
@ -91,9 +93,12 @@ export default function useDownloadModel() {
|
||||
]
|
||||
)
|
||||
|
||||
const abortModelDownload = useCallback(async (model: Model) => {
|
||||
const abortModelDownload = useCallback(async (model: Model | ModelFile) => {
|
||||
for (const source of model.sources) {
|
||||
const path = await joinPath(['models', model.id, source.filename])
|
||||
const path =
|
||||
'file_path' in model
|
||||
? await joinPath([await dirName(model.file_path), source.filename])
|
||||
: await joinPath(['models', model.id, source.filename])
|
||||
await abortDownload(path)
|
||||
}
|
||||
}, [])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user