fix: Update code based on comments from @james

This commit is contained in:
hiro 2023-12-07 15:13:04 +07:00
parent fbf8ff9d3d
commit 1177007ad6
10 changed files with 166 additions and 150 deletions

View File

@ -25,7 +25,7 @@ const readFile: (path: string) => Promise<any> = (path) => global.core.api?.read
* @param {string} path * @param {string} path
* @returns {boolean} A boolean indicating whether the path is a file. * @returns {boolean} A boolean indicating whether the path is a file.
*/ */
const checkFileExists = (path: string): Promise<boolean> => global.core.api?.checkFileExists(path) const exists = (path: string): Promise<boolean> => global.core.api?.exists(path)
/** /**
* List the directory files * List the directory files
* @param {string} path - The path of the directory to list files. * @param {string} path - The path of the directory to list files.
@ -75,7 +75,7 @@ export const fs = {
isDirectory, isDirectory,
writeFile, writeFile,
readFile, readFile,
checkFileExists, exists,
listFiles, listFiles,
mkdir, mkdir,
rmdir, rmdir,

View File

@ -41,11 +41,8 @@ export type MessageRequest = {
/** Messages for constructing a chat completion request **/ /** Messages for constructing a chat completion request **/
messages?: ChatCompletionMessage[]; messages?: ChatCompletionMessage[];
/** Runtime parameters for constructing a chat completion request **/
parameters?: ModelRuntimeParams;
/** Settings for constructing a chat completion request **/ /** Settings for constructing a chat completion request **/
model?: ModelInfo model?: ModelInfo;
}; };
/** /**

View File

@ -56,7 +56,7 @@ export function handleFsIPCs() {
* @param path - The path of the file to check. * @param path - The path of the file to check.
* @returns A promise that resolves with a boolean indicating whether the file exists. * @returns A promise that resolves with a boolean indicating whether the file exists.
*/ */
ipcMain.handle('checkFileExists', async (_event, path: string) => { ipcMain.handle('exists', async (_event, path: string) => {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
const fullPath = join(userSpacePath, path) const fullPath = join(userSpacePath, path)
fs.existsSync(fullPath) ? resolve(true) : resolve(false) fs.existsSync(fullPath) ? resolve(true) : resolve(false)

View File

@ -31,7 +31,7 @@ export function fsInvokers() {
* Reads a file at the specified path. * Reads a file at the specified path.
* @param {string} path - The path of the file to read. * @param {string} path - The path of the file to read.
*/ */
checkFileExists: (path: string) => ipcRenderer.invoke('checkFileExists', path), exists: (path: string) => ipcRenderer.invoke('exists', path),
/** /**
* Writes data to a file at the specified path. * Writes data to a file at the specified path.

View File

@ -10,10 +10,10 @@ declare const INFERENCE_URL: string;
* @property settings.embedding - Whether to use embedding. * @property settings.embedding - Whether to use embedding.
*/ */
interface EngineSettings { interface EngineSettings {
ctx_len: number; ctx_len: number;
ngl: number; ngl: number;
cont_batching: boolean; cont_batching: boolean;
embedding: boolean; embedding: boolean;
} }
/** /**
@ -21,7 +21,6 @@ interface EngineSettings {
* @property error - An error message if the model fails to load. * @property error - An error message if the model fails to load.
*/ */
interface ModelOperationResponse { interface ModelOperationResponse {
error?: any; error?: any;
modelFile?: string; modelFile?: string;
} }

View File

@ -33,17 +33,17 @@ import { join } from "path";
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceNitroExtension implements InferenceExtension { export default class JanInferenceNitroExtension implements InferenceExtension {
private static readonly _homeDir = 'engines' private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = 'nitro.json' private static readonly _engineMetadataFileName = "nitro.json";
static _currentModel: Model; private static _currentModel: Model;
static _engineSettings: EngineSettings = { private static _engineSettings: EngineSettings = {
"ctx_len": 2048, ctx_len: 2048,
"ngl": 100, ngl: 100,
"cont_batching": false, cont_batching: false,
"embedding": false embedding: false,
} };
controller = new AbortController(); controller = new AbortController();
isCancelled = false; isCancelled = false;
@ -59,12 +59,12 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
onLoad(): void { onLoad(): void {
fs.mkdir(JanInferenceNitroExtension._homeDir) fs.mkdir(JanInferenceNitroExtension._homeDir);
this.writeDefaultEngineSettings() this.writeDefaultEngineSettings();
// Events subscription // Events subscription
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) =>
JanInferenceNitroExtension.handleMessageRequest(data, this) JanInferenceNitroExtension.handleMessageRequest(data, this)
); );
events.on(EventName.OnModelInit, (model: Model) => { events.on(EventName.OnModelInit, (model: Model) => {
@ -112,42 +112,51 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
private async writeDefaultEngineSettings() { private async writeDefaultEngineSettings() {
try { try {
const engine_json = join(JanInferenceNitroExtension._homeDir, JanInferenceNitroExtension._engineMetadataFileName) const engineFile = join(
if (await fs.checkFileExists(engine_json)) { JanInferenceNitroExtension._homeDir,
JanInferenceNitroExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) JanInferenceNitroExtension._engineMetadataFileName
} );
else { if (await fs.exists(engineFile)) {
await fs.writeFile(engine_json, JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)) JanInferenceNitroExtension._engineSettings = JSON.parse(
await fs.readFile(engineFile)
);
} else {
await fs.writeFile(
engineFile,
JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)
);
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err);
} }
} }
private static async handleModelInit(model: Model) { private static async handleModelInit(model: Model) {
if (model.engine !== "nitro") { return } if (model.engine !== "nitro") {
return;
}
const userSpacePath = await getUserSpace(); const userSpacePath = await getUserSpace();
const modelFullPath = join(userSpacePath, "models", model.id, model.id); const modelFullPath = join(userSpacePath, "models", model.id, model.id);
const nitro_init_result = await executeOnMain(MODULE, "initModel", { const nitroInitResult = await executeOnMain(MODULE, "initModel", {
modelFullPath: modelFullPath, modelFullPath: modelFullPath,
model: model model: model,
}); });
if (nitro_init_result.error === null) { if (nitroInitResult.error === null) {
events.emit(EventName.OnModelFail, model) events.emit(EventName.OnModelFail, model);
} } else {
else{
JanInferenceNitroExtension._currentModel = model; JanInferenceNitroExtension._currentModel = model;
events.emit(EventName.OnModelReady, model); events.emit(EventName.OnModelReady, model);
} }
} }
private static async handleModelStop(model: Model) { private static async handleModelStop(model: Model) {
if (model.engine !== 'nitro') { return } if (model.engine !== "nitro") {
else { return;
await executeOnMain(MODULE, "stopModel") } else {
events.emit(EventName.OnModelStopped, model) await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, model);
} }
} }
@ -171,18 +180,17 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference( requestInference(
data.messages ?? [], data.messages ?? [],
JanInferenceNitroExtension._engineSettings, JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel JanInferenceNitroExtension._currentModel
) ).subscribe({
.subscribe({
next: (_content) => {}, next: (_content) => {},
complete: async () => { complete: async () => {
resolve(message); resolve(message);
}, },
error: async (err) => { error: async (err) => {
reject(err); reject(err);
}, },
}); });
}); });
} }
@ -197,7 +205,9 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
data: MessageRequest, data: MessageRequest,
instance: JanInferenceNitroExtension instance: JanInferenceNitroExtension
) { ) {
if (data.model.engine !== 'nitro') { return } if (data.model.engine !== "nitro") {
return;
}
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
@ -216,11 +226,11 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
instance.controller = new AbortController(); instance.controller = new AbortController();
requestInference( requestInference(
data.messages ?? [], data.messages ?? [],
JanInferenceNitroExtension._engineSettings, JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel, JanInferenceNitroExtension._currentModel,
instance.controller instance.controller
).subscribe({ ).subscribe({
next: (content) => { next: (content) => {
const messageContent: ThreadContent = { const messageContent: ThreadContent = {
type: ContentType.Text, type: ContentType.Text,

View File

@ -26,9 +26,9 @@ let currentModelFile = null;
*/ */
function stopModel(): Promise<ModelOperationResponse> { function stopModel(): Promise<ModelOperationResponse> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
checkAndUnloadNitro() checkAndUnloadNitro();
resolve({ error: undefined}) resolve({ error: undefined });
}) });
} }
/** /**
@ -39,33 +39,32 @@ function stopModel(): Promise<ModelOperationResponse> {
* TODO: Should it be startModel instead? * TODO: Should it be startModel instead?
*/ */
function initModel(wrapper: any): Promise<ModelOperationResponse> { function initModel(wrapper: any): Promise<ModelOperationResponse> {
currentModelFile = wrapper.modelFullPath; currentModelFile = wrapper.modelFullPath;
if (wrapper.model.engine !== "nitro") { if (wrapper.model.engine !== "nitro") {
return Promise.resolve({ error: "Not a nitro model" }) return Promise.resolve({ error: "Not a nitro model" });
} } else {
else { log.info("Started to load model " + wrapper.model.modelFullPath);
log.info("Started to load model " + wrapper.model.modelFullPath); const settings = {
const settings = { llama_model_path: currentModelFile,
llama_model_path: currentModelFile, ...wrapper.model.settings,
...wrapper.model.settings, };
}; log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`); return (
return ( // 1. Check if the port is used, if used, attempt to unload model / kill nitro process
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process validateModelVersion()
validateModelVersion() .then(checkAndUnloadNitro)
.then(checkAndUnloadNitro) // 2. Spawn the Nitro subprocess
// 2. Spawn the Nitro subprocess .then(spawnNitroProcess)
.then(spawnNitroProcess) // 4. Load the model into the Nitro subprocess (HTTP POST request)
// 4. Load the model into the Nitro subprocess (HTTP POST request) .then(() => loadLLMModel(settings))
.then(() => loadLLMModel(settings)) // 5. Check if the model is loaded successfully
// 5. Check if the model is loaded successfully .then(validateModelStatus)
.then(validateModelStatus) .catch((err) => {
.catch((err) => { log.error("error: " + JSON.stringify(err));
log.error("error: " + JSON.stringify(err)); return { error: err, currentModelFile };
return { error: err, currentModelFile }; })
}) );
); }
}
} }
/** /**
@ -148,13 +147,12 @@ async function checkAndUnloadNitro() {
// If inUse - try unload or kill process, otherwise do nothing // If inUse - try unload or kill process, otherwise do nothing
if (inUse) { if (inUse) {
// Attempt to unload model // Attempt to unload model
return await fetch(NITRO_HTTP_UNLOAD_MODEL_URL, { return fetch(NITRO_HTTP_UNLOAD_MODEL_URL, {
method: "GET", method: "GET",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
}) }).catch((err) => {
.catch((err) => {
console.error(err); console.error(err);
// Fallback to kill the port // Fallback to kill the port
return killSubprocess(); return killSubprocess();

View File

@ -3,25 +3,25 @@ import { Model } from "@janhq/core";
declare const MODULE: string; declare const MODULE: string;
declare interface EngineSettings { declare interface EngineSettings {
base_url?: string; base_url?: string;
api_key?: string; api_key?: string;
} }
enum OpenAIChatCompletionModelName { enum OpenAIChatCompletionModelName {
'gpt-3.5-turbo-instruct' = 'gpt-3.5-turbo-instruct', "gpt-3.5-turbo-instruct" = "gpt-3.5-turbo-instruct",
'gpt-3.5-turbo-instruct-0914' = 'gpt-3.5-turbo-instruct-0914', "gpt-3.5-turbo-instruct-0914" = "gpt-3.5-turbo-instruct-0914",
'gpt-4-1106-preview' = 'gpt-4-1106-preview', "gpt-4-1106-preview" = "gpt-4-1106-preview",
'gpt-3.5-turbo-0613' = 'gpt-3.5-turbo-0613', "gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
'gpt-3.5-turbo-0301' = 'gpt-3.5-turbo-0301', "gpt-3.5-turbo-0301" = "gpt-3.5-turbo-0301",
'gpt-3.5-turbo' = 'gpt-3.5-turbo', "gpt-3.5-turbo" = "gpt-3.5-turbo",
'gpt-3.5-turbo-16k-0613' = 'gpt-3.5-turbo-16k-0613', "gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
'gpt-3.5-turbo-1106' = 'gpt-3.5-turbo-1106', "gpt-3.5-turbo-1106" = "gpt-3.5-turbo-1106",
'gpt-4-vision-preview' = 'gpt-4-vision-preview', "gpt-4-vision-preview" = "gpt-4-vision-preview",
'gpt-4' = 'gpt-4', "gpt-4" = "gpt-4",
'gpt-4-0314' = 'gpt-4-0314', "gpt-4-0314" = "gpt-4-0314",
'gpt-4-0613' = 'gpt-4-0613', "gpt-4-0613" = "gpt-4-0613",
} }
declare type OpenAIModel = Omit<Model, "id"> & { declare type OpenAIModel = Omit<Model, "id"> & {
id: OpenAIChatCompletionModelName; id: OpenAIChatCompletionModelName;
}; };

View File

@ -31,14 +31,14 @@ import { EngineSettings, OpenAIModel } from "./@types/global";
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceOpenAIExtension implements InferenceExtension { export default class JanInferenceOpenAIExtension implements InferenceExtension {
private static readonly _homeDir = 'engines' private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = 'openai.json' private static readonly _engineMetadataFileName = "openai.json";
static _currentModel: OpenAIModel; private static _currentModel: OpenAIModel;
static _engineSettings: EngineSettings = { private static _engineSettings: EngineSettings = {
"base_url": "https://api.openai.com/v1", base_url: "https://api.openai.com/v1",
"api_key": "sk-<your key here>" api_key: "sk-<your key here>",
}; };
controller = new AbortController(); controller = new AbortController();
@ -56,8 +56,8 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
onLoad(): void { onLoad(): void {
fs.mkdir(JanInferenceOpenAIExtension._homeDir) fs.mkdir(JanInferenceOpenAIExtension._homeDir);
JanInferenceOpenAIExtension.writeDefaultEngineSettings() JanInferenceOpenAIExtension.writeDefaultEngineSettings();
// Events subscription // Events subscription
events.on(EventName.OnMessageSent, (data) => events.on(EventName.OnMessageSent, (data) =>
@ -87,20 +87,27 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
modelId: string, modelId: string,
settings?: ModelSettingParams settings?: ModelSettingParams
): Promise<void> { ): Promise<void> {
return return;
} }
static async writeDefaultEngineSettings() { static async writeDefaultEngineSettings() {
try { try {
const engine_json = join(JanInferenceOpenAIExtension._homeDir, JanInferenceOpenAIExtension._engineMetadataFileName) const engineFile = join(
if (await fs.checkFileExists(engine_json)) { JanInferenceOpenAIExtension._homeDir,
JanInferenceOpenAIExtension._engineSettings = JSON.parse(await fs.readFile(engine_json)) JanInferenceOpenAIExtension._engineMetadataFileName
} );
else { if (await fs.exists(engineFile)) {
await fs.writeFile(engine_json, JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)) JanInferenceOpenAIExtension._engineSettings = JSON.parse(
await fs.readFile(engineFile)
);
} else {
await fs.writeFile(
engineFile,
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)
);
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err);
} }
} }
/** /**
@ -137,35 +144,39 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
}; };
return new Promise(async (resolve, reject) => { return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? [], requestInference(
JanInferenceOpenAIExtension._engineSettings, data.messages ?? [],
JanInferenceOpenAIExtension._currentModel) JanInferenceOpenAIExtension._engineSettings,
.subscribe({ JanInferenceOpenAIExtension._currentModel
next: (_content) => {}, ).subscribe({
complete: async () => { next: (_content) => {},
resolve(message); complete: async () => {
}, resolve(message);
error: async (err) => { },
reject(err); error: async (err) => {
}, reject(err);
},
}); });
}); });
} }
private static async handleModelInit(model: OpenAIModel) { private static async handleModelInit(model: OpenAIModel) {
if (model.engine !== 'openai') { return } if (model.engine !== "openai") {
else { return;
JanInferenceOpenAIExtension._currentModel = model } else {
JanInferenceOpenAIExtension.writeDefaultEngineSettings() JanInferenceOpenAIExtension._currentModel = model;
JanInferenceOpenAIExtension.writeDefaultEngineSettings();
// Todo: Check model list with API key // Todo: Check model list with API key
events.emit(EventName.OnModelReady, model) events.emit(EventName.OnModelReady, model);
// events.emit(EventName.OnModelFail, model) // events.emit(EventName.OnModelFail, model)
} }
} }
private static async handleModelStop(model: OpenAIModel) { private static async handleModelStop(model: OpenAIModel) {
if (model.engine !== 'openai') { return } if (model.engine !== "openai") {
events.emit(EventName.OnModelStopped, model) return;
}
events.emit(EventName.OnModelStopped, model);
} }
/** /**
@ -178,7 +189,9 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
data: MessageRequest, data: MessageRequest,
instance: JanInferenceOpenAIExtension instance: JanInferenceOpenAIExtension
) { ) {
if (data.model.engine !== 'openai') { return } if (data.model.engine !== "openai") {
return;
}
const timestamp = Date.now(); const timestamp = Date.now();
const message: ThreadMessage = { const message: ThreadMessage = {

View File

@ -231,7 +231,6 @@ export default function useSendChatMessage() {
await WaitForModelStarting(modelId) await WaitForModelStarting(modelId)
setQueuedMessage(false) setQueuedMessage(false)
} }
console.log('messageRequest', messageRequest)
events.emit(EventName.OnMessageSent, messageRequest) events.emit(EventName.OnMessageSent, messageRequest)
} }