fix: Add dynamic values from engine settings and model params

This commit is contained in:
hiro 2023-12-05 00:45:23 +07:00
parent 16f2ffe9b4
commit 0c3e23665b
2 changed files with 86 additions and 56 deletions

View File

@ -20,19 +20,13 @@ import {
executeOnMain, executeOnMain,
getUserSpace, getUserSpace,
fs, fs,
Model,
} from "@janhq/core"; } from "@janhq/core";
import { InferenceExtension } from "@janhq/core"; import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse"; import { requestInference } from "./helpers/sse";
import { ulid } from "ulid"; import { ulid } from "ulid";
import { join } from "path"; import { join } from "path";
interface EngineSettings {
ctx_len: number;
ngl: number;
cont_batching: boolean;
embedding: boolean;
}
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
@ -42,7 +36,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
private static readonly _homeDir = 'engines' private static readonly _homeDir = 'engines'
private static readonly _engineMetadataFileName = 'nitro.json' private static readonly _engineMetadataFileName = 'nitro.json'
private _engineSettings: EngineSettings = { static _currentModel: Model;
static _engineSettings: EngineSettings = {
"ctx_len": 2048, "ctx_len": 2048,
"ngl": 100, "ngl": 100,
"cont_batching": false, "cont_batching": false,
@ -65,9 +61,19 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
onLoad(): void { onLoad(): void {
fs.mkdir(JanInferenceNitroExtension._homeDir) fs.mkdir(JanInferenceNitroExtension._homeDir)
this.writeDefaultEngineSettings() this.writeDefaultEngineSettings()
// Events subscription
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) =>
JanInferenceNitroExtension.handleMessageRequest(data, this) JanInferenceNitroExtension.handleMessageRequest(data, this)
); );
events.on(EventName.OnModelInit, (model: Model) => {
JanInferenceNitroExtension.handleModelInit(model);
});
events.on(EventName.OnModelStop, (model: Model) => {
JanInferenceNitroExtension.handleModelStop(model);
});
} }
/** /**
@ -85,15 +91,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
async initModel( async initModel(
modelId: string, modelId: string,
settings?: ModelSettingParams settings?: ModelSettingParams
): Promise<void> { ): Promise<void> {}
const userSpacePath = await getUserSpace();
const modelFullPath = join(userSpacePath, "models", modelId, modelId);
return executeOnMain(MODULE, "initModel", {
modelFullPath,
settings,
});
}
/** /**
* Stops the model. * Stops the model.
@ -116,16 +114,41 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
try { try {
const engine_json = join(JanInferenceNitroExtension._homeDir, JanInferenceNitroExtension._engineMetadataFileName) const engine_json = join(JanInferenceNitroExtension._homeDir, JanInferenceNitroExtension._engineMetadataFileName)
if (await fs.checkFileExists(engine_json)) { if (await fs.checkFileExists(engine_json)) {
this._engineSettings = JSON.parse(await fs.readFile(engine_json)) JanInferenceNitroExtension._engineSettings = JSON.parse(await fs.readFile(engine_json))
} }
else { else {
await fs.writeFile(engine_json, JSON.stringify(this._engineSettings, null, 2)) await fs.writeFile(engine_json, JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2))
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
} }
private static async handleModelInit(model: Model) {
if (model.engine !== "nitro") { return }
const userSpacePath = await getUserSpace();
const modelFullPath = join(userSpacePath, "models", model.id, model.id);
const nitro_init_result = await executeOnMain(MODULE, "initModel", {
modelFullPath: modelFullPath,
model: model
});
if (nitro_init_result.error) {
events.emit(EventName.OnModelFail, model)
}
else{
events.emit(EventName.OnModelReady, model);
}
}
private static async handleModelStop(model: Model) {
if (model.engine !== 'nitro') { return }
else {
events.emit(EventName.OnModelStopped, model)
}
}
/** /**
* Makes a single response inference request. * Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request. * @param {MessageRequest} data - The data for the inference request.
@ -145,14 +168,17 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
}; };
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? []).subscribe({ requestInference(data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel)
.subscribe({
next: (_content) => {}, next: (_content) => {},
complete: async () => { complete: async () => {
resolve(message); resolve(message);
}, },
error: async (err) => { error: async (err) => {
reject(err); reject(err);
}, },
}); });
}); });
} }
@ -167,6 +193,8 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
data: MessageRequest, data: MessageRequest,
instance: JanInferenceNitroExtension instance: JanInferenceNitroExtension
) { ) {
if (data.model.engine !== 'nitro') { return }
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
@ -184,7 +212,11 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
instance.isCancelled = false; instance.isCancelled = false;
instance.controller = new AbortController(); instance.controller = new AbortController();
requestInference(data.messages, instance.controller).subscribe({ requestInference(data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel,
instance.controller)
.subscribe({
next: (content) => { next: (content) => {
const messageContent: ThreadContent = { const messageContent: ThreadContent = {
type: ContentType.Text, type: ContentType.Text,

View File

@ -36,35 +36,33 @@ interface InitModelResponse {
* TODO: Should it be startModel instead? * TODO: Should it be startModel instead?
*/ */
function initModel(wrapper: any): Promise<InitModelResponse> { function initModel(wrapper: any): Promise<InitModelResponse> {
if (wrapper.settings.engine !== "llamacpp") { currentModelFile = wrapper.modelFullPath;
return if (wrapper.model.engine !== "nitro") {
} return Promise.resolve({ error: "Not a nitro model" })
// 1. Check if the model file exists }
currentModelFile = wrapper.modelFullPath; else {
log.info("Started to load model " + wrapper.modelFullPath); log.info("Started to load model " + wrapper.model.modelFullPath);
const settings = {
const settings = { llama_model_path: currentModelFile,
llama_model_path: currentModelFile, ...wrapper.model.settings,
...wrapper.settings, };
}; log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
return (
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`); // 1. Check if the port is used, if used, attempt to unload model / kill nitro process
validateModelVersion()
return ( .then(checkAndUnloadNitro)
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process // 2. Spawn the Nitro subprocess
validateModelVersion() .then(spawnNitroProcess)
.then(checkAndUnloadNitro) // 4. Load the model into the Nitro subprocess (HTTP POST request)
// 2. Spawn the Nitro subprocess .then(() => loadLLMModel(settings))
.then(spawnNitroProcess) // 5. Check if the model is loaded successfully
// 4. Load the model into the Nitro subprocess (HTTP POST request) .then(validateModelStatus)
.then(() => loadLLMModel(settings)) .catch((err) => {
// 5. Check if the model is loaded successfully log.error("error: " + JSON.stringify(err));
.then(validateModelStatus) return { error: err, currentModelFile };
.catch((err) => { })
log.error("error: " + JSON.stringify(err)); );
return { error: err, currentModelFile }; }
})
);
} }
/** /**