feat: nitro additional dependencies (#2674)

This commit is contained in:
Louis 2024-04-11 09:13:02 +07:00 committed by GitHub
parent 8917be5ef3
commit d93d74c86b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 371 additions and 339 deletions

View File

@ -23,9 +23,9 @@ export interface Compatibility {
const ALL_INSTALLATION_STATE = [ const ALL_INSTALLATION_STATE = [
'NotRequired', // not required. 'NotRequired', // not required.
'Installed', // require and installed. Good to go. 'Installed', // require and installed. Good to go.
'Updatable', // require and installed but need to be updated.
'NotInstalled', // require to be installed. 'NotInstalled', // require to be installed.
'Corrupted', // require but corrupted. Need to redownload. 'Corrupted', // require but corrupted. Need to redownload.
'NotCompatible', // require but not compatible.
] as const ] as const
export type InstallationStateTuple = typeof ALL_INSTALLATION_STATE export type InstallationStateTuple = typeof ALL_INSTALLATION_STATE
@ -98,13 +98,6 @@ export abstract class BaseExtension implements ExtensionType {
return undefined return undefined
} }
/**
* Determine if the extension is updatable.
*/
updatable(): boolean {
return false
}
async registerSettings(settings: SettingComponentProps[]): Promise<void> { async registerSettings(settings: SettingComponentProps[]): Promise<void> {
if (!this.name) { if (!this.name) {
console.error('Extension name is not defined') console.error('Extension name is not defined')

View File

@ -1,7 +1,7 @@
{ {
"name": "@janhq/inference-nitro-extension", "name": "@janhq/inference-nitro-extension",
"version": "1.0.0", "version": "1.0.0",
"description": "This extension embeds Nitro, a lightweight (3mb) inference engine written in C++. See nitro.jan.ai", "description": "This extension embeds Nitro, a lightweight (3mb) inference engine written in C++. See https://nitro.jan.ai.\nUse this setting if you encounter errors related to **CUDA toolkit** during application execution.",
"main": "dist/index.js", "main": "dist/index.js",
"node": "dist/node/index.cjs.js", "node": "dist/node/index.cjs.js",
"author": "Jan <service@jan.ai>", "author": "Jan <service@jan.ai>",
@ -29,6 +29,7 @@
"@rollup/plugin-json": "^6.1.0", "@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^15.2.3", "@rollup/plugin-node-resolve": "^15.2.3",
"@rollup/plugin-replace": "^5.0.5", "@rollup/plugin-replace": "^5.0.5",
"@types/decompress": "^4.2.7",
"@types/jest": "^29.5.12", "@types/jest": "^29.5.12",
"@types/node": "^20.11.4", "@types/node": "^20.11.4",
"@types/os-utils": "^0.0.4", "@types/os-utils": "^0.0.4",
@ -47,6 +48,7 @@
}, },
"dependencies": { "dependencies": {
"@janhq/core": "file:../../core", "@janhq/core": "file:../../core",
"decompress": "^4.2.1",
"fetch-retry": "^5.0.6", "fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1", "path-browserify": "^1.0.1",
"rxjs": "^7.8.1", "rxjs": "^7.8.1",
@ -65,6 +67,7 @@
"bundleDependencies": [ "bundleDependencies": [
"tcp-port-used", "tcp-port-used",
"fetch-retry", "fetch-retry",
"@janhq/core" "@janhq/core",
"decompress"
] ]
} }

View File

@ -92,6 +92,9 @@ export default [
JAN_SERVER_INFERENCE_URL: JSON.stringify( JAN_SERVER_INFERENCE_URL: JSON.stringify(
'http://localhost:1337/v1/chat/completions' 'http://localhost:1337/v1/chat/completions'
), ),
CUDA_DOWNLOAD_URL: JSON.stringify(
'https://catalog.jan.ai/dist/cuda-dependencies/<version>/<platform>/cuda.tar.gz'
),
}), }),
// Allow json resolution // Allow json resolution
json(), json(),

View File

@ -12,8 +12,19 @@ import {
Model, Model,
ModelEvent, ModelEvent,
LocalOAIEngine, LocalOAIEngine,
InstallationState,
systemInformation,
fs,
getJanDataFolderPath,
joinPath,
DownloadRequest,
baseName,
downloadFile,
DownloadState,
DownloadEvent,
} from '@janhq/core' } from '@janhq/core'
declare const CUDA_DOWNLOAD_URL: string
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
@ -61,6 +72,11 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
const models = MODELS as unknown as Model[] const models = MODELS as unknown as Model[]
this.registerModels(models) this.registerModels(models)
super.onLoad() super.onLoad()
executeOnMain(NODE, 'addAdditionalDependencies', {
name: this.name,
version: this.version,
})
} }
/** /**
@ -96,4 +112,80 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
} }
return super.unloadModel(model) return super.unloadModel(model)
} }
override async install(): Promise<void> {
const info = await systemInformation()
const platform = info.osInfo?.platform === 'win32' ? 'windows' : 'linux'
const downloadUrl = CUDA_DOWNLOAD_URL
const url = downloadUrl
.replace('<version>', info.gpuSetting.cuda?.version ?? '12.4')
.replace('<platform>', platform)
console.debug('Downloading Cuda Toolkit Dependency: ', url)
const janDataFolderPath = await getJanDataFolderPath()
const executableFolderPath = await joinPath([
janDataFolderPath,
'engines',
this.name ?? 'nitro',
this.version ?? '1.0.0',
])
if (!(await fs.existsSync(executableFolderPath))) {
await fs.mkdir(executableFolderPath)
}
const tarball = await baseName(url)
const tarballFullPath = await joinPath([executableFolderPath, tarball])
const downloadRequest: DownloadRequest = {
url,
localPath: tarballFullPath,
extensionId: this.name,
downloadType: 'extension',
}
downloadFile(downloadRequest)
const onFileDownloadSuccess = async (state: DownloadState) => {
console.log(state)
// if other download, ignore
if (state.fileName !== tarball) return
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
await executeOnMain(
NODE,
'decompressRunner',
tarballFullPath,
executableFolderPath
)
events.emit(DownloadEvent.onFileUnzipSuccess, state)
}
events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
}
override async installationState(): Promise<InstallationState> {
const info = await systemInformation()
if (
info.gpuSetting.run_mode === 'gpu' &&
!info.gpuSetting.vulkan &&
info.osInfo &&
info.osInfo.platform !== 'darwin' &&
!info.gpuSetting.cuda?.exist
) {
const janDataFolderPath = await getJanDataFolderPath()
const executableFolderPath = await joinPath([
janDataFolderPath,
'engines',
this.name ?? 'nitro',
this.version ?? '1.0.0',
])
if (!(await fs.existsSync(executableFolderPath))) return 'NotInstalled'
return 'Installed'
}
return 'NotRequired'
}
} }

View File

@ -11,9 +11,11 @@ import {
ModelSettingParams, ModelSettingParams,
PromptTemplate, PromptTemplate,
SystemInformation, SystemInformation,
getJanDataFolderPath,
} from '@janhq/core/node' } from '@janhq/core/node'
import { executableNitroFile } from './execute' import { executableNitroFile } from './execute'
import terminate from 'terminate' import terminate from 'terminate'
import decompress from 'decompress'
// Polyfill fetch with retry // Polyfill fetch with retry
const fetchRetry = fetchRT(fetch) const fetchRetry = fetchRT(fetch)
@ -420,9 +422,32 @@ const getCurrentNitroProcessInfo = (): NitroProcessInfo => {
} }
} }
const addAdditionalDependencies = (data: { name: string; version: string }) => {
const additionalPath = path.delimiter.concat(
path.join(getJanDataFolderPath(), 'engines', data.name, data.version)
)
// Set the updated PATH
process.env.PATH = (process.env.PATH || '').concat(additionalPath)
process.env.LD_LIBRARY_PATH = (process.env.LD_LIBRARY_PATH || '').concat(
additionalPath
)
}
const decompressRunner = async (zipPath: string, output: string) => {
console.debug(`Decompressing ${zipPath} to ${output}...`)
try {
const files = await decompress(zipPath, output)
console.debug('Decompress finished!', files)
} catch (err) {
console.error(`Decompress ${zipPath} failed: ${err}`)
}
}
export default { export default {
loadModel, loadModel,
unloadModel, unloadModel,
dispose, dispose,
getCurrentNitroProcessInfo, getCurrentNitroProcessInfo,
addAdditionalDependencies,
decompressRunner,
} }

View File

@ -85,8 +85,8 @@ export default class JanModelExtension extends ModelExtension {
} }
if (!JanModelExtension._supportedGpuArch.includes(gpuArch)) { if (!JanModelExtension._supportedGpuArch.includes(gpuArch)) {
console.error( console.debug(
`Your GPU: ${firstGpu} is not supported. Only 20xx, 30xx, 40xx series are supported.` `Your GPU: ${JSON.stringify(firstGpu)} is not supported. Only 30xx, 40xx series are supported.`
) )
return return
} }

View File

@ -200,7 +200,7 @@ const updateGpuInfo = async () =>
process.platform === 'win32' process.platform === 'win32'
? `${__dirname}\\..\\bin\\vulkaninfoSDK.exe --summary` ? `${__dirname}\\..\\bin\\vulkaninfoSDK.exe --summary`
: `${__dirname}/../bin/vulkaninfo --summary`, : `${__dirname}/../bin/vulkaninfo --summary`,
(error, stdout) => { async (error, stdout) => {
if (!error) { if (!error) {
const output = stdout.toString() const output = stdout.toString()
@ -221,7 +221,7 @@ const updateGpuInfo = async () =>
data.gpus_in_use = [data.gpus.length > 1 ? '1' : '0'] data.gpus_in_use = [data.gpus.length > 1 ? '1' : '0']
} }
data = updateCudaExistence(data) data = await updateCudaExistence(data)
writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2))
log(`[APP]::${JSON.stringify(data)}`) log(`[APP]::${JSON.stringify(data)}`)
resolve({}) resolve({})
@ -233,7 +233,7 @@ const updateGpuInfo = async () =>
} else { } else {
exec( exec(
'nvidia-smi --query-gpu=index,memory.total,name --format=csv,noheader,nounits', 'nvidia-smi --query-gpu=index,memory.total,name --format=csv,noheader,nounits',
(error, stdout) => { async (error, stdout) => {
if (!error) { if (!error) {
log(`[SPECS]::${stdout}`) log(`[SPECS]::${stdout}`)
// Get GPU info and gpu has higher memory first // Get GPU info and gpu has higher memory first
@ -264,7 +264,8 @@ const updateGpuInfo = async () =>
data.gpus_in_use = [data.gpu_highest_vram] data.gpus_in_use = [data.gpu_highest_vram]
} }
data = updateCudaExistence(data) data = await updateCudaExistence(data)
console.log(data)
writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2)) writeFileSync(GPU_INFO_FILE, JSON.stringify(data, null, 2))
log(`[APP]::${JSON.stringify(data)}`) log(`[APP]::${JSON.stringify(data)}`)
resolve({}) resolve({})
@ -283,9 +284,9 @@ const checkFileExistenceInPaths = (file: string, paths: string[]): boolean => {
/** /**
* Validate cuda for linux and windows * Validate cuda for linux and windows
*/ */
const updateCudaExistence = ( const updateCudaExistence = async (
data: GpuSetting = DEFAULT_SETTINGS data: GpuSetting = DEFAULT_SETTINGS
): GpuSetting => { ): Promise<GpuSetting> => {
let filesCuda12: string[] let filesCuda12: string[]
let filesCuda11: string[] let filesCuda11: string[]
let paths: string[] let paths: string[]
@ -329,6 +330,23 @@ const updateCudaExistence = (
} }
data.is_initial = false data.is_initial = false
// Attempt to query CUDA using NVIDIA SMI
if (!cudaExists) {
await new Promise<void>((resolve, reject) => {
exec('nvidia-smi', (error, stdout) => {
if (!error) {
const regex = /CUDA\s*Version:\s*(\d+\.\d+)/g
const match = regex.exec(stdout)
if (match && match[1]) {
data.cuda.version = match[1]
}
}
console.log(data)
resolve()
})
})
}
return data return data
} }

View File

@ -22,6 +22,7 @@ import {
MessageRequest, MessageRequest,
ModelEvent, ModelEvent,
getJanDataFolderPath, getJanDataFolderPath,
SystemInformation,
} from '@janhq/core' } from '@janhq/core'
/** /**
@ -40,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
private supportedGpuArch = ['ampere', 'ada'] private supportedGpuArch = ['ampere', 'ada']
private supportedPlatform = ['win32', 'linux'] private supportedPlatform = ['win32', 'linux']
private isUpdateAvailable = false
override compatibility() { override compatibility() {
return COMPATIBILITY as unknown as Compatibility return COMPATIBILITY as unknown as Compatibility
@ -59,33 +59,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
await this.removePopulatedModels() await this.removePopulatedModels()
const info = await systemInformation() const info = await systemInformation()
console.debug(
`TensorRTLLMExtension installing pre-requisites... ${JSON.stringify(info)}`
)
const gpuSetting: GpuSetting | undefined = info.gpuSetting
if (gpuSetting === undefined || gpuSetting.gpus.length === 0) {
console.error('No GPU setting found. Please check your GPU setting.')
return
}
// TODO: we only check for the first graphics card. Need to refactor this later. if (!this.isCompatible(info)) return
const firstGpu = gpuSetting.gpus[0]
if (!firstGpu.name.toLowerCase().includes('nvidia')) {
console.error('No Nvidia GPU found. Please check your GPU setting.')
return
}
if (firstGpu.arch === undefined) {
console.error('No GPU architecture found. Please check your GPU setting.')
return
}
if (!this.supportedGpuArch.includes(firstGpu.arch)) {
console.error(
`Your GPU: ${firstGpu} is not supported. Only 20xx, 30xx, 40xx series are supported.`
)
return
}
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
const engineVersion = TENSORRT_VERSION const engineVersion = TENSORRT_VERSION
@ -95,7 +70,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
'engines', 'engines',
this.provider, this.provider,
engineVersion, engineVersion,
firstGpu.arch, info.gpuSetting.gpus[0].arch,
]) ])
if (!(await fs.existsSync(executableFolderPath))) { if (!(await fs.existsSync(executableFolderPath))) {
@ -107,7 +82,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
const url = placeholderUrl const url = placeholderUrl
.replace(/<version>/g, tensorrtVersion) .replace(/<version>/g, tensorrtVersion)
.replace(/<gpuarch>/g, firstGpu.arch) .replace(/<gpuarch>/g, info.gpuSetting!.gpus[0]!.arch!)
const tarball = await baseName(url) const tarball = await baseName(url)
@ -163,70 +138,17 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
} }
override async loadModel(model: Model): Promise<void> { override async loadModel(model: Model): Promise<void> {
if (model.engine !== this.provider) return
if ((await this.installationState()) === 'Installed') if ((await this.installationState()) === 'Installed')
return super.loadModel(model) return super.loadModel(model)
else {
events.emit(ModelEvent.OnModelFail, {
...model,
error: {
message: 'EXTENSION_IS_NOT_INSTALLED::TensorRT-LLM extension',
},
})
}
}
override updatable() { throw new Error('EXTENSION_IS_NOT_INSTALLED::TensorRT-LLM extension')
return this.isUpdateAvailable
} }
override async installationState(): Promise<InstallationState> { override async installationState(): Promise<InstallationState> {
const info = await systemInformation() const info = await systemInformation()
const gpuSetting: GpuSetting | undefined = info.gpuSetting if (!this.isCompatible(info)) return 'NotCompatible'
if (gpuSetting === undefined) { const firstGpu = info.gpuSetting.gpus[0]
console.warn(
'No GPU setting found. TensorRT-LLM extension is not installed'
)
return 'NotInstalled' // TODO: maybe disabled / incompatible is more appropriate
}
if (gpuSetting.gpus.length === 0) {
console.warn('No GPU found. TensorRT-LLM extension is not installed')
return 'NotInstalled'
}
const firstGpu = gpuSetting.gpus[0]
if (!firstGpu.name.toLowerCase().includes('nvidia')) {
console.error('No Nvidia GPU found. Please check your GPU setting.')
return 'NotInstalled'
}
if (firstGpu.arch === undefined) {
console.error('No GPU architecture found. Please check your GPU setting.')
return 'NotInstalled'
}
if (!this.supportedGpuArch.includes(firstGpu.arch)) {
console.error(
`Your GPU: ${firstGpu} is not supported. Only 20xx, 30xx, 40xx series are supported.`
)
return 'NotInstalled'
}
const osInfo = info.osInfo
if (!osInfo) {
console.error('No OS information found. Please check your OS setting.')
return 'NotInstalled'
}
if (!this.supportedPlatform.includes(osInfo.platform)) {
console.error(
`Your OS: ${osInfo.platform} is not supported. Only Windows and Linux are supported.`
)
return 'NotInstalled'
}
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
const engineVersion = TENSORRT_VERSION const engineVersion = TENSORRT_VERSION
@ -236,7 +158,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
this.provider, this.provider,
engineVersion, engineVersion,
firstGpu.arch, firstGpu.arch,
osInfo.platform === 'win32' ? 'nitro.exe' : 'nitro', info.osInfo.platform === 'win32' ? 'nitro.exe' : 'nitro',
]) ])
// For now, we just check the executable of nitro x tensor rt // For now, we just check the executable of nitro x tensor rt
@ -258,4 +180,19 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
if (data.model) data.model.parameters.stream = true if (data.model) data.model.parameters.stream = true
super.inference(data) super.inference(data)
} }
isCompatible(info: SystemInformation): info is Required<SystemInformation> & {
gpuSetting: { gpus: { arch: string }[] }
} {
const firstGpu = info.gpuSetting.gpus[0]
return (
!!info.osInfo &&
info.gpuSetting?.gpus?.length > 0 &&
this.supportedPlatform.includes(info.osInfo.platform) &&
!!firstGpu &&
!!firstGpu.arch &&
firstGpu.name.toLowerCase().includes('nvidia') &&
this.supportedGpuArch.includes(firstGpu.arch)
)
}
} }

View File

@ -34,7 +34,7 @@ const InstallingExtension: React.FC = () => {
onClick={onClick} onClick={onClick}
> >
<p className="text-xs font-semibold text-muted-foreground"> <p className="text-xs font-semibold text-muted-foreground">
Installing Extension Installing Additional Dependencies
</p> </p>
<div className="flex flex-row items-center justify-center space-x-2 rounded-md bg-secondary px-2 py-[2px]"> <div className="flex flex-row items-center justify-center space-x-2 rounded-md bg-secondary px-2 py-[2px]">

View File

@ -2,32 +2,30 @@ import { useCallback, useEffect, useState } from 'react'
import { fs, joinPath } from '@janhq/core' import { fs, joinPath } from '@janhq/core'
type NvidiaDriver = {
exist: boolean
version: string
}
export type AppSettings = { export type AppSettings = {
run_mode: 'cpu' | 'gpu' | undefined run_mode: 'cpu' | 'gpu' | undefined
notify: boolean notify: boolean
gpus_in_use: string[] gpus_in_use: string[]
vulkan: boolean vulkan: boolean
gpus: string[] gpus: string[]
nvidia_driver: NvidiaDriver
cuda: NvidiaDriver
} }
export const useSettings = () => { export const useSettings = () => {
const [isGPUModeEnabled, setIsGPUModeEnabled] = useState(false) // New state for GPU mode
const [settings, setSettings] = useState<AppSettings>() const [settings, setSettings] = useState<AppSettings>()
useEffect(() => { useEffect(() => {
readSettings().then((settings) => setSettings(settings as AppSettings)) readSettings().then((settings) => setSettings(settings as AppSettings))
setTimeout(() => validateSettings, 3000)
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])
const validateSettings = async () => {
readSettings().then((settings) => {
// Check if run_mode is 'gpu' or 'cpu' and update state accordingly
setIsGPUModeEnabled(settings?.run_mode === 'gpu')
})
}
const readSettings = useCallback(async () => { const readSettings = useCallback(async () => {
if (!window?.core?.api) { if (!window?.core?.api) {
return return
@ -69,10 +67,8 @@ export const useSettings = () => {
} }
return { return {
isGPUModeEnabled,
readSettings, readSettings,
saveSettings, saveSettings,
validateSettings,
settings, settings,
} }
} }

View File

@ -9,6 +9,8 @@ import { MainViewState } from '@/constants/screens'
import { loadModelErrorAtom } from '@/hooks/useActiveModel' import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { useSettings } from '@/hooks/useSettings'
import { mainViewStateAtom } from '@/helpers/atoms/App.atom' import { mainViewStateAtom } from '@/helpers/atoms/App.atom'
import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom' import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -17,15 +19,16 @@ const LoadModelError = () => {
const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom) const setModalTroubleShooting = useSetAtom(modalTroubleShootingAtom)
const loadModelError = useAtomValue(loadModelErrorAtom) const loadModelError = useAtomValue(loadModelErrorAtom)
const setMainState = useSetAtom(mainViewStateAtom) const setMainState = useSetAtom(mainViewStateAtom)
const activeThread = useAtomValue(activeThreadAtom)
const setSelectedSettingScreen = useSetAtom(selectedSettingAtom) const setSelectedSettingScreen = useSetAtom(selectedSettingAtom)
const activeThread = useAtomValue(activeThreadAtom)
const { settings } = useSettings()
const PORT_NOT_AVAILABLE = 'PORT_NOT_AVAILABLE' const PORT_NOT_AVAILABLE = 'PORT_NOT_AVAILABLE'
const ErrorMessage = () => {
if (loadModelError === PORT_NOT_AVAILABLE) {
return ( return (
<div className="mt-10"> <p>
{loadModelError === PORT_NOT_AVAILABLE ? (
<div className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500">
<p className="w-[90%]">
Port 3928 is currently unavailable. Check for conflicting apps, or Port 3928 is currently unavailable. Check for conflicting apps, or
access&nbsp; access&nbsp;
<span <span
@ -34,19 +37,18 @@ const LoadModelError = () => {
> >
troubleshooting assistance troubleshooting assistance
</span> </span>
&nbsp;for further support.
</p> </p>
<ModalTroubleShooting /> )
</div> } else if (
) : loadModelError && typeof loadModelError?.includes === 'function' &&
typeof loadModelError.includes === 'function' && loadModelError.includes('EXTENSION_IS_NOT_INSTALLED')
loadModelError.includes('EXTENSION_IS_NOT_INSTALLED') ? ( ) {
<div className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"> return (
<p className="w-[90%]"> <p>
Model is currently unavailable. Please switch to a different model Model is currently unavailable. Please switch to a different model or
or install the{' '} install the{' '}
<button <span
className="font-medium text-primary dark:text-blue-400" className="cursor-pointer font-medium text-primary dark:text-blue-400"
onClick={() => { onClick={() => {
setMainState(MainViewState.Settings) setMainState(MainViewState.Settings)
if (activeThread?.assistants[0]?.model.engine) { if (activeThread?.assistants[0]?.model.engine) {
@ -58,12 +60,56 @@ const LoadModelError = () => {
}} }}
> >
{loadModelError.split('::')[1] ?? ''} {loadModelError.split('::')[1] ?? ''}
</button>{' '} </span>{' '}
to continue using it. to continue using it.
</p> </p>
</div> )
} else if (
settings &&
settings.run_mode === 'gpu' &&
!settings.vulkan &&
(!settings.nvidia_driver?.exist || !settings.cuda?.exist)
) {
return (
<>
{!settings?.cuda.exist ? (
<p>
The CUDA toolkit may be unavailable. Please use the{' '}
<span
className="cursor-pointer font-medium text-primary dark:text-blue-400"
onClick={() => {
setMainState(MainViewState.Settings)
if (activeThread?.assistants[0]?.model.engine) {
const engine = EngineManager.instance().get(
activeThread.assistants[0].model.engine
)
engine?.name && setSelectedSettingScreen(engine.name)
}
}}
>
Install Additional Dependencies
</span>{' '}
setting to proceed with the download / installation process.
</p>
) : ( ) : (
<div className="mx-6 flex flex-col items-center space-y-2 text-center text-sm font-medium text-gray-500"> <div>
Problem with Nvidia drivers. Please follow the{' '}
<a
className="font-medium text-primary dark:text-blue-400"
href="https://www.nvidia.com/Download/index.aspx"
target="_blank"
>
Nvidia Drivers guideline
</a>{' '}
to access installation instructions and ensure proper functioning
of the application.
</div>
)}
</>
)
} else {
return (
<div>
Apologies, somethings amiss! Apologies, somethings amiss!
<p> <p>
Jans in beta. Access&nbsp; Jans in beta. Access&nbsp;
@ -75,9 +121,19 @@ const LoadModelError = () => {
</span> </span>
&nbsp;now. &nbsp;now.
</p> </p>
</div>
)
}
}
return (
<div className="mt-10">
<div className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500">
<p className="w-[90%]">
<ErrorMessage />
</p>
<ModalTroubleShooting /> <ModalTroubleShooting />
</div> </div>
)}
</div> </div>
) )
} }

View File

@ -67,7 +67,7 @@ const Advanced = () => {
const [gpuList, setGpuList] = useState<GPU[]>([]) const [gpuList, setGpuList] = useState<GPU[]>([])
const [gpusInUse, setGpusInUse] = useState<string[]>([]) const [gpusInUse, setGpusInUse] = useState<string[]>([])
const { readSettings, saveSettings, validateSettings } = useSettings() const { readSettings, saveSettings } = useSettings()
const { stopModel } = useActiveModel() const { stopModel } = useActiveModel()
const selectedGpu = gpuList const selectedGpu = gpuList
@ -277,9 +277,6 @@ const Advanced = () => {
'Successfully turned on GPU Acceleration', 'Successfully turned on GPU Acceleration',
type: 'success', type: 'success',
}) })
setTimeout(() => {
validateSettings()
}, 300)
} else { } else {
saveSettings({ runMode: 'cpu' }) saveSettings({ runMode: 'cpu' })
setGpuEnabled(false) setGpuEnabled(false)

View File

@ -1,14 +1,12 @@
import { useCallback, useEffect, useState } from 'react' import { useCallback, useEffect, useState } from 'react'
import { import {
BaseExtension,
Compatibility, Compatibility,
GpuSetting,
InstallationState, InstallationState,
abortDownload, abortDownload,
systemInformation,
} from '@janhq/core' } from '@janhq/core'
import { import {
Badge,
Button, Button,
Progress, Progress,
Tooltip, Tooltip,
@ -23,25 +21,20 @@ import { useAtomValue } from 'jotai'
import { Marked, Renderer } from 'marked' import { Marked, Renderer } from 'marked'
import UpdateExtensionModal from './UpdateExtensionModal'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import Extension from '@/extension/Extension'
import { installingExtensionAtom } from '@/helpers/atoms/Extension.atom' import { installingExtensionAtom } from '@/helpers/atoms/Extension.atom'
type Props = { type Props = {
item: Extension item: BaseExtension
} }
const TensorRtExtensionItem: React.FC<Props> = ({ item }) => { const ExtensionItem: React.FC<Props> = ({ item }) => {
const [compatibility, setCompatibility] = useState<Compatibility | undefined>( const [compatibility, setCompatibility] = useState<Compatibility | undefined>(
undefined undefined
) )
const [installState, setInstallState] = const [installState, setInstallState] =
useState<InstallationState>('NotRequired') useState<InstallationState>('NotRequired')
const installingExtensions = useAtomValue(installingExtensionAtom) const installingExtensions = useAtomValue(installingExtensionAtom)
const [isGpuSupported, setIsGpuSupported] = useState<boolean>(false)
const [promptUpdateModal, setPromptUpdateModal] = useState<boolean>(false)
const isInstalling = installingExtensions.some( const isInstalling = installingExtensions.some(
(e) => e.extensionId === item.name (e) => e.extensionId === item.name
) )
@ -51,32 +44,6 @@ const TensorRtExtensionItem: React.FC<Props> = ({ item }) => {
?.percentage ?? -1 ?.percentage ?? -1
: -1 : -1
useEffect(() => {
const getSystemInfos = async () => {
const info = await systemInformation()
if (!info) {
setIsGpuSupported(false)
return
}
const gpuSettings: GpuSetting | undefined = info.gpuSetting
if (!gpuSettings || gpuSettings.gpus.length === 0) {
setIsGpuSupported(false)
return
}
const arch = gpuSettings.gpus[0].arch
if (!arch) {
setIsGpuSupported(false)
return
}
const supportedGpuArch = ['ampere', 'ada']
setIsGpuSupported(supportedGpuArch.includes(arch))
}
getSystemInfos()
}, [])
useEffect(() => { useEffect(() => {
const getExtensionInstallationState = async () => { const getExtensionInstallationState = async () => {
const extension = extensionManager.getByName(item.name ?? '') const extension = extensionManager.getByName(item.name ?? '')
@ -116,16 +83,10 @@ const TensorRtExtensionItem: React.FC<Props> = ({ item }) => {
const description = marked.parse(item.description ?? '', { async: false }) const description = marked.parse(item.description ?? '', { async: false })
return ( return (
<div className="flex w-full items-start justify-between border-b border-border py-4 first:pt-4 last:border-none"> <div className="mx-6 flex w-full items-start justify-between border-b border-border py-4 py-6 first:pt-4 last:border-none">
<div className="flex-1 flex-shrink-0 space-y-1.5"> <div className="flex-1 flex-shrink-0 space-y-1.5">
<div className="flex items-center gap-x-2"> <div className="flex items-center gap-x-2">
<h6 className="text-sm font-semibold capitalize"> <h6 className="text-base font-bold">Additional Dependencies</h6>
TensorRT-LLM Extension
</h6>
<p className="whitespace-pre-wrap text-sm font-semibold leading-relaxed">
v{item.version}
</p>
<Badge>Experimental</Badge>
</div> </div>
{ {
// eslint-disable-next-line @typescript-eslint/naming-convention // eslint-disable-next-line @typescript-eslint/naming-convention
@ -133,18 +94,62 @@ const TensorRtExtensionItem: React.FC<Props> = ({ item }) => {
} }
</div> </div>
{(!compatibility || compatibility['platform']?.includes(PLATFORM)) && {(!compatibility || compatibility['platform']?.includes(PLATFORM)) && (
isGpuSupported ? (
<div className="flex min-w-[150px] flex-row justify-end"> <div className="flex min-w-[150px] flex-row justify-end">
<InstallStateIndicator <InstallStateIndicator
installProgress={progress} installProgress={progress}
installState={installState} installState={installState}
compatibility={compatibility}
onInstallClick={onInstallClick} onInstallClick={onInstallClick}
onUpdateClick={() => setPromptUpdateModal(true)}
onCancelClick={onCancelInstallingClick} onCancelClick={onCancelInstallingClick}
/> />
</div> </div>
) : ( )}
</div>
)
}
type InstallStateProps = {
installProgress: number
compatibility?: Compatibility
installState: InstallationState
onInstallClick: () => void
onCancelClick: () => void
}
const InstallStateIndicator: React.FC<InstallStateProps> = ({
installProgress,
compatibility,
installState,
onInstallClick,
onCancelClick,
}) => {
if (installProgress !== -1) {
const progress = installProgress * 100
return (
<div className="flex h-10 flex-row items-center justify-center space-x-2 rounded-lg bg-[#EFF8FF] px-4 text-primary dark:bg-secondary">
<button onClick={onCancelClick} className="font-semibold text-primary">
Cancel
</button>
<div className="flex w-[113px] flex-row items-center justify-center space-x-2 rounded-md bg-[#D1E9FF] px-2 py-[2px] dark:bg-black/50">
<Progress className="h-1 w-[69px]" value={progress} />
<span className="text-xs font-bold text-primary">
{progress.toFixed(0)}%
</span>
</div>
</div>
)
}
switch (installState) {
case 'Installed':
return (
<div className="rounded-md bg-secondary px-3 py-1.5 text-sm font-semibold text-gray-400">
Installed
</div>
)
case 'NotCompatible':
return (
<div className="rounded-md bg-secondary px-3 py-1.5 text-sm font-semibold text-gray-400"> <div className="rounded-md bg-secondary px-3 py-1.5 text-sm font-semibold text-gray-400">
<div className="flex flex-row items-center justify-center gap-1"> <div className="flex flex-row items-center justify-center gap-1">
Incompatible{' '} Incompatible{' '}
@ -179,58 +184,6 @@ const TensorRtExtensionItem: React.FC<Props> = ({ item }) => {
</Tooltip> </Tooltip>
</div> </div>
</div> </div>
)}
{promptUpdateModal && (
<UpdateExtensionModal onUpdateClick={onInstallClick} />
)}
</div>
)
}
type InstallStateProps = {
installProgress: number
installState: InstallationState
onInstallClick: () => void
onUpdateClick: () => void
onCancelClick: () => void
}
const InstallStateIndicator: React.FC<InstallStateProps> = ({
installProgress,
installState,
onInstallClick,
onUpdateClick,
onCancelClick,
}) => {
if (installProgress !== -1) {
const progress = installProgress * 100
return (
<div className="flex h-10 flex-row items-center justify-center space-x-2 rounded-lg bg-[#EFF8FF] px-4 text-primary dark:bg-secondary">
<button onClick={onCancelClick} className="font-semibold text-primary">
Cancel
</button>
<div className="flex w-[113px] flex-row items-center justify-center space-x-2 rounded-md bg-[#D1E9FF] px-2 py-[2px] dark:bg-black/50">
<Progress className="h-1 w-[69px]" value={progress} />
<span className="text-xs font-bold text-primary">
{progress.toFixed(0)}%
</span>
</div>
</div>
)
}
switch (installState) {
case 'Installed':
return (
<div className="rounded-md bg-secondary px-3 py-1.5 text-sm font-semibold text-gray-400">
Installed
</div>
)
case 'Updatable':
return (
<Button themes="secondaryBlue" size="sm" onClick={onUpdateClick}>
Update
</Button>
) )
case 'NotInstalled': case 'NotInstalled':
return ( return (
@ -253,4 +206,4 @@ const marked: Marked = new Marked({
}, },
}) })
export default TensorRtExtensionItem export default ExtensionItem

View File

@ -1,58 +0,0 @@
import React from 'react'
import {
Button,
Modal,
ModalClose,
ModalContent,
ModalFooter,
ModalHeader,
ModalPortal,
ModalTitle,
ModalTrigger,
} from '@janhq/uikit'
import { Paintbrush } from 'lucide-react'
type Props = {
onUpdateClick: () => void
}
const UpdateExtensionModal: React.FC<Props> = ({ onUpdateClick }) => {
return (
<Modal>
<ModalTrigger asChild onClick={(e) => e.stopPropagation()}>
<div className="flex cursor-pointer items-center space-x-2 px-4 py-2 hover:bg-secondary">
<Paintbrush size={16} className="text-muted-foreground" />
<span className="text-bold text-black dark:text-muted-foreground">
Update extension
</span>
</div>
</ModalTrigger>
<ModalPortal />
<ModalContent>
<ModalHeader>
<ModalTitle>Clean Thread</ModalTitle>
</ModalHeader>
<p>
Updating this extension may result in the loss of any custom models or
data associated with the current version. We recommend backing up any
important data before proceeding with the update.
</p>
<ModalFooter>
<div className="flex gap-x-2">
<ModalClose asChild onClick={(e) => e.stopPropagation()}>
<Button themes="ghost">No</Button>
</ModalClose>
<ModalClose asChild>
<Button themes="danger" onClick={onUpdateClick} autoFocus>
Yes
</Button>
</ModalClose>
</div>
</ModalFooter>
</ModalContent>
</Modal>
)
}
export default React.memo(UpdateExtensionModal)

View File

@ -8,7 +8,7 @@ import Loader from '@/containers/Loader'
import { formatExtensionsName } from '@/utils/converter' import { formatExtensionsName } from '@/utils/converter'
import TensorRtExtensionItem from './TensorRtExtensionItem' import ExtensionItem from './ExtensionItem'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import Extension from '@/extension/Extension' import Extension from '@/extension/Extension'
@ -78,11 +78,6 @@ const ExtensionCatalog = () => {
<ScrollArea className="h-full w-full px-4"> <ScrollArea className="h-full w-full px-4">
<div className="block w-full"> <div className="block w-full">
{activeExtensions.map((item, i) => { {activeExtensions.map((item, i) => {
// TODO: this is bad code, rewrite it
if (item.name === '@janhq/tensorrt-llm-extension') {
return <TensorRtExtensionItem key={i} item={item} />
}
return ( return (
<div <div
key={i} key={i}

View File

@ -1,9 +1,14 @@
import React, { useEffect, useState } from 'react' import React, { useEffect, useState } from 'react'
import { SettingComponentProps } from '@janhq/core/.' import {
BaseExtension,
InstallationState,
SettingComponentProps,
} from '@janhq/core/.'
import { useAtomValue } from 'jotai' import { useAtomValue } from 'jotai'
import ExtensionItem from '../CoreExtensions/ExtensionItem'
import SettingDetailItem from '../SettingDetail/SettingDetailItem' import SettingDetailItem from '../SettingDetail/SettingDetailItem'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
@ -12,6 +17,11 @@ import { selectedSettingAtom } from '@/helpers/atoms/Setting.atom'
const ExtensionSetting: React.FC = () => { const ExtensionSetting: React.FC = () => {
const selectedExtensionName = useAtomValue(selectedSettingAtom) const selectedExtensionName = useAtomValue(selectedSettingAtom)
const [settings, setSettings] = useState<SettingComponentProps[]>([]) const [settings, setSettings] = useState<SettingComponentProps[]>([])
const [installationState, setInstallationState] =
useState<InstallationState>('NotRequired')
const [baseExtension, setBaseExtension] = useState<BaseExtension | undefined>(
undefined
)
useEffect(() => { useEffect(() => {
const getExtensionSettings = async () => { const getExtensionSettings = async () => {
@ -19,11 +29,15 @@ const ExtensionSetting: React.FC = () => {
const allSettings: SettingComponentProps[] = [] const allSettings: SettingComponentProps[] = []
const baseExtension = extensionManager.getByName(selectedExtensionName) const baseExtension = extensionManager.getByName(selectedExtensionName)
if (!baseExtension) return if (!baseExtension) return
setBaseExtension(baseExtension)
if (typeof baseExtension.getSettings === 'function') { if (typeof baseExtension.getSettings === 'function') {
const setting = await baseExtension.getSettings() const setting = await baseExtension.getSettings()
if (setting) allSettings.push(...setting) if (setting) allSettings.push(...setting)
} }
setSettings(allSettings) setSettings(allSettings)
setInstallationState(await baseExtension.installationState())
} }
getExtensionSettings() getExtensionSettings()
}, [selectedExtensionName]) }, [selectedExtensionName])
@ -48,13 +62,18 @@ const ExtensionSetting: React.FC = () => {
setSettings(newSettings) setSettings(newSettings)
} }
if (settings.length === 0) return null
return ( return (
<>
{settings.length > 0 && (
<SettingDetailItem <SettingDetailItem
componentProps={settings} componentProps={settings}
onValueUpdated={onValueChanged} onValueUpdated={onValueChanged}
/> />
)}
{baseExtension && installationState !== 'NotRequired' && (
<ExtensionItem item={baseExtension} />
)}
</>
) )
} }

View File

@ -20,7 +20,10 @@ const SettingMenu: React.FC = () => {
for (const extension of extensions) { for (const extension of extensions) {
if (typeof extension.getSettings === 'function') { if (typeof extension.getSettings === 'function') {
const settings = await extension.getSettings() const settings = await extension.getSettings()
if (settings && settings.length > 0) { if (
(settings && settings.length > 0) ||
(await extension.installationState()) !== 'NotRequired'
) {
extensionsMenu.push(extension.name ?? extension.url) extensionsMenu.push(extension.name ?? extension.url)
} }
} }