diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 614c32586..89d051f16 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -20,19 +20,13 @@ import { executeOnMain, getUserSpace, fs, + Model, } from "@janhq/core"; import { InferenceExtension } from "@janhq/core"; import { requestInference } from "./helpers/sse"; import { ulid } from "ulid"; import { join } from "path"; -interface EngineSettings { - ctx_len: number; - ngl: number; - cont_batching: boolean; - embedding: boolean; -} - /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -42,7 +36,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension { private static readonly _homeDir = 'engines' private static readonly _engineMetadataFileName = 'nitro.json' - private _engineSettings: EngineSettings = { + static _currentModel: Model; + + static _engineSettings: EngineSettings = { "ctx_len": 2048, "ngl": 100, "cont_batching": false, @@ -65,9 +61,19 @@ export default class JanInferenceNitroExtension implements InferenceExtension { onLoad(): void { fs.mkdir(JanInferenceNitroExtension._homeDir) this.writeDefaultEngineSettings() + + // Events subscription events.on(EventName.OnMessageSent, (data) => JanInferenceNitroExtension.handleMessageRequest(data, this) ); + + events.on(EventName.OnModelInit, (model: Model) => { + JanInferenceNitroExtension.handleModelInit(model); + }); + + events.on(EventName.OnModelStop, (model: Model) => { + JanInferenceNitroExtension.handleModelStop(model); + }); } /** @@ -85,15 +91,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension { async initModel( modelId: string, settings?: ModelSettingParams - ): Promise { - const userSpacePath = await getUserSpace(); - const modelFullPath = join(userSpacePath, "models", modelId, modelId); - - return executeOnMain(MODULE, "initModel", { - modelFullPath, - settings, - }); - } + ): Promise {} /** * Stops the model. @@ -116,16 +114,41 @@ export default class JanInferenceNitroExtension implements InferenceExtension { try { const engine_json = join(JanInferenceNitroExtension._homeDir, JanInferenceNitroExtension._engineMetadataFileName) if (await fs.checkFileExists(engine_json)) { - this._engineSettings = JSON.parse(await fs.readFile(engine_json)) + JanInferenceNitroExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) } else { - await fs.writeFile(engine_json, JSON.stringify(this._engineSettings, null, 2)) + await fs.writeFile(engine_json, JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)) } } catch (err) { console.error(err) } } + private static async handleModelInit(model: Model) { + if (model.engine !== "nitro") { return } + const userSpacePath = await getUserSpace(); + const modelFullPath = join(userSpacePath, "models", model.id, model.id); + + const nitro_init_result = await executeOnMain(MODULE, "initModel", { + modelFullPath: modelFullPath, + model: model + }); + + if (nitro_init_result.error) { + events.emit(EventName.OnModelFail, model) + } + else{ + events.emit(EventName.OnModelReady, model); + } + } + + private static async handleModelStop(model: Model) { + if (model.engine !== 'nitro') { return } + else { + events.emit(EventName.OnModelStopped, model) + } + } + /** * Makes a single response inference request. * @param {MessageRequest} data - The data for the inference request. @@ -145,14 +168,17 @@ export default class JanInferenceNitroExtension implements InferenceExtension { }; return new Promise(async (resolve, reject) => { - requestInference(data.messages ?? []).subscribe({ + requestInference(data.messages ?? [], + JanInferenceNitroExtension._engineSettings, + JanInferenceNitroExtension._currentModel) + .subscribe({ next: (_content) => {}, - complete: async () => { - resolve(message); - }, - error: async (err) => { - reject(err); - }, + complete: async () => { + resolve(message); + }, + error: async (err) => { + reject(err); + }, }); }); } @@ -167,6 +193,8 @@ export default class JanInferenceNitroExtension implements InferenceExtension { data: MessageRequest, instance: JanInferenceNitroExtension ) { + if (data.model.engine !== 'nitro') { return } + const timestamp = Date.now(); const message: ThreadMessage = { id: ulid(), @@ -184,7 +212,11 @@ export default class JanInferenceNitroExtension implements InferenceExtension { instance.isCancelled = false; instance.controller = new AbortController(); - requestInference(data.messages, instance.controller).subscribe({ + requestInference(data.messages ?? [], + JanInferenceNitroExtension._engineSettings, + JanInferenceNitroExtension._currentModel, + instance.controller) + .subscribe({ next: (content) => { const messageContent: ThreadContent = { type: ContentType.Text, diff --git a/extensions/inference-nitro-extension/src/module.ts b/extensions/inference-nitro-extension/src/module.ts index 3eeedec32..5b7a52c60 100644 --- a/extensions/inference-nitro-extension/src/module.ts +++ b/extensions/inference-nitro-extension/src/module.ts @@ -36,35 +36,33 @@ interface InitModelResponse { * TODO: Should it be startModel instead? */ function initModel(wrapper: any): Promise { - if (wrapper.settings.engine !== "llamacpp") { - return - } - // 1. Check if the model file exists - currentModelFile = wrapper.modelFullPath; - log.info("Started to load model " + wrapper.modelFullPath); - - const settings = { - llama_model_path: currentModelFile, - ...wrapper.settings, - }; - - log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`); - - return ( - // 1. Check if the port is used, if used, attempt to unload model / kill nitro process - validateModelVersion() - .then(checkAndUnloadNitro) - // 2. Spawn the Nitro subprocess - .then(spawnNitroProcess) - // 4. Load the model into the Nitro subprocess (HTTP POST request) - .then(() => loadLLMModel(settings)) - // 5. Check if the model is loaded successfully - .then(validateModelStatus) - .catch((err) => { - log.error("error: " + JSON.stringify(err)); - return { error: err, currentModelFile }; - }) - ); + currentModelFile = wrapper.modelFullPath; + if (wrapper.model.engine !== "nitro") { + return Promise.resolve({ error: "Not a nitro model" }) + } + else { + log.info("Started to load model " + wrapper.model.modelFullPath); + const settings = { + llama_model_path: currentModelFile, + ...wrapper.model.settings, + }; + log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`); + return ( + // 1. Check if the port is used, if used, attempt to unload model / kill nitro process + validateModelVersion() + .then(checkAndUnloadNitro) + // 2. Spawn the Nitro subprocess + .then(spawnNitroProcess) + // 4. Load the model into the Nitro subprocess (HTTP POST request) + .then(() => loadLLMModel(settings)) + // 5. Check if the model is loaded successfully + .then(validateModelStatus) + .catch((err) => { + log.error("error: " + JSON.stringify(err)); + return { error: err, currentModelFile }; + }) + ); + } } /**