fix(InferenceExtension): #1067 sync the nitro process state (#1493)

Signed-off-by: James <james@jan.ai>
Co-authored-by: James <james@jan.ai>
This commit is contained in:
NamH 2024-01-10 14:15:17 +07:00 committed by GitHub
parent 31fdd89f0e
commit 9183330480
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 63 deletions

View File

@ -20,8 +20,8 @@ import {
fs, fs,
Model, Model,
joinPath, joinPath,
InferenceExtension,
} from "@janhq/core"; } from "@janhq/core";
import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse"; import { requestInference } from "./helpers/sse";
import { ulid } from "ulid"; import { ulid } from "ulid";
import { join } from "path"; import { join } from "path";
@ -36,9 +36,14 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
private static readonly _settingsDir = "file://settings"; private static readonly _settingsDir = "file://settings";
private static readonly _engineMetadataFileName = "nitro.json"; private static readonly _engineMetadataFileName = "nitro.json";
private static _currentModel: Model; /**
* Checking the health for Nitro's process each 5 secs.
*/
private static readonly _intervalHealthCheck = 5 * 1000;
private static _engineSettings: EngineSettings = { private _currentModel: Model;
private _engineSettings: EngineSettings = {
ctx_len: 2048, ctx_len: 2048,
ngl: 100, ngl: 100,
cpu_threads: 1, cpu_threads: 1,
@ -48,6 +53,18 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
controller = new AbortController(); controller = new AbortController();
isCancelled = false; isCancelled = false;
/**
* The interval id for the health check. Used to stop the health check.
*/
private getNitroProcesHealthIntervalId: NodeJS.Timeout | undefined =
undefined;
/**
* Tracking the current state of nitro process.
*/
private nitroProcessInfo: any = undefined;
/** /**
* Returns the type of the extension. * Returns the type of the extension.
* @returns {ExtensionType} The type of the extension. * @returns {ExtensionType} The type of the extension.
@ -71,21 +88,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
this.writeDefaultEngineSettings(); this.writeDefaultEngineSettings();
// Events subscription // Events subscription
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) => this.onMessageRequest(data));
JanInferenceNitroExtension.handleMessageRequest(data, this)
);
events.on(EventName.OnModelInit, (model: Model) => { events.on(EventName.OnModelInit, (model: Model) => this.onModelInit(model));
JanInferenceNitroExtension.handleModelInit(model);
});
events.on(EventName.OnModelStop, (model: Model) => { events.on(EventName.OnModelStop, (model: Model) => this.onModelStop(model));
JanInferenceNitroExtension.handleModelStop(model);
});
events.on(EventName.OnInferenceStopped, () => { events.on(EventName.OnInferenceStopped, () => this.onInferenceStopped());
JanInferenceNitroExtension.handleInferenceStopped(this);
});
// Attempt to fetch nvidia info // Attempt to fetch nvidia info
await executeOnMain(MODULE, "updateNvidiaInfo", {}); await executeOnMain(MODULE, "updateNvidiaInfo", {});
@ -104,12 +113,12 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
); );
if (await fs.existsSync(engineFile)) { if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, "utf-8"); const engine = await fs.readFileSync(engineFile, "utf-8");
JanInferenceNitroExtension._engineSettings = this._engineSettings =
typeof engine === "object" ? engine : JSON.parse(engine); typeof engine === "object" ? engine : JSON.parse(engine);
} else { } else {
await fs.writeFileSync( await fs.writeFileSync(
engineFile, engineFile,
JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2) JSON.stringify(this._engineSettings, null, 2)
); );
} }
} catch (err) { } catch (err) {
@ -117,10 +126,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
} }
} }
private static async handleModelInit(model: Model) { private async onModelInit(model: Model) {
if (model.engine !== "nitro") { if (model.engine !== "nitro") return;
return;
}
const modelFullPath = await joinPath(["models", model.id]); const modelFullPath = await joinPath(["models", model.id]);
const nitroInitResult = await executeOnMain(MODULE, "initModel", { const nitroInitResult = await executeOnMain(MODULE, "initModel", {
@ -130,26 +138,49 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
if (nitroInitResult.error === null) { if (nitroInitResult.error === null) {
events.emit(EventName.OnModelFail, model); events.emit(EventName.OnModelFail, model);
} else {
JanInferenceNitroExtension._currentModel = model;
events.emit(EventName.OnModelReady, model);
}
}
private static async handleModelStop(model: Model) {
if (model.engine !== "nitro") {
return; return;
} else { }
await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, model); this._currentModel = model;
events.emit(EventName.OnModelReady, model);
this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck
);
}
private async onModelStop(model: Model) {
if (model.engine !== "nitro") return;
await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, {});
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
console.debug("Stop calling Nitro process health check");
clearInterval(this.getNitroProcesHealthIntervalId);
this.getNitroProcesHealthIntervalId = undefined;
} }
} }
private static async handleInferenceStopped( /**
instance: JanInferenceNitroExtension * Periodically check for nitro process's health.
) { */
instance.isCancelled = true; private async periodicallyGetNitroHealth(): Promise<void> {
instance.controller?.abort(); const health = await executeOnMain(MODULE, "getCurrentNitroProcessInfo");
const isRunning = this.nitroProcessInfo?.isRunning ?? false;
if (isRunning && health.isRunning === false) {
console.debug("Nitro process is stopped");
events.emit(EventName.OnModelStopped, {});
}
this.nitroProcessInfo = health;
}
private async onInferenceStopped() {
this.isCancelled = true;
this.controller?.abort();
} }
/** /**
@ -171,10 +202,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
}; };
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference( requestInference(data.messages ?? [], this._currentModel).subscribe({
data.messages ?? [],
JanInferenceNitroExtension._currentModel
).subscribe({
next: (_content) => {}, next: (_content) => {},
complete: async () => { complete: async () => {
resolve(message); resolve(message);
@ -192,13 +220,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
* Pass instance as a reference. * Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request. * @param {MessageRequest} data - The data for the new message request.
*/ */
private static async handleMessageRequest( private async onMessageRequest(data: MessageRequest) {
data: MessageRequest, if (data.model.engine !== "nitro") return;
instance: JanInferenceNitroExtension
) {
if (data.model.engine !== "nitro") {
return;
}
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
@ -213,13 +237,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
}; };
events.emit(EventName.OnMessageResponse, message); events.emit(EventName.OnMessageResponse, message);
instance.isCancelled = false; this.isCancelled = false;
instance.controller = new AbortController(); this.controller = new AbortController();
requestInference( requestInference(
data.messages ?? [], data.messages ?? [],
{ ...JanInferenceNitroExtension._currentModel, ...data.model }, { ...this._currentModel, ...data.model },
instance.controller this.controller
).subscribe({ ).subscribe({
next: (content) => { next: (content) => {
const messageContent: ThreadContent = { const messageContent: ThreadContent = {
@ -239,7 +263,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
events.emit(EventName.OnMessageUpdate, message); events.emit(EventName.OnMessageUpdate, message);
}, },
error: async (err) => { error: async (err) => {
if (instance.isCancelled || message.content.length) { if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped; message.status = MessageStatus.Stopped;
events.emit(EventName.OnMessageUpdate, message); events.emit(EventName.OnMessageUpdate, message);
return; return;

View File

@ -43,6 +43,8 @@ let subprocess = undefined;
let currentModelFile: string = undefined; let currentModelFile: string = undefined;
let currentSettings = undefined; let currentSettings = undefined;
let nitroProcessInfo = undefined;
/** /**
* Stops a Nitro subprocess. * Stops a Nitro subprocess.
* @param wrapper - The model wrapper. * @param wrapper - The model wrapper.
@ -80,7 +82,7 @@ async function updateNvidiaDriverInfo(): Promise<void> {
); );
} }
function checkFileExistenceInPaths(file: string, paths: string[]): boolean { function isExists(file: string, paths: string[]): boolean {
return paths.some((p) => existsSync(path.join(p, file))); return paths.some((p) => existsSync(path.join(p, file)));
} }
@ -104,12 +106,12 @@ function updateCudaExistence() {
} }
let cudaExists = filesCuda12.every( let cudaExists = filesCuda12.every(
(file) => existsSync(file) || checkFileExistenceInPaths(file, paths) (file) => existsSync(file) || isExists(file, paths)
); );
if (!cudaExists) { if (!cudaExists) {
cudaExists = filesCuda11.every( cudaExists = filesCuda11.every(
(file) => existsSync(file) || checkFileExistenceInPaths(file, paths) (file) => existsSync(file) || isExists(file, paths)
); );
if (cudaExists) { if (cudaExists) {
cudaVersion = "11"; cudaVersion = "11";
@ -461,7 +463,7 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise<any> {
function getResourcesInfo(): Promise<ResourcesInfo> { function getResourcesInfo(): Promise<ResourcesInfo> {
return new Promise(async (resolve) => { return new Promise(async (resolve) => {
const cpu = await osUtils.cpuCount(); const cpu = await osUtils.cpuCount();
console.log("cpu: ", cpu); console.debug("cpu: ", cpu);
const response: ResourcesInfo = { const response: ResourcesInfo = {
numCpuPhysicalCore: cpu, numCpuPhysicalCore: cpu,
memAvailable: 0, memAvailable: 0,
@ -470,6 +472,13 @@ function getResourcesInfo(): Promise<ResourcesInfo> {
}); });
} }
const getCurrentNitroProcessInfo = (): Promise<any> => {
nitroProcessInfo = {
isRunning: subprocess != null,
};
return nitroProcessInfo;
};
function dispose() { function dispose() {
// clean other registered resources here // clean other registered resources here
killSubprocess(); killSubprocess();
@ -481,4 +490,5 @@ module.exports = {
killSubprocess, killSubprocess,
dispose, dispose,
updateNvidiaInfo, updateNvidiaInfo,
getCurrentNitroProcessInfo,
}; };

View File

@ -8,8 +8,8 @@ import {
ExtensionType, ExtensionType,
MessageStatus, MessageStatus,
Model, Model,
ConversationalExtension,
} from '@janhq/core' } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
@ -64,14 +64,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
})) }))
} }
async function handleModelStopped(model: Model) { async function handleModelStopped() {
setTimeout(async () => { setTimeout(async () => {
setActiveModel(undefined) setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' }) setStateModel({ state: 'start', loading: false, model: '' })
// toaster({
// title: 'Success!',
// description: `Model ${model.id} has been stopped.`,
// })
}, 500) }, 500)
} }