chore: update legacy tensorrt-llm download and run

This commit is contained in:
Louis 2024-10-28 19:08:32 +07:00
parent 2c11caf87e
commit a466bbca38
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 78 additions and 36 deletions

View File

@ -1,4 +1,4 @@
import { executeOnMain, systemInformation, dirName } from '../../core'
import { executeOnMain, systemInformation, dirName, joinPath, getJanDataFolderPath } from '../../core'
import { events } from '../../events'
import { Model, ModelEvent } from '../../../types'
import { OAIEngine } from './OAIEngine'
@ -29,13 +29,46 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
override async loadModel(model: Model): Promise<void> {
override async loadModel(model: Model & { file_path?: string }): Promise<void> {
if (model.engine.toString() !== this.provider) return
const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id)
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
this.loadModelFunctionName,
{
modelFolder,
model,
},
systemInfo
)
if (res?.error) {
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
override async unloadModel(model?: Model) {
return Promise.resolve()
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
this.loadedModel = undefined
await executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
/// Legacy
private getModelFilePath = async (
id: string,
): Promise<string> => {
return joinPath([await getJanDataFolderPath(), 'models', id])
}
///
}

View File

@ -118,19 +118,6 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
.then()
}
private async modelPath(
model: Model & { file_path?: string }
): Promise<string> {
if (!model.file_path) return model.id
return await joinPath([
await dirName(model.file_path),
model.sources[0]?.filename ??
model.settings?.llama_model_path ??
model.sources[0]?.url.split('/').pop() ??
model.id,
])
}
/**
* Do health check on cortex.cpp
* @returns

View File

@ -14,6 +14,7 @@ import { CortexAPI } from './cortex'
import { scanModelsFolder } from './legacy/model-json'
import { downloadModel } from './legacy/download'
import { systemInformation } from '@janhq/core'
import { deleteModelFiles } from './legacy/delete'
declare const SETTINGS: Array<any>
@ -50,7 +51,7 @@ export default class JanModelExtension extends ModelExtension {
* Called when the extension is unloaded.
* @override
*/
async onUnload() {}
async onUnload() { }
/**
* Downloads a machine learning model.
@ -92,7 +93,7 @@ export default class JanModelExtension extends ModelExtension {
) {
for (const source of modelDto.sources) {
const path = await joinPath(['models', modelDto.id, source.filename])
return abortDownload(path)
await abortDownload(path)
}
}
}
@ -108,7 +109,14 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is deleted.
*/
async deleteModel(model: string): Promise<void> {
const modelDto: Model = ModelManager.instance().get(model)
return this.cortexAPI.deleteModel(model)
.catch(e => console.debug(e))
.finally(async () => {
// Delete legacy model files
await deleteModelFiles(modelDto)
.catch(e => console.debug(e))
})
}
/**

View File

@ -0,0 +1,18 @@
import { fs, joinPath, Model } from "@janhq/core"
export const deleteModelFiles = async (model: Model) => {
try {
const dirPath = await joinPath(['file://models', model.id])
// remove all files under dirPath except model.json
const files = await fs.readdirSync(dirPath)
const deletePromises = files.map(async (fileName: string) => {
if (fileName !== 'model.json') {
return fs.unlinkSync(await joinPath([dirPath, fileName]))
}
})
await Promise.allSettled(deletePromises)
} catch (err) {
console.error(err)
}
}

View File

@ -1,4 +1,4 @@
import { Model, fs, joinPath } from '@janhq/core'
import { InferenceEngine, Model, fs, joinPath } from '@janhq/core'
//// LEGACY MODEL FOLDER ////
/**
* Scan through models folder and return downloaded models
@ -71,7 +71,7 @@ export const scanModelsFolder = async (): Promise<Model[]> => {
file.toLowerCase().endsWith('.gguf') || // GGUF
file.toLowerCase().endsWith('.engine') // Tensort-LLM
)
})?.length >= (model.sources?.length ?? 1)
})?.length >= (model.engine === InferenceEngine.nitro_tensorrt_llm ? 1 : (model.sources?.length ?? 1))
)
})

View File

@ -88,6 +88,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => {
if (state.downloadType !== 'extension') {
state.downloadState = 'end'
setDownloadState(state)
if (state.percent !== 0)
removeDownloadingModel(state.modelId)
}
events.emit(ModelEvent.OnModelsUpdate, {})

View File

@ -108,6 +108,7 @@ export const setDownloadStateAtom = atom(
)
modelDownloadState.children = updatedChildren
if (isAnyChildDownloadNotReady) {
// just update the children
currentState[state.modelId] = modelDownloadState
@ -115,23 +116,17 @@ export const setDownloadStateAtom = atom(
return
}
const parentTotalSize = modelDownloadState.size.total
if (parentTotalSize === 0) {
// calculate the total size of the parent by sum all children total size
const totalSize = updatedChildren.reduce(
const parentTotalSize = updatedChildren.reduce(
(acc, m) => acc + m.size.total,
0
)
modelDownloadState.size.total = totalSize
}
// calculate the total transferred size by sum all children transferred size
const transferredSize = updatedChildren.reduce(
(acc, m) => acc + m.size.transferred,
0
)
modelDownloadState.size.transferred = transferredSize
modelDownloadState.percent =
parentTotalSize === 0 ? 0 : transferredSize / parentTotalSize
currentState[state.modelId] = modelDownloadState