diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 9af2230dc..f8370676f 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -20,8 +20,8 @@ import { fs, Model, joinPath, + InferenceExtension, } from "@janhq/core"; -import { InferenceExtension } from "@janhq/core"; import { requestInference } from "./helpers/sse"; import { ulid } from "ulid"; import { join } from "path"; @@ -36,9 +36,14 @@ export default class JanInferenceNitroExtension implements InferenceExtension { private static readonly _settingsDir = "file://settings"; private static readonly _engineMetadataFileName = "nitro.json"; - private static _currentModel: Model; + /** + * Checking the health for Nitro's process each 5 secs. + */ + private static readonly _intervalHealthCheck = 5 * 1000; - private static _engineSettings: EngineSettings = { + private _currentModel: Model; + + private _engineSettings: EngineSettings = { ctx_len: 2048, ngl: 100, cpu_threads: 1, @@ -48,6 +53,18 @@ export default class JanInferenceNitroExtension implements InferenceExtension { controller = new AbortController(); isCancelled = false; + + /** + * The interval id for the health check. Used to stop the health check. + */ + private getNitroProcesHealthIntervalId: NodeJS.Timeout | undefined = + undefined; + + /** + * Tracking the current state of nitro process. + */ + private nitroProcessInfo: any = undefined; + /** * Returns the type of the extension. * @returns {ExtensionType} The type of the extension. @@ -71,21 +88,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension { this.writeDefaultEngineSettings(); // Events subscription - events.on(EventName.OnMessageSent, (data) => - JanInferenceNitroExtension.handleMessageRequest(data, this) - ); + events.on(EventName.OnMessageSent, (data) => this.onMessageRequest(data)); - events.on(EventName.OnModelInit, (model: Model) => { - JanInferenceNitroExtension.handleModelInit(model); - }); + events.on(EventName.OnModelInit, (model: Model) => this.onModelInit(model)); - events.on(EventName.OnModelStop, (model: Model) => { - JanInferenceNitroExtension.handleModelStop(model); - }); + events.on(EventName.OnModelStop, (model: Model) => this.onModelStop(model)); - events.on(EventName.OnInferenceStopped, () => { - JanInferenceNitroExtension.handleInferenceStopped(this); - }); + events.on(EventName.OnInferenceStopped, () => this.onInferenceStopped()); // Attempt to fetch nvidia info await executeOnMain(MODULE, "updateNvidiaInfo", {}); @@ -104,12 +113,12 @@ export default class JanInferenceNitroExtension implements InferenceExtension { ); if (await fs.existsSync(engineFile)) { const engine = await fs.readFileSync(engineFile, "utf-8"); - JanInferenceNitroExtension._engineSettings = + this._engineSettings = typeof engine === "object" ? engine : JSON.parse(engine); } else { await fs.writeFileSync( engineFile, - JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2) + JSON.stringify(this._engineSettings, null, 2) ); } } catch (err) { @@ -117,10 +126,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension { } } - private static async handleModelInit(model: Model) { - if (model.engine !== "nitro") { - return; - } + private async onModelInit(model: Model) { + if (model.engine !== "nitro") return; + const modelFullPath = await joinPath(["models", model.id]); const nitroInitResult = await executeOnMain(MODULE, "initModel", { @@ -130,26 +138,49 @@ export default class JanInferenceNitroExtension implements InferenceExtension { if (nitroInitResult.error === null) { events.emit(EventName.OnModelFail, model); - } else { - JanInferenceNitroExtension._currentModel = model; - events.emit(EventName.OnModelReady, model); - } - } - - private static async handleModelStop(model: Model) { - if (model.engine !== "nitro") { return; - } else { - await executeOnMain(MODULE, "stopModel"); - events.emit(EventName.OnModelStopped, model); + } + + this._currentModel = model; + events.emit(EventName.OnModelReady, model); + + this.getNitroProcesHealthIntervalId = setInterval( + () => this.periodicallyGetNitroHealth(), + JanInferenceNitroExtension._intervalHealthCheck + ); + } + + private async onModelStop(model: Model) { + if (model.engine !== "nitro") return; + + await executeOnMain(MODULE, "stopModel"); + events.emit(EventName.OnModelStopped, {}); + + // stop the periocally health check + if (this.getNitroProcesHealthIntervalId) { + console.debug("Stop calling Nitro process health check"); + clearInterval(this.getNitroProcesHealthIntervalId); + this.getNitroProcesHealthIntervalId = undefined; } } - private static async handleInferenceStopped( - instance: JanInferenceNitroExtension - ) { - instance.isCancelled = true; - instance.controller?.abort(); + /** + * Periodically check for nitro process's health. + */ + private async periodicallyGetNitroHealth(): Promise { + const health = await executeOnMain(MODULE, "getCurrentNitroProcessInfo"); + + const isRunning = this.nitroProcessInfo?.isRunning ?? false; + if (isRunning && health.isRunning === false) { + console.debug("Nitro process is stopped"); + events.emit(EventName.OnModelStopped, {}); + } + this.nitroProcessInfo = health; + } + + private async onInferenceStopped() { + this.isCancelled = true; + this.controller?.abort(); } /** @@ -171,10 +202,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension { }; return new Promise(async (resolve, reject) => { - requestInference( - data.messages ?? [], - JanInferenceNitroExtension._currentModel - ).subscribe({ + requestInference(data.messages ?? [], this._currentModel).subscribe({ next: (_content) => {}, complete: async () => { resolve(message); @@ -192,13 +220,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension { * Pass instance as a reference. * @param {MessageRequest} data - The data for the new message request. */ - private static async handleMessageRequest( - data: MessageRequest, - instance: JanInferenceNitroExtension - ) { - if (data.model.engine !== "nitro") { - return; - } + private async onMessageRequest(data: MessageRequest) { + if (data.model.engine !== "nitro") return; + const timestamp = Date.now(); const message: ThreadMessage = { id: ulid(), @@ -213,13 +237,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension { }; events.emit(EventName.OnMessageResponse, message); - instance.isCancelled = false; - instance.controller = new AbortController(); + this.isCancelled = false; + this.controller = new AbortController(); requestInference( data.messages ?? [], - { ...JanInferenceNitroExtension._currentModel, ...data.model }, - instance.controller + { ...this._currentModel, ...data.model }, + this.controller ).subscribe({ next: (content) => { const messageContent: ThreadContent = { @@ -239,7 +263,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension { events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { - if (instance.isCancelled || message.content.length) { + if (this.isCancelled || message.content.length) { message.status = MessageStatus.Stopped; events.emit(EventName.OnMessageUpdate, message); return; diff --git a/extensions/inference-nitro-extension/src/module.ts b/extensions/inference-nitro-extension/src/module.ts index 3654410d4..eb61afe75 100644 --- a/extensions/inference-nitro-extension/src/module.ts +++ b/extensions/inference-nitro-extension/src/module.ts @@ -43,6 +43,8 @@ let subprocess = undefined; let currentModelFile: string = undefined; let currentSettings = undefined; +let nitroProcessInfo = undefined; + /** * Stops a Nitro subprocess. * @param wrapper - The model wrapper. @@ -80,7 +82,7 @@ async function updateNvidiaDriverInfo(): Promise { ); } -function checkFileExistenceInPaths(file: string, paths: string[]): boolean { +function isExists(file: string, paths: string[]): boolean { return paths.some((p) => existsSync(path.join(p, file))); } @@ -104,12 +106,12 @@ function updateCudaExistence() { } let cudaExists = filesCuda12.every( - (file) => existsSync(file) || checkFileExistenceInPaths(file, paths) + (file) => existsSync(file) || isExists(file, paths) ); if (!cudaExists) { cudaExists = filesCuda11.every( - (file) => existsSync(file) || checkFileExistenceInPaths(file, paths) + (file) => existsSync(file) || isExists(file, paths) ); if (cudaExists) { cudaVersion = "11"; @@ -461,7 +463,7 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise { function getResourcesInfo(): Promise { return new Promise(async (resolve) => { const cpu = await osUtils.cpuCount(); - console.log("cpu: ", cpu); + console.debug("cpu: ", cpu); const response: ResourcesInfo = { numCpuPhysicalCore: cpu, memAvailable: 0, @@ -470,6 +472,13 @@ function getResourcesInfo(): Promise { }); } +const getCurrentNitroProcessInfo = (): Promise => { + nitroProcessInfo = { + isRunning: subprocess != null, + }; + return nitroProcessInfo; +}; + function dispose() { // clean other registered resources here killSubprocess(); @@ -481,4 +490,5 @@ module.exports = { killSubprocess, dispose, updateNvidiaInfo, + getCurrentNitroProcessInfo, }; diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index aafbbb787..66622b1b6 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -8,8 +8,8 @@ import { ExtensionType, MessageStatus, Model, + ConversationalExtension, } from '@janhq/core' -import { ConversationalExtension } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' @@ -64,14 +64,10 @@ export default function EventHandler({ children }: { children: ReactNode }) { })) } - async function handleModelStopped(model: Model) { + async function handleModelStopped() { setTimeout(async () => { setActiveModel(undefined) setStateModel({ state: 'start', loading: false, model: '' }) - // toaster({ - // title: 'Success!', - // description: `Model ${model.id} has been stopped.`, - // }) }, 500) }