* fix: #3491 - Unable to use tensorrt-llm * fix: abortModelDownload input type
This commit is contained in:
parent
cf0a232001
commit
8334076047
9
extensions/tensorrt-llm-extension/jest.config.js
Normal file
9
extensions/tensorrt-llm-extension/jest.config.js
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||||
|
module.exports = {
|
||||||
|
preset: 'ts-jest',
|
||||||
|
testEnvironment: 'node',
|
||||||
|
transform: {
|
||||||
|
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
|
||||||
|
},
|
||||||
|
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
|
||||||
|
}
|
||||||
@ -22,6 +22,7 @@
|
|||||||
"tensorrtVersion": "0.1.8",
|
"tensorrtVersion": "0.1.8",
|
||||||
"provider": "nitro-tensorrt-llm",
|
"provider": "nitro-tensorrt-llm",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"test": "jest",
|
||||||
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
|
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
|
||||||
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||||
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||||
@ -49,7 +50,12 @@
|
|||||||
"rollup-plugin-sourcemaps": "^0.6.3",
|
"rollup-plugin-sourcemaps": "^0.6.3",
|
||||||
"rollup-plugin-typescript2": "^0.36.0",
|
"rollup-plugin-typescript2": "^0.36.0",
|
||||||
"run-script-os": "^1.1.6",
|
"run-script-os": "^1.1.6",
|
||||||
"typescript": "^5.2.2"
|
"typescript": "^5.2.2",
|
||||||
|
"@types/jest": "^29.5.12",
|
||||||
|
"jest": "^29.7.0",
|
||||||
|
"jest-junit": "^16.0.0",
|
||||||
|
"jest-runner": "^29.7.0",
|
||||||
|
"ts-jest": "^29.2.5"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@janhq/core": "file:../../core",
|
"@janhq/core": "file:../../core",
|
||||||
|
|||||||
@ -23,10 +23,10 @@ export default [
|
|||||||
DOWNLOAD_RUNNER_URL:
|
DOWNLOAD_RUNNER_URL:
|
||||||
process.platform === 'win32'
|
process.platform === 'win32'
|
||||||
? JSON.stringify(
|
? JSON.stringify(
|
||||||
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
|
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
|
||||||
)
|
)
|
||||||
: JSON.stringify(
|
: JSON.stringify(
|
||||||
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
|
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
|
||||||
),
|
),
|
||||||
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
|
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
|
||||||
INFERENCE_URL: JSON.stringify(
|
INFERENCE_URL: JSON.stringify(
|
||||||
|
|||||||
186
extensions/tensorrt-llm-extension/src/index.test.ts
Normal file
186
extensions/tensorrt-llm-extension/src/index.test.ts
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import TensorRTLLMExtension from '../src/index'
|
||||||
|
import {
|
||||||
|
executeOnMain,
|
||||||
|
systemInformation,
|
||||||
|
fs,
|
||||||
|
baseName,
|
||||||
|
joinPath,
|
||||||
|
downloadFile,
|
||||||
|
} from '@janhq/core'
|
||||||
|
|
||||||
|
jest.mock('@janhq/core', () => ({
|
||||||
|
...jest.requireActual('@janhq/core/node'),
|
||||||
|
LocalOAIEngine: jest.fn().mockImplementation(function () {
|
||||||
|
// @ts-ignore
|
||||||
|
this.registerModels = () => {
|
||||||
|
return Promise.resolve()
|
||||||
|
}
|
||||||
|
// @ts-ignore
|
||||||
|
return this
|
||||||
|
}),
|
||||||
|
systemInformation: jest.fn(),
|
||||||
|
fs: {
|
||||||
|
existsSync: jest.fn(),
|
||||||
|
mkdir: jest.fn(),
|
||||||
|
},
|
||||||
|
joinPath: jest.fn(),
|
||||||
|
baseName: jest.fn(),
|
||||||
|
downloadFile: jest.fn(),
|
||||||
|
executeOnMain: jest.fn(),
|
||||||
|
showToast: jest.fn(),
|
||||||
|
events: {
|
||||||
|
emit: jest.fn(),
|
||||||
|
// @ts-ignore
|
||||||
|
on: (event, func) => {
|
||||||
|
func({ fileName: './' })
|
||||||
|
},
|
||||||
|
off: jest.fn(),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// @ts-ignore
|
||||||
|
global.COMPATIBILITY = {
|
||||||
|
platform: ['win32'],
|
||||||
|
}
|
||||||
|
// @ts-ignore
|
||||||
|
global.PROVIDER = 'tensorrt-llm'
|
||||||
|
// @ts-ignore
|
||||||
|
global.INFERENCE_URL = 'http://localhost:5000'
|
||||||
|
// @ts-ignore
|
||||||
|
global.NODE = 'node'
|
||||||
|
// @ts-ignore
|
||||||
|
global.MODELS = []
|
||||||
|
// @ts-ignore
|
||||||
|
global.TENSORRT_VERSION = ''
|
||||||
|
// @ts-ignore
|
||||||
|
global.DOWNLOAD_RUNNER_URL = ''
|
||||||
|
|
||||||
|
describe('TensorRTLLMExtension', () => {
|
||||||
|
let extension: TensorRTLLMExtension
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// @ts-ignore
|
||||||
|
extension = new TensorRTLLMExtension()
|
||||||
|
jest.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('compatibility', () => {
|
||||||
|
it('should return the correct compatibility', () => {
|
||||||
|
const result = extension.compatibility()
|
||||||
|
expect(result).toEqual({
|
||||||
|
platform: ['win32'],
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('install', () => {
|
||||||
|
it('should install if compatible', async () => {
|
||||||
|
const mockSystemInfo: any = {
|
||||||
|
osInfo: { platform: 'win32' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
;(executeOnMain as jest.Mock).mockResolvedValue({})
|
||||||
|
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||||
|
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
|
||||||
|
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
|
||||||
|
;(baseName as jest.Mock).mockResolvedValue('./')
|
||||||
|
;(joinPath as jest.Mock).mockResolvedValue('./')
|
||||||
|
;(downloadFile as jest.Mock).mockResolvedValue({})
|
||||||
|
|
||||||
|
await extension.install()
|
||||||
|
|
||||||
|
expect(executeOnMain).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not install if not compatible', async () => {
|
||||||
|
const mockSystemInfo: any = {
|
||||||
|
osInfo: { platform: 'linux' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||||
|
|
||||||
|
jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
|
||||||
|
await extension.install()
|
||||||
|
|
||||||
|
expect(executeOnMain).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('installationState', () => {
|
||||||
|
it('should return NotCompatible if not compatible', async () => {
|
||||||
|
const mockSystemInfo: any = {
|
||||||
|
osInfo: { platform: 'linux' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||||
|
|
||||||
|
const result = await extension.installationState()
|
||||||
|
|
||||||
|
expect(result).toBe('NotCompatible')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return Installed if executable exists', async () => {
|
||||||
|
const mockSystemInfo: any = {
|
||||||
|
osInfo: { platform: 'win32' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||||
|
;(fs.existsSync as jest.Mock).mockResolvedValue(true)
|
||||||
|
|
||||||
|
const result = await extension.installationState()
|
||||||
|
|
||||||
|
expect(result).toBe('Installed')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return NotInstalled if executable does not exist', async () => {
|
||||||
|
const mockSystemInfo: any = {
|
||||||
|
osInfo: { platform: 'win32' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
|
||||||
|
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
|
||||||
|
|
||||||
|
const result = await extension.installationState()
|
||||||
|
|
||||||
|
expect(result).toBe('NotInstalled')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isCompatible', () => {
|
||||||
|
it('should return true for compatible system', () => {
|
||||||
|
const mockInfo: any = {
|
||||||
|
osInfo: { platform: 'win32' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = extension.isCompatible(mockInfo)
|
||||||
|
|
||||||
|
expect(result).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false for incompatible system', () => {
|
||||||
|
const mockInfo: any = {
|
||||||
|
osInfo: { platform: 'linux' },
|
||||||
|
gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = extension.isCompatible(mockInfo)
|
||||||
|
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('GitHub Release File URL Test', () => {
|
||||||
|
const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';
|
||||||
|
|
||||||
|
it('should return a status code 200 for the release file URL', async () => {
|
||||||
|
const response = await fetch(url, { method: 'HEAD' });
|
||||||
|
expect(response.status).toBe(200);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not return a 404 status', async () => {
|
||||||
|
const response = await fetch(url, { method: 'HEAD' });
|
||||||
|
expect(response.status).not.toBe(404);
|
||||||
|
});
|
||||||
|
});
|
||||||
@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
override nodeModule = NODE
|
override nodeModule = NODE
|
||||||
|
|
||||||
private supportedGpuArch = ['ampere', 'ada']
|
private supportedGpuArch = ['ampere', 'ada']
|
||||||
private supportedPlatform = ['win32', 'linux']
|
|
||||||
|
|
||||||
override compatibility() {
|
override compatibility() {
|
||||||
return COMPATIBILITY as unknown as Compatibility
|
return COMPATIBILITY as unknown as Compatibility
|
||||||
@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
!!info.gpuSetting &&
|
!!info.gpuSetting &&
|
||||||
!!firstGpu &&
|
!!firstGpu &&
|
||||||
info.gpuSetting.gpus.length > 0 &&
|
info.gpuSetting.gpus.length > 0 &&
|
||||||
this.supportedPlatform.includes(info.osInfo.platform) &&
|
this.compatibility().platform.includes(info.osInfo.platform) &&
|
||||||
!!firstGpu.arch &&
|
!!firstGpu.arch &&
|
||||||
firstGpu.name.toLowerCase().includes('nvidia') &&
|
firstGpu.name.toLowerCase().includes('nvidia') &&
|
||||||
this.supportedGpuArch.includes(firstGpu.arch)
|
this.supportedGpuArch.includes(firstGpu.arch)
|
||||||
|
|||||||
@ -16,5 +16,6 @@
|
|||||||
"resolveJsonModule": true,
|
"resolveJsonModule": true,
|
||||||
"typeRoots": ["node_modules/@types"]
|
"typeRoots": ["node_modules/@types"]
|
||||||
},
|
},
|
||||||
"include": ["src"]
|
"include": ["src"],
|
||||||
|
"exclude": ["**/*.test.ts"]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import {
|
|||||||
ModelArtifact,
|
ModelArtifact,
|
||||||
DownloadState,
|
DownloadState,
|
||||||
GpuSetting,
|
GpuSetting,
|
||||||
|
ModelFile,
|
||||||
|
dirName,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
@ -91,9 +93,12 @@ export default function useDownloadModel() {
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
const abortModelDownload = useCallback(async (model: Model) => {
|
const abortModelDownload = useCallback(async (model: Model | ModelFile) => {
|
||||||
for (const source of model.sources) {
|
for (const source of model.sources) {
|
||||||
const path = await joinPath(['models', model.id, source.filename])
|
const path =
|
||||||
|
'file_path' in model
|
||||||
|
? await joinPath([await dirName(model.file_path), source.filename])
|
||||||
|
: await joinPath(['models', model.id, source.filename])
|
||||||
await abortDownload(path)
|
await abortDownload(path)
|
||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user