Merge pull request #814 from janhq/feat/inference_engines

feat: Multiple inference engines for nitro and openai
This commit is contained in:
hiro 2023-12-09 01:09:47 +07:00 committed by GitHub
commit ee16683d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 999 additions and 263 deletions

8
.gitignore vendored
View File

@ -17,7 +17,7 @@ package-lock.json
core/lib/**
# Nitro binary files
extensions/inference-extension/nitro/*/nitro
extensions/inference-extension/nitro/*/*.exe
extensions/inference-extension/nitro/*/*.dll
extensions/inference-extension/nitro/*/*.metal
extensions/inference-nitro-extension/bin/*/nitro
extensions/inference-nitro-extension/bin/*/*.exe
extensions/inference-nitro-extension/bin/*/*.dll
extensions/inference-nitro-extension/bin/*/*.metal

View File

@ -8,6 +8,18 @@ export enum EventName {
OnMessageResponse = "OnMessageResponse",
/** The `OnMessageUpdate` event is emitted when a message is updated. */
OnMessageUpdate = "OnMessageUpdate",
/** The `OnModelInit` event is emitted when a model inits. */
OnModelInit = "OnModelInit",
/** The `OnModelReady` event is emitted when a model ready. */
OnModelReady = "OnModelReady",
/** The `OnModelFail` event is emitted when a model fails loading. */
OnModelFail = "OnModelFail",
/** The `OnModelStop` event is emitted when a model start to stop. */
OnModelStop = "OnModelStop",
/** The `OnModelStopped` event is emitted when a model stopped ok. */
OnModelStopped = "OnModelStopped",
/** The `OnInferenceStopped` event is emitted when a inference is stopped. */
OnInferenceStopped = "OnInferenceStopped",
}
/**

View File

@ -5,26 +5,10 @@ import { BaseExtension } from "../extension";
* Inference extension. Start, stop and inference models.
*/
export abstract class InferenceExtension extends BaseExtension {
/**
* Initializes the model for the extension.
* @param modelId - The ID of the model to initialize.
*/
abstract initModel(modelId: string, settings?: ModelSettingParams): Promise<void>;
/**
* Stops the model for the extension.
*/
abstract stopModel(): Promise<void>;
/**
* Stops the streaming inference.
*/
abstract stopInference(): Promise<void>;
/**
* Processes an inference request.
* @param data - The data for the inference request.
* @returns The result of the inference request.
*/
abstract inferenceRequest(data: MessageRequest): Promise<ThreadMessage>;
abstract inference(data: MessageRequest): Promise<ThreadMessage>;
}

View File

@ -5,52 +5,52 @@
* @returns {Promise<any>} A Promise that resolves when the file is written successfully.
*/
const writeFile: (path: string, data: string) => Promise<any> = (path, data) =>
global.core.api?.writeFile(path, data);
global.core.api?.writeFile(path, data)
/**
* Checks whether the path is a directory.
* @param path - The path to check.
* @returns {boolean} A boolean indicating whether the path is a directory.
*/
const isDirectory = (path: string): Promise<boolean> =>
global.core.api?.isDirectory(path);
const isDirectory = (path: string): Promise<boolean> => global.core.api?.isDirectory(path)
/**
* Reads the contents of a file at the specified path.
* @param {string} path - The path of the file to read.
* @returns {Promise<any>} A Promise that resolves with the contents of the file.
*/
const readFile: (path: string) => Promise<any> = (path) =>
global.core.api?.readFile(path);
const readFile: (path: string) => Promise<any> = (path) => global.core.api?.readFile(path)
/**
* Check whether the file exists
* @param {string} path
* @returns {boolean} A boolean indicating whether the path is a file.
*/
const exists = (path: string): Promise<boolean> => global.core.api?.exists(path)
/**
* List the directory files
* @param {string} path - The path of the directory to list files.
* @returns {Promise<any>} A Promise that resolves with the contents of the directory.
*/
const listFiles: (path: string) => Promise<any> = (path) =>
global.core.api?.listFiles(path);
const listFiles: (path: string) => Promise<any> = (path) => global.core.api?.listFiles(path)
/**
* Creates a directory at the specified path.
* @param {string} path - The path of the directory to create.
* @returns {Promise<any>} A Promise that resolves when the directory is created successfully.
*/
const mkdir: (path: string) => Promise<any> = (path) =>
global.core.api?.mkdir(path);
const mkdir: (path: string) => Promise<any> = (path) => global.core.api?.mkdir(path)
/**
* Removes a directory at the specified path.
* @param {string} path - The path of the directory to remove.
* @returns {Promise<any>} A Promise that resolves when the directory is removed successfully.
*/
const rmdir: (path: string) => Promise<any> = (path) =>
global.core.api?.rmdir(path);
const rmdir: (path: string) => Promise<any> = (path) => global.core.api?.rmdir(path)
/**
* Deletes a file from the local file system.
* @param {string} path - The path of the file to delete.
* @returns {Promise<any>} A Promise that resolves when the file is deleted.
*/
const deleteFile: (path: string) => Promise<any> = (path) =>
global.core.api?.deleteFile(path);
const deleteFile: (path: string) => Promise<any> = (path) => global.core.api?.deleteFile(path)
/**
* Appends data to a file at the specified path.
@ -58,10 +58,10 @@ const deleteFile: (path: string) => Promise<any> = (path) =>
* @param data data to append
*/
const appendFile: (path: string, data: string) => Promise<any> = (path, data) =>
global.core.api?.appendFile(path, data);
global.core.api?.appendFile(path, data)
const copyFile: (src: string, dest: string) => Promise<any> = (src, dest) =>
global.core.api?.copyFile(src, dest);
global.core.api?.copyFile(src, dest)
/**
* Reads a file line by line.
@ -69,12 +69,13 @@ const copyFile: (src: string, dest: string) => Promise<any> = (src, dest) =>
* @returns {Promise<any>} A promise that resolves to the lines of the file.
*/
const readLineByLine: (path: string) => Promise<any> = (path) =>
global.core.api?.readLineByLine(path);
global.core.api?.readLineByLine(path)
export const fs = {
isDirectory,
writeFile,
readFile,
exists,
listFiles,
mkdir,
rmdir,
@ -82,4 +83,4 @@ export const fs = {
appendFile,
readLineByLine,
copyFile,
};
}

View File

@ -41,8 +41,8 @@ export type MessageRequest = {
/** Messages for constructing a chat completion request **/
messages?: ChatCompletionMessage[];
/** Runtime parameters for constructing a chat completion request **/
parameters?: ModelRuntimeParam;
/** Settings for constructing a chat completion request **/
model?: ModelInfo;
};
/**
@ -153,7 +153,8 @@ export type ThreadAssistantInfo = {
export type ModelInfo = {
id: string;
settings: ModelSettingParams;
parameters: ModelRuntimeParam;
parameters: ModelRuntimeParams;
engine?: InferenceEngine;
};
/**
@ -166,6 +167,17 @@ export type ThreadState = {
error?: Error;
lastMessage?: string;
};
/**
* Represents the inference engine.
* @stored
*/
enum InferenceEngine {
nitro = "nitro",
openai = "openai",
nvidia_triton = "nvidia_triton",
hf_endpoint = "hf_endpoint",
}
/**
* Model type defines the shape of a model object.
@ -228,12 +240,16 @@ export interface Model {
/**
* The model runtime parameters.
*/
parameters: ModelRuntimeParam;
parameters: ModelRuntimeParams;
/**
* Metadata of the model.
*/
metadata: ModelMetadata;
/**
* The model engine.
*/
engine: InferenceEngine;
}
export type ModelMetadata = {
@ -268,7 +284,7 @@ export type ModelSettingParams = {
/**
* The available model runtime parameters.
*/
export type ModelRuntimeParam = {
export type ModelRuntimeParams = {
temperature?: number;
token_limit?: number;
top_k?: number;

View File

@ -289,7 +289,7 @@ components:
engine:
type: string
description: "The engine used by the model."
example: "llamacpp"
enum: [nitro, openai, hf_inference]
quantization:
type: string
description: "Quantization parameter of the model."

View File

@ -50,6 +50,19 @@ export function handleFsIPCs() {
})
})
/**
* Checks whether a file exists in the user data directory.
* @param event - The event object.
* @param path - The path of the file to check.
* @returns A promise that resolves with a boolean indicating whether the file exists.
*/
ipcMain.handle('exists', async (_event, path: string) => {
return new Promise((resolve, reject) => {
const fullPath = join(userSpacePath, path)
fs.existsSync(fullPath) ? resolve(true) : resolve(false)
})
})
/**
* Writes data to a file in the user data directory.
* @param event - The event object.

View File

@ -27,6 +27,12 @@ export function fsInvokers() {
*/
readFile: (path: string) => ipcRenderer.invoke('readFile', path),
/**
* Reads a file at the specified path.
* @param {string} path - The path of the file to read.
*/
exists: (path: string) => ipcRenderer.invoke('exists', path),
/**
* Writes data to a file at the specified path.
* @param {string} path - The path of the file to write to.

View File

@ -1,3 +0,0 @@
@echo off
set /p NITRO_VERSION=<./nitro/version.txt
.\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64-cuda.tar.gz -e --strip 1 -o ./nitro/win-cuda && .\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64.tar.gz -e --strip 1 -o ./nitro/win-cpu

View File

@ -1,57 +0,0 @@
{
"name": "@janhq/inference-extension",
"version": "1.0.0",
"description": "Inference Extension, powered by @janhq/nitro, bring a high-performance Llama model inference in pure C++.",
"main": "dist/index.js",
"module": "dist/module.js",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"downloadnitro:linux": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./nitro/linux-cpu && chmod +x ./nitro/linux-cpu/nitro && chmod +x ./nitro/linux-start.sh && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda.tar.gz -e --strip 1 -o ./nitro/linux-cuda && chmod +x ./nitro/linux-cuda/nitro && chmod +x ./nitro/linux-start.sh",
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./nitro/mac-arm64 && chmod +x ./nitro/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./nitro/mac-x64 && chmod +x ./nitro/mac-x64/nitro",
"downloadnitro:win32": "download.bat",
"downloadnitro": "run-script-os",
"build:publish:darwin": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish:win32": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish:linux": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish": "run-script-os"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"run-script-os": "^1.1.6",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4"
},
"dependencies": {
"@janhq/core": "file:../../core",
"download-cli": "^1.1.1",
"electron-log": "^5.0.1",
"fetch-retry": "^5.0.6",
"kill-port": "^2.0.1",
"path-browserify": "^1.0.1",
"rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2",
"ts-loader": "^9.5.0",
"ulid": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"tcp-port-used",
"kill-port",
"fetch-retry",
"electron-log"
]
}

View File

@ -1,2 +0,0 @@
declare const MODULE: string;
declare const INFERENCE_URL: string;

View File

@ -0,0 +1,3 @@
@echo off
set /p NITRO_VERSION=<./bin/version.txt
.\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64-cuda.tar.gz -e --strip 1 -o ./bin/win-cuda && .\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64.tar.gz -e --strip 1 -o ./bin/win-cpu

View File

@ -0,0 +1,57 @@
{
"name": "@janhq/inference-nitro-extension",
"version": "1.0.0",
"description": "Inference Engine for Nitro Extension, powered by @janhq/nitro, bring a high-performance Llama model inference in pure C++.",
"main": "dist/index.js",
"module": "dist/module.js",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && chmod +x ./bin/linux-start.sh && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda.tar.gz -e --strip 1 -o ./bin/linux-cuda && chmod +x ./bin/linux-cuda/nitro && chmod +x ./bin/linux-start.sh",
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
"downloadnitro:win32": "download.bat",
"downloadnitro": "run-script-os",
"build:publish:darwin": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish:win32": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish:linux": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
"build:publish": "run-script-os"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"run-script-os": "^1.1.6",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4"
},
"dependencies": {
"@janhq/core": "file:../../core",
"download-cli": "^1.1.1",
"electron-log": "^5.0.1",
"fetch-retry": "^5.0.6",
"kill-port": "^2.0.1",
"path-browserify": "^1.0.1",
"rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2",
"ts-loader": "^9.5.0",
"ulid": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"tcp-port-used",
"kill-port",
"fetch-retry",
"electron-log"
]
}

View File

@ -0,0 +1,26 @@
declare const MODULE: string;
declare const INFERENCE_URL: string;
/**
* The parameters for the initModel function.
* @property settings - The settings for the machine learning model.
* @property settings.ctx_len - The context length.
* @property settings.ngl - The number of generated tokens.
* @property settings.cont_batching - Whether to use continuous batching.
* @property settings.embedding - Whether to use embedding.
*/
interface EngineSettings {
ctx_len: number;
ngl: number;
cont_batching: boolean;
embedding: boolean;
}
/**
* The response from the initModel function.
* @property error - An error message if the model fails to load.
*/
interface ModelOperationResponse {
error?: any;
modelFile?: string;
}

View File

@ -1,3 +1,4 @@
import { Model } from "@janhq/core";
import { Observable } from "rxjs";
/**
* Sends a request to the inference server to generate a response based on the recent messages.
@ -6,21 +7,23 @@ import { Observable } from "rxjs";
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
messages: recentMessages,
model: model.id,
stream: true,
model: "gpt-3.5-turbo",
max_tokens: 2048,
// ...model.parameters,
});
fetch(INFERENCE_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origin": "*",
Accept: "text/event-stream",
},
body: requestBody,
signal: controller?.signal,

View File

@ -19,6 +19,8 @@ import {
events,
executeOnMain,
getUserSpace,
fs,
Model,
} from "@janhq/core";
import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse";
@ -30,7 +32,19 @@ import { join } from "path";
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceExtension implements InferenceExtension {
export default class JanInferenceNitroExtension implements InferenceExtension {
private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = "nitro.json";
private static _currentModel: Model;
private static _engineSettings: EngineSettings = {
ctx_len: 2048,
ngl: 100,
cont_batching: false,
embedding: false,
};
controller = new AbortController();
isCancelled = false;
/**
@ -45,51 +59,88 @@ export default class JanInferenceExtension implements InferenceExtension {
* Subscribes to events emitted by the @janhq/core package.
*/
onLoad(): void {
fs.mkdir(JanInferenceNitroExtension._homeDir);
this.writeDefaultEngineSettings();
// Events subscription
events.on(EventName.OnMessageSent, (data) =>
JanInferenceExtension.handleMessageRequest(data, this)
JanInferenceNitroExtension.handleMessageRequest(data, this)
);
events.on(EventName.OnModelInit, (model: Model) => {
JanInferenceNitroExtension.handleModelInit(model);
});
events.on(EventName.OnModelStop, (model: Model) => {
JanInferenceNitroExtension.handleModelStop(model);
});
events.on(EventName.OnInferenceStopped, () => {
JanInferenceNitroExtension.handleInferenceStopped(this);
});
}
/**
* Stops the model inference.
*/
onUnload(): void {
this.stopModel();
onUnload(): void {}
private async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceNitroExtension._homeDir,
JanInferenceNitroExtension._engineMetadataFileName
);
if (await fs.exists(engineFile)) {
JanInferenceNitroExtension._engineSettings = JSON.parse(
await fs.readFile(engineFile)
);
} else {
await fs.writeFile(
engineFile,
JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)
);
}
} catch (err) {
console.error(err);
}
}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
private static async handleModelInit(model: Model) {
if (model.engine !== "nitro") {
return;
}
const userSpacePath = await getUserSpace();
const modelFullPath = join(userSpacePath, "models", modelId, modelId);
const modelFullPath = join(userSpacePath, "models", model.id, model.id);
return executeOnMain(MODULE, "initModel", {
modelFullPath,
settings,
const nitroInitResult = await executeOnMain(MODULE, "initModel", {
modelFullPath: modelFullPath,
model: model,
});
if (nitroInitResult.error === null) {
events.emit(EventName.OnModelFail, model);
} else {
JanInferenceNitroExtension._currentModel = model;
events.emit(EventName.OnModelReady, model);
}
}
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {
return executeOnMain(MODULE, "killSubprocess");
private static async handleModelStop(model: Model) {
if (model.engine !== "nitro") {
return;
} else {
await executeOnMain(MODULE, "stopModel");
events.emit(EventName.OnModelStopped, model);
}
}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true;
this.controller?.abort();
private static async handleInferenceStopped(
instance: JanInferenceNitroExtension
) {
instance.isCancelled = true;
instance.controller?.abort();
}
/**
@ -97,7 +148,7 @@ export default class JanInferenceExtension implements InferenceExtension {
* @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response.
*/
async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> {
async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now();
const message: ThreadMessage = {
thread_id: data.threadId,
@ -111,7 +162,11 @@ export default class JanInferenceExtension implements InferenceExtension {
};
return new Promise(async (resolve, reject) => {
requestInference(data.messages ?? []).subscribe({
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel
).subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
@ -131,8 +186,11 @@ export default class JanInferenceExtension implements InferenceExtension {
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceExtension
instance: JanInferenceNitroExtension
) {
if (data.model.engine !== "nitro") {
return;
}
const timestamp = Date.now();
const message: ThreadMessage = {
id: ulid(),
@ -150,7 +208,12 @@ export default class JanInferenceExtension implements InferenceExtension {
instance.isCancelled = false;
instance.controller = new AbortController();
requestInference(data.messages, instance.controller).subscribe({
requestInference(
data.messages ?? [],
JanInferenceNitroExtension._engineSettings,
JanInferenceNitroExtension._currentModel,
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,

View File

@ -20,51 +20,51 @@ let subprocess = null;
let currentModelFile = null;
/**
* The response from the initModel function.
* @property error - An error message if the model fails to load.
* Stops a Nitro subprocess.
* @param wrapper - The model wrapper.
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
*/
interface InitModelResponse {
error?: any;
modelFile?: string;
function stopModel(): Promise<ModelOperationResponse> {
return new Promise((resolve, reject) => {
checkAndUnloadNitro();
resolve({ error: undefined });
});
}
/**
* Initializes a Nitro subprocess to load a machine learning model.
* @param modelFile - The name of the machine learning model file.
* @param wrapper - The model wrapper.
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
* TODO: Should it be startModel instead?
*/
function initModel(wrapper: any): Promise<InitModelResponse> {
// 1. Check if the model file exists
function initModel(wrapper: any): Promise<ModelOperationResponse> {
currentModelFile = wrapper.modelFullPath;
log.info("Started to load model " + wrapper.modelFullPath);
const settings = {
llama_model_path: currentModelFile,
ctx_len: 2048,
ngl: 100,
cont_batching: false,
embedding: false, // Always enable embedding mode on
...wrapper.settings,
};
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
return (
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process
validateModelVersion()
.then(checkAndUnloadNitro)
// 2. Spawn the Nitro subprocess
.then(spawnNitroProcess)
// 4. Load the model into the Nitro subprocess (HTTP POST request)
.then(() => loadLLMModel(settings))
// 5. Check if the model is loaded successfully
.then(validateModelStatus)
.catch((err) => {
log.error("error: " + JSON.stringify(err));
return { error: err, currentModelFile };
})
);
if (wrapper.model.engine !== "nitro") {
return Promise.resolve({ error: "Not a nitro model" });
} else {
log.info("Started to load model " + wrapper.model.modelFullPath);
const settings = {
llama_model_path: currentModelFile,
...wrapper.model.settings,
};
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
return (
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process
validateModelVersion()
.then(checkAndUnloadNitro)
// 2. Spawn the Nitro subprocess
.then(spawnNitroProcess)
// 4. Load the model into the Nitro subprocess (HTTP POST request)
.then(() => loadLLMModel(settings))
// 5. Check if the model is loaded successfully
.then(validateModelStatus)
.catch((err) => {
log.error("error: " + JSON.stringify(err));
return { error: err, currentModelFile };
})
);
}
}
/**
@ -91,11 +91,11 @@ function loadLLMModel(settings): Promise<Response> {
/**
* Validates the status of a model.
* @returns {Promise<InitModelResponse>} A promise that resolves to an object.
* @returns {Promise<ModelOperationResponse>} A promise that resolves to an object.
* If the model is loaded successfully, the object is empty.
* If the model is not loaded successfully, the object contains an error message.
*/
async function validateModelStatus(): Promise<InitModelResponse> {
async function validateModelStatus(): Promise<ModelOperationResponse> {
// Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
@ -142,8 +142,8 @@ function killSubprocess(): Promise<void> {
* Check port is used or not, if used, attempt to unload model
* If unload failed, kill the port
*/
function checkAndUnloadNitro() {
return tcpPortUsed.check(PORT, LOCAL_HOST).then((inUse) => {
async function checkAndUnloadNitro() {
return tcpPortUsed.check(PORT, LOCAL_HOST).then(async (inUse) => {
// If inUse - try unload or kill process, otherwise do nothing
if (inUse) {
// Attempt to unload model
@ -168,7 +168,7 @@ function checkAndUnloadNitro() {
*/
async function spawnNitroProcess(): Promise<void> {
return new Promise((resolve, reject) => {
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
let binaryFolder = path.join(__dirname, "bin"); // Current directory by default
let binaryName;
if (process.platform === "win32") {

View File

@ -0,0 +1,78 @@
# Jan inference plugin
Created using Jan app example
# Create a Jan Plugin using Typescript
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
## Create Your Own Plugin
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository
2. Select Create a new repository
3. Select an owner and name for your new repository
4. Click Create repository
5. Clone your new repository
## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
> [!NOTE]
>
> You'll need to have a reasonably modern version of
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
> [`nodenv`](https://github.com/nodenv/nodenv) or
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
> root of your repository to install the version specified in
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
1. :hammer_and_wrench: Install the dependencies
```bash
npm install
```
1. :building_construction: Package the TypeScript for distribution
```bash
npm run bundle
```
1. :white_check_mark: Check your artifact
There will be a tgz file in your plugin directory now
## Update the Plugin Metadata
The [`package.json`](package.json) file defines metadata about your plugin, such as
plugin name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your plugin.
## Update the Plugin Code
The [`src/`](./src/) directory is the heart of your plugin! This contains the
source code that will be run when your plugin extension functions are invoked. You can replace the
contents of this directory with your own code.
There are a few things to keep in mind when writing your plugin code:
- Most Jan Plugin Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript
import { core } from "@janhq/core";
function onStart(): Promise<any> {
return core.invokePluginFunc(MODULE_PATH, "run", 0);
}
```
For more information about the Jan Plugin Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your plugin!

View File

@ -0,0 +1,41 @@
{
"name": "@janhq/inference-openai-extension",
"version": "1.0.0",
"description": "Inference Engine for OpenAI Extension that can be used with any OpenAI compatible API",
"main": "dist/index.js",
"module": "dist/module.js",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"build": "tsc -b . && webpack --config webpack.config.js",
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/module.js"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"webpack": "^5.88.2",
"webpack-cli": "^5.1.4"
},
"dependencies": {
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1",
"ts-loader": "^9.5.0",
"ulid": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"bundleDependencies": [
"fetch-retry"
]
}

View File

@ -0,0 +1,27 @@
import { Model } from "@janhq/core";
declare const MODULE: string;
declare interface EngineSettings {
full_url?: string;
api_key?: string;
}
enum OpenAIChatCompletionModelName {
"gpt-3.5-turbo-instruct" = "gpt-3.5-turbo-instruct",
"gpt-3.5-turbo-instruct-0914" = "gpt-3.5-turbo-instruct-0914",
"gpt-4-1106-preview" = "gpt-4-1106-preview",
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301" = "gpt-3.5-turbo-0301",
"gpt-3.5-turbo" = "gpt-3.5-turbo",
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-1106" = "gpt-3.5-turbo-1106",
"gpt-4-vision-preview" = "gpt-4-vision-preview",
"gpt-4" = "gpt-4",
"gpt-4-0314" = "gpt-4-0314",
"gpt-4-0613" = "gpt-4-0613",
}
declare type OpenAIModel = Omit<Model, "id"> & {
id: OpenAIChatCompletionModelName;
};

View File

@ -0,0 +1,68 @@
import { Observable } from "rxjs";
import { EngineSettings, OpenAIModel } from "../@types/global";
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: OpenAIModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
let model_id: string = model.id
if (engine.full_url.includes("openai.azure.com")){
model_id = engine.full_url.split("/")[5]
}
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model_id
// ...model.parameters,
});
fetch(`${engine.full_url}`, {
method: "POST",
headers: {
"Content-Type": "application/json",
Accept: "text/event-stream",
"Access-Control-Allow-Origin": "*",
Authorization: `Bearer ${engine.api_key}`,
"api-key": `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body;
const decoder = new TextDecoder("utf-8");
const reader = stream?.getReader();
let content = "";
while (true && reader) {
const { done, value } = await reader.read();
if (done) {
break;
}
const text = decoder.decode(value);
const lines = text.trim().split("\n");
for (const line of lines) {
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
const data = JSON.parse(line.replace("data: ", ""));
content += data.choices[0]?.delta?.content ?? "";
if (content.startsWith("assistant: ")) {
content = content.replace("assistant: ", "");
}
subscriber.next(content);
}
}
}
subscriber.complete();
})
.catch((err) => subscriber.error(err));
});
}

View File

@ -0,0 +1,231 @@
/**
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0
* @module inference-openai-extension/src/index
*/
import {
ChatCompletionRole,
ContentType,
EventName,
MessageRequest,
MessageStatus,
ModelSettingParams,
ExtensionType,
ThreadContent,
ThreadMessage,
events,
fs,
} from "@janhq/core";
import { InferenceExtension } from "@janhq/core";
import { requestInference } from "./helpers/sse";
import { ulid } from "ulid";
import { join } from "path";
import { EngineSettings, OpenAIModel } from "./@types/global";
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceOpenAIExtension implements InferenceExtension {
private static readonly _homeDir = "engines";
private static readonly _engineMetadataFileName = "openai.json";
private static _currentModel: OpenAIModel;
private static _engineSettings: EngineSettings = {
full_url: "https://api.openai.com/v1/chat/completions",
api_key: "sk-<your key here>",
};
controller = new AbortController();
isCancelled = false;
/**
* Returns the type of the extension.
* @returns {ExtensionType} The type of the extension.
*/
// TODO: To fix
type(): ExtensionType {
return undefined;
}
/**
* Subscribes to events emitted by the @janhq/core package.
*/
onLoad(): void {
fs.mkdir(JanInferenceOpenAIExtension._homeDir);
JanInferenceOpenAIExtension.writeDefaultEngineSettings();
// Events subscription
events.on(EventName.OnMessageSent, (data) =>
JanInferenceOpenAIExtension.handleMessageRequest(data, this)
);
events.on(EventName.OnModelInit, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelInit(model);
});
events.on(EventName.OnModelStop, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelStop(model);
});
events.on(EventName.OnInferenceStopped, () => {
JanInferenceOpenAIExtension.handleInferenceStopped(this);
});
}
/**
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceOpenAIExtension._homeDir,
JanInferenceOpenAIExtension._engineMetadataFileName
);
if (await fs.exists(engineFile)) {
JanInferenceOpenAIExtension._engineSettings = JSON.parse(
await fs.readFile(engineFile)
);
} else {
await fs.writeFile(
engineFile,
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)
);
}
} catch (err) {
console.error(err);
}
}
/**
* Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response.
*/
async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now();
const message: ThreadMessage = {
thread_id: data.threadId,
created: timestamp,
updated: timestamp,
status: MessageStatus.Ready,
id: "",
role: ChatCompletionRole.Assistant,
object: "thread.message",
content: [],
};
return new Promise(async (resolve, reject) => {
requestInference(
data.messages ?? [],
JanInferenceOpenAIExtension._engineSettings,
JanInferenceOpenAIExtension._currentModel
).subscribe({
next: (_content) => {},
complete: async () => {
resolve(message);
},
error: async (err) => {
reject(err);
},
});
});
}
private static async handleModelInit(model: OpenAIModel) {
if (model.engine !== "openai") {
return;
} else {
JanInferenceOpenAIExtension._currentModel = model;
JanInferenceOpenAIExtension.writeDefaultEngineSettings();
// Todo: Check model list with API key
events.emit(EventName.OnModelReady, model);
}
}
private static async handleModelStop(model: OpenAIModel) {
if (model.engine !== "openai") {
return;
}
events.emit(EventName.OnModelStopped, model);
}
private static async handleInferenceStopped(
instance: JanInferenceOpenAIExtension
) {
instance.isCancelled = true;
instance.controller?.abort();
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceOpenAIExtension
) {
if (data.model.engine !== "openai") {
return;
}
const timestamp = Date.now();
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: "thread.message",
};
events.emit(EventName.OnMessageResponse, message);
instance.isCancelled = false;
instance.controller = new AbortController();
requestInference(
data?.messages ?? [],
this._engineSettings,
JanInferenceOpenAIExtension._currentModel,
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
};
message.content = [messageContent];
events.emit(EventName.OnMessageUpdate, message);
},
complete: async () => {
message.status = MessageStatus.Ready;
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: "Error occurred: " + err.message,
annotations: [],
},
};
message.content = [messageContent];
message.status = MessageStatus.Ready;
events.emit(EventName.OnMessageUpdate, message);
},
});
}
}

View File

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "es2016",
"module": "ES6",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"]
}

View File

@ -0,0 +1,42 @@
const path = require("path");
const webpack = require("webpack");
const packageJson = require("./package.json");
module.exports = {
experiments: { outputModule: true },
entry: "./src/index.ts", // Adjust the entry point to match your project's main file
mode: "production",
module: {
rules: [
{
test: /\.tsx?$/,
use: "ts-loader",
exclude: /node_modules/,
},
],
},
plugins: [
new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
INFERENCE_URL: JSON.stringify(
process.env.INFERENCE_URL ||
"http://127.0.0.1:3928/inferences/llamacpp/chat_completion"
),
}),
],
output: {
filename: "index.js", // Adjust the output file name as needed
path: path.resolve(__dirname, "dist"),
library: { type: "module" }, // Specify ESM output format
},
resolve: {
extensions: [".ts", ".js"],
fallback: {
path: require.resolve("path-browserify"),
},
},
optimization: {
minimize: false,
},
// Add loaders and other configuration as needed for your project
};

View File

@ -19,6 +19,7 @@
"author": "NousResearch, The Bloke",
"tags": ["34B", "Finetuned"],
"size": 24320000000
}
},
"engine": "nitro"
}

View File

@ -1,3 +1,4 @@
{
"source_url": "https://huggingface.co/TheBloke/deepseek-coder-1.3b-instruct-GGUF/resolve/main/deepseek-coder-1.3b-instruct.Q8_0.gguf",
"id": "deepseek-coder-1.3b",
@ -19,5 +20,6 @@
"author": "Deepseek, The Bloke",
"tags": ["Tiny", "Foundational Model"],
"size": 1430000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Deepseek, The Bloke",
"tags": ["34B", "Foundational Model"],
"size": 26040000000
}
},
"engine": "nitro"
}

View File

@ -0,0 +1,20 @@
{
"source_url": "https://openai.com",
"id": "gpt-3.5-turbo-16k-0613",
"object": "model",
"name": "OpenAI GPT 3.5 Turbo 16k 0613",
"version": 1.0,
"description": "OpenAI GPT 3.5 Turbo 16k 0613 model is extremely good",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096
},
"metadata": {
"author": "OpenAI",
"tags": ["General", "Big Context Length"]
},
"engine": "openai",
"state": "ready"
}

View File

@ -0,0 +1,18 @@
{
"source_url": "https://openai.com",
"id": "gpt-3.5-turbo",
"object": "model",
"name": "OpenAI GPT 3.5 Turbo",
"version": 1.0,
"description": "OpenAI GPT 3.5 Turbo model is extremely good",
"format": "api",
"settings": {},
"parameters": {},
"metadata": {
"author": "OpenAI",
"tags": ["General", "Big Context Length"]
},
"engine": "openai",
"state": "ready"
}

20
models/gpt-4/model.json Normal file
View File

@ -0,0 +1,20 @@
{
"source_url": "https://openai.com",
"id": "gpt-4",
"object": "model",
"name": "OpenAI GPT 3.5",
"version": 1.0,
"description": "OpenAI GPT 3.5 model is extremely good",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096
},
"metadata": {
"author": "OpenAI",
"tags": ["General", "Big Context Length"]
},
"engine": "openai",
"state": "ready"
}

View File

@ -19,6 +19,7 @@
"author": "MetaAI, The Bloke",
"tags": ["70B", "Foundational Model"],
"size": 43920000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "MetaAI, The Bloke",
"tags": ["7B", "Foundational Model"],
"size": 4080000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "MetaAI, The Bloke",
"tags": ["7B", "Foundational Model"],
"size": 4780000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Lizpreciatior, The Bloke",
"tags": ["70B", "Finetuned"],
"size": 48750000000
}
},
"engine": "nitro"
}

View File

@ -20,6 +20,7 @@
"tags": ["Featured", "7B", "Foundational Model"],
"size": 4370000000,
"cover": "https://raw.githubusercontent.com/janhq/jan/main/models/mistral-ins-7b-q4/cover.png"
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "MistralAI, The Bloke",
"tags": ["7B", "Foundational Model"],
"size": 5130000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Intel, The Bloke",
"tags": ["Recommended", "7B", "Finetuned"],
"size": 4370000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "NeverSleep, The Bloke",
"tags": ["34B", "Finetuned"],
"size": 12040000000
}
},
"engine": "nitro"
}

View File

@ -20,5 +20,6 @@
"tags": ["Featured", "7B", "Merged"],
"size": 4370000000,
"cover": "https://raw.githubusercontent.com/janhq/jan/main/models/openhermes-neural-7b/cover.png"
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Microsoft, The Bloke",
"tags": ["13B", "Finetuned"],
"size": 9230000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Phind, The Bloke",
"tags": ["34B", "Finetuned"],
"size": 24320000000
}
},
"engine": "nitro"
}

View File

@ -19,5 +19,6 @@
"author": "Pansophic, The Bloke",
"tags": ["Tiny", "Finetuned"],
"size": 1710000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "Berkeley-nest, The Bloke",
"tags": ["Recommended", "7B","Finetuned"],
"size": 4370000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "KoboldAI, The Bloke",
"tags": ["13B", "Finetuned"],
"size": 9230000000
}
},
"engine": "nitro"
}

View File

@ -19,5 +19,6 @@
"author": "TinyLlama",
"tags": ["Tiny", "Foundation Model"],
"size": 637000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "WizardLM, The Bloke",
"tags": ["Recommended", "13B", "Finetuned"],
"size": 9230000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "01-ai, The Bloke",
"tags": ["34B", "Foundational Model"],
"size": 24320000000
}
},
"engine": "nitro"
}

View File

@ -19,6 +19,7 @@
"author": "HuggingFaceH4, The Bloke",
"tags": ["7B", "Finetuned"],
"size": 4370000000
}
},
"engine": "nitro"
}

View File

@ -7,10 +7,16 @@ import {
ThreadMessage,
ExtensionType,
MessageStatus,
Model,
} from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { toaster } from '../Toast'
import { extensionManager } from '@/extension'
import {
addNewMessageAtom,
@ -24,19 +30,61 @@ import {
export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom)
const { downloadedModels } = useGetDownloadedModels()
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
const modelsRef = useRef(downloadedModels)
const threadsRef = useRef(threads)
useEffect(() => {
threadsRef.current = threads
}, [threads])
useEffect(() => {
modelsRef.current = downloadedModels
}, [downloadedModels])
async function handleNewMessageResponse(message: ThreadMessage) {
addNewMessage(message)
}
async function handleModelReady(model: Model) {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
}
async function handleModelStopped(model: Model) {
setTimeout(async () => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${model.id} has been stopped.`,
})
}, 500)
}
async function handleModelFail(res: any) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.modelId,
}))
}
async function handleMessageResponseUpdate(message: ThreadMessage) {
updateMessage(
message.id,
@ -73,6 +121,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core.events) {
events.on(EventName.OnMessageResponse, handleNewMessageResponse)
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
events.on(EventName.OnModelReady, handleModelReady)
events.on(EventName.OnModelFail, handleModelFail)
events.on(EventName.OnModelStopped, handleModelStopped)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])

View File

@ -1,5 +1,8 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ExtensionType, InferenceExtension } from '@janhq/core'
import {
EventName,
events,
} from '@janhq/core'
import { Model, ModelSettingParams } from '@janhq/core'
import { atom, useAtom } from 'jotai'
@ -9,9 +12,13 @@ import { useGetDownloadedModels } from './useGetDownloadedModels'
import { extensionManager } from '@/extension'
const activeModelAtom = atom<Model | undefined>(undefined)
export const activeModelAtom = atom<Model | undefined>(undefined)
const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
export const stateModelAtom = atom({
state: 'start',
loading: false,
model: '',
})
export function useActiveModel() {
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
@ -47,59 +54,14 @@ export function useActiveModel() {
return
}
const currentTime = Date.now()
const res = await initModel(modelId, model?.settings)
if (res && res.error) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: modelId,
}))
} else {
console.debug(
`Model ${modelId} successfully initialized! Took ${
Date.now() - currentTime
}ms`
)
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${modelId} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: modelId,
}))
}
events.emit(EventName.OnModelInit, model)
}
const stopModel = async (modelId: string) => {
const model = downloadedModels.find((e) => e.id === modelId)
setStateModel({ state: 'stop', loading: true, model: modelId })
setTimeout(async () => {
extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.stopModel()
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${modelId} has been stopped.`,
})
}, 500)
events.emit(EventName.OnModelStop, model)
}
return { activeModel, startModel, stopModel, stateModel }
}
const initModel = async (
modelId: string,
settings?: ModelSettingParams
): Promise<any> => {
return extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.initModel(modelId, settings)
}

View File

@ -67,6 +67,7 @@ export const useCreateNewThread = () => {
top_p: 0,
stream: false,
},
engine: undefined
},
instructions: assistant.instructions,
}

View File

@ -50,7 +50,6 @@ export default function useSendChatMessage() {
const [queuedMessage, setQueuedMessage] = useState(false)
const modelRef = useRef<Model | undefined>()
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
@ -91,18 +90,35 @@ export default function useSendChatMessage() {
id: ulid(),
messages: messages,
threadId: activeThread.id,
model: activeThread.assistants[0].model ?? selectedModel,
}
const modelId = selectedModel?.id ?? activeThread.assistants[0].model.id
if (activeModel?.id !== modelId) {
setQueuedMessage(true)
await startModel(modelId)
startModel(modelId)
await WaitForModelStarting(modelId)
setQueuedMessage(false)
}
events.emit(EventName.OnMessageSent, messageRequest)
}
// TODO: Refactor @louis
const WaitForModelStarting = async (modelId: string) => {
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (modelRef.current?.id !== modelId) {
console.log('waiting for model to start')
await WaitForModelStarting(modelId)
resolve()
} else {
resolve()
}
}, 200)
})
}
const sendChatMessage = async () => {
if (!currentPrompt || currentPrompt.trim().length === 0) {
return
@ -132,6 +148,7 @@ export default function useSendChatMessage() {
id: selectedModel.id,
settings: selectedModel.settings,
parameters: selectedModel.parameters,
engine: selectedModel.engine,
},
},
],
@ -178,7 +195,7 @@ export default function useSendChatMessage() {
id: msgId,
threadId: activeThread.id,
messages,
parameters: activeThread.assistants[0].model.parameters,
model: selectedModel ?? activeThread.assistants[0].model,
}
const timestamp = Date.now()
const threadMessage: ThreadMessage = {
@ -210,7 +227,8 @@ export default function useSendChatMessage() {
if (activeModel?.id !== modelId) {
setQueuedMessage(true)
await startModel(modelId)
startModel(modelId)
await WaitForModelStarting(modelId)
setQueuedMessage(false)
}
events.emit(EventName.OnMessageSent, messageRequest)

View File

@ -30,9 +30,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
const { resendChatMessage } = useSendChatMessage()
const onStopInferenceClick = async () => {
await extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.stopInference()
events.emit(EventName.OnInferenceStopped, {})
setTimeout(() => {
events.emit(EventName.OnMessageUpdate, {
...message,

View File

@ -1,4 +1,5 @@
export const toGigabytes = (input: number) => {
if (!input) return ''
if (input > 1024 ** 3) {
return (input / 1000 ** 3).toFixed(2) + 'GB'
} else if (input > 1024 ** 2) {