Merge pull request #814 from janhq/feat/inference_engines
feat: Multiple inference engines for nitro and openai
This commit is contained in:
commit
ee16683d0a
8
.gitignore
vendored
8
.gitignore
vendored
@ -17,7 +17,7 @@ package-lock.json
|
||||
core/lib/**
|
||||
|
||||
# Nitro binary files
|
||||
extensions/inference-extension/nitro/*/nitro
|
||||
extensions/inference-extension/nitro/*/*.exe
|
||||
extensions/inference-extension/nitro/*/*.dll
|
||||
extensions/inference-extension/nitro/*/*.metal
|
||||
extensions/inference-nitro-extension/bin/*/nitro
|
||||
extensions/inference-nitro-extension/bin/*/*.exe
|
||||
extensions/inference-nitro-extension/bin/*/*.dll
|
||||
extensions/inference-nitro-extension/bin/*/*.metal
|
||||
@ -8,6 +8,18 @@ export enum EventName {
|
||||
OnMessageResponse = "OnMessageResponse",
|
||||
/** The `OnMessageUpdate` event is emitted when a message is updated. */
|
||||
OnMessageUpdate = "OnMessageUpdate",
|
||||
/** The `OnModelInit` event is emitted when a model inits. */
|
||||
OnModelInit = "OnModelInit",
|
||||
/** The `OnModelReady` event is emitted when a model ready. */
|
||||
OnModelReady = "OnModelReady",
|
||||
/** The `OnModelFail` event is emitted when a model fails loading. */
|
||||
OnModelFail = "OnModelFail",
|
||||
/** The `OnModelStop` event is emitted when a model start to stop. */
|
||||
OnModelStop = "OnModelStop",
|
||||
/** The `OnModelStopped` event is emitted when a model stopped ok. */
|
||||
OnModelStopped = "OnModelStopped",
|
||||
/** The `OnInferenceStopped` event is emitted when a inference is stopped. */
|
||||
OnInferenceStopped = "OnInferenceStopped",
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -5,26 +5,10 @@ import { BaseExtension } from "../extension";
|
||||
* Inference extension. Start, stop and inference models.
|
||||
*/
|
||||
export abstract class InferenceExtension extends BaseExtension {
|
||||
/**
|
||||
* Initializes the model for the extension.
|
||||
* @param modelId - The ID of the model to initialize.
|
||||
*/
|
||||
abstract initModel(modelId: string, settings?: ModelSettingParams): Promise<void>;
|
||||
|
||||
/**
|
||||
* Stops the model for the extension.
|
||||
*/
|
||||
abstract stopModel(): Promise<void>;
|
||||
|
||||
/**
|
||||
* Stops the streaming inference.
|
||||
*/
|
||||
abstract stopInference(): Promise<void>;
|
||||
|
||||
/**
|
||||
* Processes an inference request.
|
||||
* @param data - The data for the inference request.
|
||||
* @returns The result of the inference request.
|
||||
*/
|
||||
abstract inferenceRequest(data: MessageRequest): Promise<ThreadMessage>;
|
||||
abstract inference(data: MessageRequest): Promise<ThreadMessage>;
|
||||
}
|
||||
|
||||
@ -5,52 +5,52 @@
|
||||
* @returns {Promise<any>} A Promise that resolves when the file is written successfully.
|
||||
*/
|
||||
const writeFile: (path: string, data: string) => Promise<any> = (path, data) =>
|
||||
global.core.api?.writeFile(path, data);
|
||||
global.core.api?.writeFile(path, data)
|
||||
|
||||
/**
|
||||
* Checks whether the path is a directory.
|
||||
* @param path - The path to check.
|
||||
* @returns {boolean} A boolean indicating whether the path is a directory.
|
||||
*/
|
||||
const isDirectory = (path: string): Promise<boolean> =>
|
||||
global.core.api?.isDirectory(path);
|
||||
const isDirectory = (path: string): Promise<boolean> => global.core.api?.isDirectory(path)
|
||||
|
||||
/**
|
||||
* Reads the contents of a file at the specified path.
|
||||
* @param {string} path - The path of the file to read.
|
||||
* @returns {Promise<any>} A Promise that resolves with the contents of the file.
|
||||
*/
|
||||
const readFile: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.readFile(path);
|
||||
const readFile: (path: string) => Promise<any> = (path) => global.core.api?.readFile(path)
|
||||
/**
|
||||
* Check whether the file exists
|
||||
* @param {string} path
|
||||
* @returns {boolean} A boolean indicating whether the path is a file.
|
||||
*/
|
||||
const exists = (path: string): Promise<boolean> => global.core.api?.exists(path)
|
||||
/**
|
||||
* List the directory files
|
||||
* @param {string} path - The path of the directory to list files.
|
||||
* @returns {Promise<any>} A Promise that resolves with the contents of the directory.
|
||||
*/
|
||||
const listFiles: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.listFiles(path);
|
||||
const listFiles: (path: string) => Promise<any> = (path) => global.core.api?.listFiles(path)
|
||||
/**
|
||||
* Creates a directory at the specified path.
|
||||
* @param {string} path - The path of the directory to create.
|
||||
* @returns {Promise<any>} A Promise that resolves when the directory is created successfully.
|
||||
*/
|
||||
const mkdir: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.mkdir(path);
|
||||
const mkdir: (path: string) => Promise<any> = (path) => global.core.api?.mkdir(path)
|
||||
|
||||
/**
|
||||
* Removes a directory at the specified path.
|
||||
* @param {string} path - The path of the directory to remove.
|
||||
* @returns {Promise<any>} A Promise that resolves when the directory is removed successfully.
|
||||
*/
|
||||
const rmdir: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.rmdir(path);
|
||||
const rmdir: (path: string) => Promise<any> = (path) => global.core.api?.rmdir(path)
|
||||
/**
|
||||
* Deletes a file from the local file system.
|
||||
* @param {string} path - The path of the file to delete.
|
||||
* @returns {Promise<any>} A Promise that resolves when the file is deleted.
|
||||
*/
|
||||
const deleteFile: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.deleteFile(path);
|
||||
const deleteFile: (path: string) => Promise<any> = (path) => global.core.api?.deleteFile(path)
|
||||
|
||||
/**
|
||||
* Appends data to a file at the specified path.
|
||||
@ -58,10 +58,10 @@ const deleteFile: (path: string) => Promise<any> = (path) =>
|
||||
* @param data data to append
|
||||
*/
|
||||
const appendFile: (path: string, data: string) => Promise<any> = (path, data) =>
|
||||
global.core.api?.appendFile(path, data);
|
||||
global.core.api?.appendFile(path, data)
|
||||
|
||||
const copyFile: (src: string, dest: string) => Promise<any> = (src, dest) =>
|
||||
global.core.api?.copyFile(src, dest);
|
||||
global.core.api?.copyFile(src, dest)
|
||||
|
||||
/**
|
||||
* Reads a file line by line.
|
||||
@ -69,12 +69,13 @@ const copyFile: (src: string, dest: string) => Promise<any> = (src, dest) =>
|
||||
* @returns {Promise<any>} A promise that resolves to the lines of the file.
|
||||
*/
|
||||
const readLineByLine: (path: string) => Promise<any> = (path) =>
|
||||
global.core.api?.readLineByLine(path);
|
||||
global.core.api?.readLineByLine(path)
|
||||
|
||||
export const fs = {
|
||||
isDirectory,
|
||||
writeFile,
|
||||
readFile,
|
||||
exists,
|
||||
listFiles,
|
||||
mkdir,
|
||||
rmdir,
|
||||
@ -82,4 +83,4 @@ export const fs = {
|
||||
appendFile,
|
||||
readLineByLine,
|
||||
copyFile,
|
||||
};
|
||||
}
|
||||
|
||||
@ -41,8 +41,8 @@ export type MessageRequest = {
|
||||
/** Messages for constructing a chat completion request **/
|
||||
messages?: ChatCompletionMessage[];
|
||||
|
||||
/** Runtime parameters for constructing a chat completion request **/
|
||||
parameters?: ModelRuntimeParam;
|
||||
/** Settings for constructing a chat completion request **/
|
||||
model?: ModelInfo;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -153,7 +153,8 @@ export type ThreadAssistantInfo = {
|
||||
export type ModelInfo = {
|
||||
id: string;
|
||||
settings: ModelSettingParams;
|
||||
parameters: ModelRuntimeParam;
|
||||
parameters: ModelRuntimeParams;
|
||||
engine?: InferenceEngine;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -166,6 +167,17 @@ export type ThreadState = {
|
||||
error?: Error;
|
||||
lastMessage?: string;
|
||||
};
|
||||
/**
|
||||
* Represents the inference engine.
|
||||
* @stored
|
||||
*/
|
||||
|
||||
enum InferenceEngine {
|
||||
nitro = "nitro",
|
||||
openai = "openai",
|
||||
nvidia_triton = "nvidia_triton",
|
||||
hf_endpoint = "hf_endpoint",
|
||||
}
|
||||
|
||||
/**
|
||||
* Model type defines the shape of a model object.
|
||||
@ -228,12 +240,16 @@ export interface Model {
|
||||
/**
|
||||
* The model runtime parameters.
|
||||
*/
|
||||
parameters: ModelRuntimeParam;
|
||||
parameters: ModelRuntimeParams;
|
||||
|
||||
/**
|
||||
* Metadata of the model.
|
||||
*/
|
||||
metadata: ModelMetadata;
|
||||
/**
|
||||
* The model engine.
|
||||
*/
|
||||
engine: InferenceEngine;
|
||||
}
|
||||
|
||||
export type ModelMetadata = {
|
||||
@ -268,7 +284,7 @@ export type ModelSettingParams = {
|
||||
/**
|
||||
* The available model runtime parameters.
|
||||
*/
|
||||
export type ModelRuntimeParam = {
|
||||
export type ModelRuntimeParams = {
|
||||
temperature?: number;
|
||||
token_limit?: number;
|
||||
top_k?: number;
|
||||
|
||||
@ -289,7 +289,7 @@ components:
|
||||
engine:
|
||||
type: string
|
||||
description: "The engine used by the model."
|
||||
example: "llamacpp"
|
||||
enum: [nitro, openai, hf_inference]
|
||||
quantization:
|
||||
type: string
|
||||
description: "Quantization parameter of the model."
|
||||
|
||||
@ -50,6 +50,19 @@ export function handleFsIPCs() {
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* Checks whether a file exists in the user data directory.
|
||||
* @param event - The event object.
|
||||
* @param path - The path of the file to check.
|
||||
* @returns A promise that resolves with a boolean indicating whether the file exists.
|
||||
*/
|
||||
ipcMain.handle('exists', async (_event, path: string) => {
|
||||
return new Promise((resolve, reject) => {
|
||||
const fullPath = join(userSpacePath, path)
|
||||
fs.existsSync(fullPath) ? resolve(true) : resolve(false)
|
||||
})
|
||||
})
|
||||
|
||||
/**
|
||||
* Writes data to a file in the user data directory.
|
||||
* @param event - The event object.
|
||||
|
||||
@ -27,6 +27,12 @@ export function fsInvokers() {
|
||||
*/
|
||||
readFile: (path: string) => ipcRenderer.invoke('readFile', path),
|
||||
|
||||
/**
|
||||
* Reads a file at the specified path.
|
||||
* @param {string} path - The path of the file to read.
|
||||
*/
|
||||
exists: (path: string) => ipcRenderer.invoke('exists', path),
|
||||
|
||||
/**
|
||||
* Writes data to a file at the specified path.
|
||||
* @param {string} path - The path of the file to write to.
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
@echo off
|
||||
set /p NITRO_VERSION=<./nitro/version.txt
|
||||
.\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64-cuda.tar.gz -e --strip 1 -o ./nitro/win-cuda && .\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64.tar.gz -e --strip 1 -o ./nitro/win-cpu
|
||||
@ -1,57 +0,0 @@
|
||||
{
|
||||
"name": "@janhq/inference-extension",
|
||||
"version": "1.0.0",
|
||||
"description": "Inference Extension, powered by @janhq/nitro, bring a high-performance Llama model inference in pure C++.",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/module.js",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"build": "tsc -b . && webpack --config webpack.config.js",
|
||||
"downloadnitro:linux": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./nitro/linux-cpu && chmod +x ./nitro/linux-cpu/nitro && chmod +x ./nitro/linux-start.sh && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda.tar.gz -e --strip 1 -o ./nitro/linux-cuda && chmod +x ./nitro/linux-cuda/nitro && chmod +x ./nitro/linux-start.sh",
|
||||
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./nitro/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./nitro/mac-arm64 && chmod +x ./nitro/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./nitro/mac-x64 && chmod +x ./nitro/mac-x64/nitro",
|
||||
"downloadnitro:win32": "download.bat",
|
||||
"downloadnitro": "run-script-os",
|
||||
"build:publish:darwin": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish:win32": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish:linux": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"nitro/**\" \"dist/nitro\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish": "run-script-os"
|
||||
},
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
"./main": "./dist/module.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"run-script-os": "^1.1.6",
|
||||
"webpack": "^5.88.2",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core",
|
||||
"download-cli": "^1.1.1",
|
||||
"electron-log": "^5.0.1",
|
||||
"fetch-retry": "^5.0.6",
|
||||
"kill-port": "^2.0.1",
|
||||
"path-browserify": "^1.0.1",
|
||||
"rxjs": "^7.8.1",
|
||||
"tcp-port-used": "^1.0.2",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ulid": "^2.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
],
|
||||
"bundleDependencies": [
|
||||
"tcp-port-used",
|
||||
"kill-port",
|
||||
"fetch-retry",
|
||||
"electron-log"
|
||||
]
|
||||
}
|
||||
@ -1,2 +0,0 @@
|
||||
declare const MODULE: string;
|
||||
declare const INFERENCE_URL: string;
|
||||
3
extensions/inference-nitro-extension/download.bat
Normal file
3
extensions/inference-nitro-extension/download.bat
Normal file
@ -0,0 +1,3 @@
|
||||
@echo off
|
||||
set /p NITRO_VERSION=<./bin/version.txt
|
||||
.\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64-cuda.tar.gz -e --strip 1 -o ./bin/win-cuda && .\node_modules\.bin\download https://github.com/janhq/nitro/releases/download/v%NITRO_VERSION%/nitro-%NITRO_VERSION%-win-amd64.tar.gz -e --strip 1 -o ./bin/win-cpu
|
||||
57
extensions/inference-nitro-extension/package.json
Normal file
57
extensions/inference-nitro-extension/package.json
Normal file
@ -0,0 +1,57 @@
|
||||
{
|
||||
"name": "@janhq/inference-nitro-extension",
|
||||
"version": "1.0.0",
|
||||
"description": "Inference Engine for Nitro Extension, powered by @janhq/nitro, bring a high-performance Llama model inference in pure C++.",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/module.js",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"build": "tsc -b . && webpack --config webpack.config.js",
|
||||
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && chmod +x ./bin/linux-start.sh && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda.tar.gz -e --strip 1 -o ./bin/linux-cuda && chmod +x ./bin/linux-cuda/nitro && chmod +x ./bin/linux-start.sh",
|
||||
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
|
||||
"downloadnitro:win32": "download.bat",
|
||||
"downloadnitro": "run-script-os",
|
||||
"build:publish:darwin": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish:win32": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish:linux": "rimraf *.tgz --glob && npm run build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../electron/pre-install",
|
||||
"build:publish": "run-script-os"
|
||||
},
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
"./main": "./dist/module.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"run-script-os": "^1.1.6",
|
||||
"webpack": "^5.88.2",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core",
|
||||
"download-cli": "^1.1.1",
|
||||
"electron-log": "^5.0.1",
|
||||
"fetch-retry": "^5.0.6",
|
||||
"kill-port": "^2.0.1",
|
||||
"path-browserify": "^1.0.1",
|
||||
"rxjs": "^7.8.1",
|
||||
"tcp-port-used": "^1.0.2",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ulid": "^2.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
],
|
||||
"bundleDependencies": [
|
||||
"tcp-port-used",
|
||||
"kill-port",
|
||||
"fetch-retry",
|
||||
"electron-log"
|
||||
]
|
||||
}
|
||||
26
extensions/inference-nitro-extension/src/@types/global.d.ts
vendored
Normal file
26
extensions/inference-nitro-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
declare const MODULE: string;
|
||||
declare const INFERENCE_URL: string;
|
||||
|
||||
/**
|
||||
* The parameters for the initModel function.
|
||||
* @property settings - The settings for the machine learning model.
|
||||
* @property settings.ctx_len - The context length.
|
||||
* @property settings.ngl - The number of generated tokens.
|
||||
* @property settings.cont_batching - Whether to use continuous batching.
|
||||
* @property settings.embedding - Whether to use embedding.
|
||||
*/
|
||||
interface EngineSettings {
|
||||
ctx_len: number;
|
||||
ngl: number;
|
||||
cont_batching: boolean;
|
||||
embedding: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* The response from the initModel function.
|
||||
* @property error - An error message if the model fails to load.
|
||||
*/
|
||||
interface ModelOperationResponse {
|
||||
error?: any;
|
||||
modelFile?: string;
|
||||
}
|
||||
@ -1,3 +1,4 @@
|
||||
import { Model } from "@janhq/core";
|
||||
import { Observable } from "rxjs";
|
||||
/**
|
||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||
@ -6,21 +7,23 @@ import { Observable } from "rxjs";
|
||||
*/
|
||||
export function requestInference(
|
||||
recentMessages: any[],
|
||||
engine: EngineSettings,
|
||||
model: Model,
|
||||
controller?: AbortController
|
||||
): Observable<string> {
|
||||
return new Observable((subscriber) => {
|
||||
const requestBody = JSON.stringify({
|
||||
messages: recentMessages,
|
||||
model: model.id,
|
||||
stream: true,
|
||||
model: "gpt-3.5-turbo",
|
||||
max_tokens: 2048,
|
||||
// ...model.parameters,
|
||||
});
|
||||
fetch(INFERENCE_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
Accept: "text/event-stream",
|
||||
},
|
||||
body: requestBody,
|
||||
signal: controller?.signal,
|
||||
@ -19,6 +19,8 @@ import {
|
||||
events,
|
||||
executeOnMain,
|
||||
getUserSpace,
|
||||
fs,
|
||||
Model,
|
||||
} from "@janhq/core";
|
||||
import { InferenceExtension } from "@janhq/core";
|
||||
import { requestInference } from "./helpers/sse";
|
||||
@ -30,7 +32,19 @@ import { join } from "path";
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
export default class JanInferenceExtension implements InferenceExtension {
|
||||
export default class JanInferenceNitroExtension implements InferenceExtension {
|
||||
private static readonly _homeDir = "engines";
|
||||
private static readonly _engineMetadataFileName = "nitro.json";
|
||||
|
||||
private static _currentModel: Model;
|
||||
|
||||
private static _engineSettings: EngineSettings = {
|
||||
ctx_len: 2048,
|
||||
ngl: 100,
|
||||
cont_batching: false,
|
||||
embedding: false,
|
||||
};
|
||||
|
||||
controller = new AbortController();
|
||||
isCancelled = false;
|
||||
/**
|
||||
@ -45,51 +59,88 @@ export default class JanInferenceExtension implements InferenceExtension {
|
||||
* Subscribes to events emitted by the @janhq/core package.
|
||||
*/
|
||||
onLoad(): void {
|
||||
fs.mkdir(JanInferenceNitroExtension._homeDir);
|
||||
this.writeDefaultEngineSettings();
|
||||
|
||||
// Events subscription
|
||||
events.on(EventName.OnMessageSent, (data) =>
|
||||
JanInferenceExtension.handleMessageRequest(data, this)
|
||||
JanInferenceNitroExtension.handleMessageRequest(data, this)
|
||||
);
|
||||
|
||||
events.on(EventName.OnModelInit, (model: Model) => {
|
||||
JanInferenceNitroExtension.handleModelInit(model);
|
||||
});
|
||||
|
||||
events.on(EventName.OnModelStop, (model: Model) => {
|
||||
JanInferenceNitroExtension.handleModelStop(model);
|
||||
});
|
||||
|
||||
events.on(EventName.OnInferenceStopped, () => {
|
||||
JanInferenceNitroExtension.handleInferenceStopped(this);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the model inference.
|
||||
*/
|
||||
onUnload(): void {
|
||||
this.stopModel();
|
||||
onUnload(): void {}
|
||||
|
||||
|
||||
private async writeDefaultEngineSettings() {
|
||||
try {
|
||||
const engineFile = join(
|
||||
JanInferenceNitroExtension._homeDir,
|
||||
JanInferenceNitroExtension._engineMetadataFileName
|
||||
);
|
||||
if (await fs.exists(engineFile)) {
|
||||
JanInferenceNitroExtension._engineSettings = JSON.parse(
|
||||
await fs.readFile(engineFile)
|
||||
);
|
||||
} else {
|
||||
await fs.writeFile(
|
||||
engineFile,
|
||||
JSON.stringify(JanInferenceNitroExtension._engineSettings, null, 2)
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the model with the specified file name.
|
||||
* @param {string} modelId - The ID of the model to initialize.
|
||||
* @returns {Promise<void>} A promise that resolves when the model is initialized.
|
||||
*/
|
||||
async initModel(
|
||||
modelId: string,
|
||||
settings?: ModelSettingParams
|
||||
): Promise<void> {
|
||||
private static async handleModelInit(model: Model) {
|
||||
if (model.engine !== "nitro") {
|
||||
return;
|
||||
}
|
||||
const userSpacePath = await getUserSpace();
|
||||
const modelFullPath = join(userSpacePath, "models", modelId, modelId);
|
||||
const modelFullPath = join(userSpacePath, "models", model.id, model.id);
|
||||
|
||||
return executeOnMain(MODULE, "initModel", {
|
||||
modelFullPath,
|
||||
settings,
|
||||
const nitroInitResult = await executeOnMain(MODULE, "initModel", {
|
||||
modelFullPath: modelFullPath,
|
||||
model: model,
|
||||
});
|
||||
|
||||
if (nitroInitResult.error === null) {
|
||||
events.emit(EventName.OnModelFail, model);
|
||||
} else {
|
||||
JanInferenceNitroExtension._currentModel = model;
|
||||
events.emit(EventName.OnModelReady, model);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the model.
|
||||
* @returns {Promise<void>} A promise that resolves when the model is stopped.
|
||||
*/
|
||||
async stopModel(): Promise<void> {
|
||||
return executeOnMain(MODULE, "killSubprocess");
|
||||
private static async handleModelStop(model: Model) {
|
||||
if (model.engine !== "nitro") {
|
||||
return;
|
||||
} else {
|
||||
await executeOnMain(MODULE, "stopModel");
|
||||
events.emit(EventName.OnModelStopped, model);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops streaming inference.
|
||||
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
|
||||
*/
|
||||
async stopInference(): Promise<void> {
|
||||
this.isCancelled = true;
|
||||
this.controller?.abort();
|
||||
private static async handleInferenceStopped(
|
||||
instance: JanInferenceNitroExtension
|
||||
) {
|
||||
instance.isCancelled = true;
|
||||
instance.controller?.abort();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -97,7 +148,7 @@ export default class JanInferenceExtension implements InferenceExtension {
|
||||
* @param {MessageRequest} data - The data for the inference request.
|
||||
* @returns {Promise<any>} A promise that resolves with the inference response.
|
||||
*/
|
||||
async inferenceRequest(data: MessageRequest): Promise<ThreadMessage> {
|
||||
async inference(data: MessageRequest): Promise<ThreadMessage> {
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
thread_id: data.threadId,
|
||||
@ -111,7 +162,11 @@ export default class JanInferenceExtension implements InferenceExtension {
|
||||
};
|
||||
|
||||
return new Promise(async (resolve, reject) => {
|
||||
requestInference(data.messages ?? []).subscribe({
|
||||
requestInference(
|
||||
data.messages ?? [],
|
||||
JanInferenceNitroExtension._engineSettings,
|
||||
JanInferenceNitroExtension._currentModel
|
||||
).subscribe({
|
||||
next: (_content) => {},
|
||||
complete: async () => {
|
||||
resolve(message);
|
||||
@ -131,8 +186,11 @@ export default class JanInferenceExtension implements InferenceExtension {
|
||||
*/
|
||||
private static async handleMessageRequest(
|
||||
data: MessageRequest,
|
||||
instance: JanInferenceExtension
|
||||
instance: JanInferenceNitroExtension
|
||||
) {
|
||||
if (data.model.engine !== "nitro") {
|
||||
return;
|
||||
}
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
id: ulid(),
|
||||
@ -150,7 +208,12 @@ export default class JanInferenceExtension implements InferenceExtension {
|
||||
instance.isCancelled = false;
|
||||
instance.controller = new AbortController();
|
||||
|
||||
requestInference(data.messages, instance.controller).subscribe({
|
||||
requestInference(
|
||||
data.messages ?? [],
|
||||
JanInferenceNitroExtension._engineSettings,
|
||||
JanInferenceNitroExtension._currentModel,
|
||||
instance.controller
|
||||
).subscribe({
|
||||
next: (content) => {
|
||||
const messageContent: ThreadContent = {
|
||||
type: ContentType.Text,
|
||||
@ -20,51 +20,51 @@ let subprocess = null;
|
||||
let currentModelFile = null;
|
||||
|
||||
/**
|
||||
* The response from the initModel function.
|
||||
* @property error - An error message if the model fails to load.
|
||||
* Stops a Nitro subprocess.
|
||||
* @param wrapper - The model wrapper.
|
||||
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
|
||||
*/
|
||||
interface InitModelResponse {
|
||||
error?: any;
|
||||
modelFile?: string;
|
||||
function stopModel(): Promise<ModelOperationResponse> {
|
||||
return new Promise((resolve, reject) => {
|
||||
checkAndUnloadNitro();
|
||||
resolve({ error: undefined });
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes a Nitro subprocess to load a machine learning model.
|
||||
* @param modelFile - The name of the machine learning model file.
|
||||
* @param wrapper - The model wrapper.
|
||||
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
|
||||
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
|
||||
* TODO: Should it be startModel instead?
|
||||
*/
|
||||
function initModel(wrapper: any): Promise<InitModelResponse> {
|
||||
// 1. Check if the model file exists
|
||||
function initModel(wrapper: any): Promise<ModelOperationResponse> {
|
||||
currentModelFile = wrapper.modelFullPath;
|
||||
log.info("Started to load model " + wrapper.modelFullPath);
|
||||
|
||||
const settings = {
|
||||
llama_model_path: currentModelFile,
|
||||
ctx_len: 2048,
|
||||
ngl: 100,
|
||||
cont_batching: false,
|
||||
embedding: false, // Always enable embedding mode on
|
||||
...wrapper.settings,
|
||||
};
|
||||
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
|
||||
|
||||
return (
|
||||
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process
|
||||
validateModelVersion()
|
||||
.then(checkAndUnloadNitro)
|
||||
// 2. Spawn the Nitro subprocess
|
||||
.then(spawnNitroProcess)
|
||||
// 4. Load the model into the Nitro subprocess (HTTP POST request)
|
||||
.then(() => loadLLMModel(settings))
|
||||
// 5. Check if the model is loaded successfully
|
||||
.then(validateModelStatus)
|
||||
.catch((err) => {
|
||||
log.error("error: " + JSON.stringify(err));
|
||||
return { error: err, currentModelFile };
|
||||
})
|
||||
);
|
||||
if (wrapper.model.engine !== "nitro") {
|
||||
return Promise.resolve({ error: "Not a nitro model" });
|
||||
} else {
|
||||
log.info("Started to load model " + wrapper.model.modelFullPath);
|
||||
const settings = {
|
||||
llama_model_path: currentModelFile,
|
||||
...wrapper.model.settings,
|
||||
};
|
||||
log.info(`Load model settings: ${JSON.stringify(settings, null, 2)}`);
|
||||
return (
|
||||
// 1. Check if the port is used, if used, attempt to unload model / kill nitro process
|
||||
validateModelVersion()
|
||||
.then(checkAndUnloadNitro)
|
||||
// 2. Spawn the Nitro subprocess
|
||||
.then(spawnNitroProcess)
|
||||
// 4. Load the model into the Nitro subprocess (HTTP POST request)
|
||||
.then(() => loadLLMModel(settings))
|
||||
// 5. Check if the model is loaded successfully
|
||||
.then(validateModelStatus)
|
||||
.catch((err) => {
|
||||
log.error("error: " + JSON.stringify(err));
|
||||
return { error: err, currentModelFile };
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -91,11 +91,11 @@ function loadLLMModel(settings): Promise<Response> {
|
||||
|
||||
/**
|
||||
* Validates the status of a model.
|
||||
* @returns {Promise<InitModelResponse>} A promise that resolves to an object.
|
||||
* @returns {Promise<ModelOperationResponse>} A promise that resolves to an object.
|
||||
* If the model is loaded successfully, the object is empty.
|
||||
* If the model is not loaded successfully, the object contains an error message.
|
||||
*/
|
||||
async function validateModelStatus(): Promise<InitModelResponse> {
|
||||
async function validateModelStatus(): Promise<ModelOperationResponse> {
|
||||
// Send a GET request to the validation URL.
|
||||
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
||||
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
||||
@ -142,8 +142,8 @@ function killSubprocess(): Promise<void> {
|
||||
* Check port is used or not, if used, attempt to unload model
|
||||
* If unload failed, kill the port
|
||||
*/
|
||||
function checkAndUnloadNitro() {
|
||||
return tcpPortUsed.check(PORT, LOCAL_HOST).then((inUse) => {
|
||||
async function checkAndUnloadNitro() {
|
||||
return tcpPortUsed.check(PORT, LOCAL_HOST).then(async (inUse) => {
|
||||
// If inUse - try unload or kill process, otherwise do nothing
|
||||
if (inUse) {
|
||||
// Attempt to unload model
|
||||
@ -168,7 +168,7 @@ function checkAndUnloadNitro() {
|
||||
*/
|
||||
async function spawnNitroProcess(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
let binaryFolder = path.join(__dirname, "nitro"); // Current directory by default
|
||||
let binaryFolder = path.join(__dirname, "bin"); // Current directory by default
|
||||
let binaryName;
|
||||
|
||||
if (process.platform === "win32") {
|
||||
78
extensions/inference-openai-extension/README.md
Normal file
78
extensions/inference-openai-extension/README.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Jan inference plugin
|
||||
|
||||
Created using Jan app example
|
||||
|
||||
# Create a Jan Plugin using Typescript
|
||||
|
||||
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
|
||||
|
||||
## Create Your Own Plugin
|
||||
|
||||
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
|
||||
|
||||
1. Click the Use this template button at the top of the repository
|
||||
2. Select Create a new repository
|
||||
3. Select an owner and name for your new repository
|
||||
4. Click Create repository
|
||||
5. Clone your new repository
|
||||
|
||||
## Initial Setup
|
||||
|
||||
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> You'll need to have a reasonably modern version of
|
||||
> [Node.js](https://nodejs.org) handy. If you are using a version manager like
|
||||
> [`nodenv`](https://github.com/nodenv/nodenv) or
|
||||
> [`nvm`](https://github.com/nvm-sh/nvm), you can run `nodenv install` in the
|
||||
> root of your repository to install the version specified in
|
||||
> [`package.json`](./package.json). Otherwise, 20.x or later should work!
|
||||
|
||||
1. :hammer_and_wrench: Install the dependencies
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
1. :building_construction: Package the TypeScript for distribution
|
||||
|
||||
```bash
|
||||
npm run bundle
|
||||
```
|
||||
|
||||
1. :white_check_mark: Check your artifact
|
||||
|
||||
There will be a tgz file in your plugin directory now
|
||||
|
||||
## Update the Plugin Metadata
|
||||
|
||||
The [`package.json`](package.json) file defines metadata about your plugin, such as
|
||||
plugin name, main entry, description and version.
|
||||
|
||||
When you copy this repository, update `package.json` with the name, description for your plugin.
|
||||
|
||||
## Update the Plugin Code
|
||||
|
||||
The [`src/`](./src/) directory is the heart of your plugin! This contains the
|
||||
source code that will be run when your plugin extension functions are invoked. You can replace the
|
||||
contents of this directory with your own code.
|
||||
|
||||
There are a few things to keep in mind when writing your plugin code:
|
||||
|
||||
- Most Jan Plugin Extension functions are processed asynchronously.
|
||||
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
|
||||
|
||||
```typescript
|
||||
import { core } from "@janhq/core";
|
||||
|
||||
function onStart(): Promise<any> {
|
||||
return core.invokePluginFunc(MODULE_PATH, "run", 0);
|
||||
}
|
||||
```
|
||||
|
||||
For more information about the Jan Plugin Core module, see the
|
||||
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
|
||||
|
||||
So, what are you waiting for? Go ahead and start customizing your plugin!
|
||||
|
||||
41
extensions/inference-openai-extension/package.json
Normal file
41
extensions/inference-openai-extension/package.json
Normal file
@ -0,0 +1,41 @@
|
||||
{
|
||||
"name": "@janhq/inference-openai-extension",
|
||||
"version": "1.0.0",
|
||||
"description": "Inference Engine for OpenAI Extension that can be used with any OpenAI compatible API",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/module.js",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"build": "tsc -b . && webpack --config webpack.config.js",
|
||||
"build:publish": "rimraf *.tgz --glob && npm run build && npm pack && cpx *.tgz ../../electron/pre-install"
|
||||
},
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
"./main": "./dist/module.js"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"webpack": "^5.88.2",
|
||||
"webpack-cli": "^5.1.4"
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "file:../../core",
|
||||
"fetch-retry": "^5.0.6",
|
||||
"path-browserify": "^1.0.1",
|
||||
"ts-loader": "^9.5.0",
|
||||
"ulid": "^2.3.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
],
|
||||
"bundleDependencies": [
|
||||
"fetch-retry"
|
||||
]
|
||||
}
|
||||
27
extensions/inference-openai-extension/src/@types/global.d.ts
vendored
Normal file
27
extensions/inference-openai-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
import { Model } from "@janhq/core";
|
||||
|
||||
declare const MODULE: string;
|
||||
|
||||
declare interface EngineSettings {
|
||||
full_url?: string;
|
||||
api_key?: string;
|
||||
}
|
||||
|
||||
enum OpenAIChatCompletionModelName {
|
||||
"gpt-3.5-turbo-instruct" = "gpt-3.5-turbo-instruct",
|
||||
"gpt-3.5-turbo-instruct-0914" = "gpt-3.5-turbo-instruct-0914",
|
||||
"gpt-4-1106-preview" = "gpt-4-1106-preview",
|
||||
"gpt-3.5-turbo-0613" = "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-0301" = "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo" = "gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k-0613" = "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-1106" = "gpt-3.5-turbo-1106",
|
||||
"gpt-4-vision-preview" = "gpt-4-vision-preview",
|
||||
"gpt-4" = "gpt-4",
|
||||
"gpt-4-0314" = "gpt-4-0314",
|
||||
"gpt-4-0613" = "gpt-4-0613",
|
||||
}
|
||||
|
||||
declare type OpenAIModel = Omit<Model, "id"> & {
|
||||
id: OpenAIChatCompletionModelName;
|
||||
};
|
||||
68
extensions/inference-openai-extension/src/helpers/sse.ts
Normal file
68
extensions/inference-openai-extension/src/helpers/sse.ts
Normal file
@ -0,0 +1,68 @@
|
||||
import { Observable } from "rxjs";
|
||||
import { EngineSettings, OpenAIModel } from "../@types/global";
|
||||
|
||||
/**
|
||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
||||
* @param engine - The engine settings to use for the inference.
|
||||
* @param model - The model to use for the inference.
|
||||
* @returns An Observable that emits the generated response as a string.
|
||||
*/
|
||||
export function requestInference(
|
||||
recentMessages: any[],
|
||||
engine: EngineSettings,
|
||||
model: OpenAIModel,
|
||||
controller?: AbortController
|
||||
): Observable<string> {
|
||||
return new Observable((subscriber) => {
|
||||
let model_id: string = model.id
|
||||
if (engine.full_url.includes("openai.azure.com")){
|
||||
model_id = engine.full_url.split("/")[5]
|
||||
}
|
||||
const requestBody = JSON.stringify({
|
||||
messages: recentMessages,
|
||||
stream: true,
|
||||
model: model_id
|
||||
// ...model.parameters,
|
||||
});
|
||||
fetch(`${engine.full_url}`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
Authorization: `Bearer ${engine.api_key}`,
|
||||
"api-key": `${engine.api_key}`,
|
||||
},
|
||||
body: requestBody,
|
||||
signal: controller?.signal,
|
||||
})
|
||||
.then(async (response) => {
|
||||
const stream = response.body;
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
const reader = stream?.getReader();
|
||||
let content = "";
|
||||
|
||||
while (true && reader) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.trim().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ") && !line.includes("data: [DONE]")) {
|
||||
const data = JSON.parse(line.replace("data: ", ""));
|
||||
content += data.choices[0]?.delta?.content ?? "";
|
||||
if (content.startsWith("assistant: ")) {
|
||||
content = content.replace("assistant: ", "");
|
||||
}
|
||||
subscriber.next(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
subscriber.complete();
|
||||
})
|
||||
.catch((err) => subscriber.error(err));
|
||||
});
|
||||
}
|
||||
231
extensions/inference-openai-extension/src/index.ts
Normal file
231
extensions/inference-openai-extension/src/index.ts
Normal file
@ -0,0 +1,231 @@
|
||||
/**
|
||||
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
* @version 1.0.0
|
||||
* @module inference-openai-extension/src/index
|
||||
*/
|
||||
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
EventName,
|
||||
MessageRequest,
|
||||
MessageStatus,
|
||||
ModelSettingParams,
|
||||
ExtensionType,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
events,
|
||||
fs,
|
||||
} from "@janhq/core";
|
||||
import { InferenceExtension } from "@janhq/core";
|
||||
import { requestInference } from "./helpers/sse";
|
||||
import { ulid } from "ulid";
|
||||
import { join } from "path";
|
||||
import { EngineSettings, OpenAIModel } from "./@types/global";
|
||||
|
||||
/**
|
||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
export default class JanInferenceOpenAIExtension implements InferenceExtension {
|
||||
private static readonly _homeDir = "engines";
|
||||
private static readonly _engineMetadataFileName = "openai.json";
|
||||
|
||||
private static _currentModel: OpenAIModel;
|
||||
|
||||
private static _engineSettings: EngineSettings = {
|
||||
full_url: "https://api.openai.com/v1/chat/completions",
|
||||
api_key: "sk-<your key here>",
|
||||
};
|
||||
|
||||
controller = new AbortController();
|
||||
isCancelled = false;
|
||||
|
||||
/**
|
||||
* Returns the type of the extension.
|
||||
* @returns {ExtensionType} The type of the extension.
|
||||
*/
|
||||
// TODO: To fix
|
||||
type(): ExtensionType {
|
||||
return undefined;
|
||||
}
|
||||
/**
|
||||
* Subscribes to events emitted by the @janhq/core package.
|
||||
*/
|
||||
onLoad(): void {
|
||||
fs.mkdir(JanInferenceOpenAIExtension._homeDir);
|
||||
JanInferenceOpenAIExtension.writeDefaultEngineSettings();
|
||||
|
||||
// Events subscription
|
||||
events.on(EventName.OnMessageSent, (data) =>
|
||||
JanInferenceOpenAIExtension.handleMessageRequest(data, this)
|
||||
);
|
||||
|
||||
events.on(EventName.OnModelInit, (model: OpenAIModel) => {
|
||||
JanInferenceOpenAIExtension.handleModelInit(model);
|
||||
});
|
||||
|
||||
events.on(EventName.OnModelStop, (model: OpenAIModel) => {
|
||||
JanInferenceOpenAIExtension.handleModelStop(model);
|
||||
});
|
||||
events.on(EventName.OnInferenceStopped, () => {
|
||||
JanInferenceOpenAIExtension.handleInferenceStopped(this);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops the model inference.
|
||||
*/
|
||||
onUnload(): void {}
|
||||
|
||||
static async writeDefaultEngineSettings() {
|
||||
try {
|
||||
const engineFile = join(
|
||||
JanInferenceOpenAIExtension._homeDir,
|
||||
JanInferenceOpenAIExtension._engineMetadataFileName
|
||||
);
|
||||
if (await fs.exists(engineFile)) {
|
||||
JanInferenceOpenAIExtension._engineSettings = JSON.parse(
|
||||
await fs.readFile(engineFile)
|
||||
);
|
||||
} else {
|
||||
await fs.writeFile(
|
||||
engineFile,
|
||||
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Makes a single response inference request.
|
||||
* @param {MessageRequest} data - The data for the inference request.
|
||||
* @returns {Promise<any>} A promise that resolves with the inference response.
|
||||
*/
|
||||
async inference(data: MessageRequest): Promise<ThreadMessage> {
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
thread_id: data.threadId,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
status: MessageStatus.Ready,
|
||||
id: "",
|
||||
role: ChatCompletionRole.Assistant,
|
||||
object: "thread.message",
|
||||
content: [],
|
||||
};
|
||||
|
||||
return new Promise(async (resolve, reject) => {
|
||||
requestInference(
|
||||
data.messages ?? [],
|
||||
JanInferenceOpenAIExtension._engineSettings,
|
||||
JanInferenceOpenAIExtension._currentModel
|
||||
).subscribe({
|
||||
next: (_content) => {},
|
||||
complete: async () => {
|
||||
resolve(message);
|
||||
},
|
||||
error: async (err) => {
|
||||
reject(err);
|
||||
},
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private static async handleModelInit(model: OpenAIModel) {
|
||||
if (model.engine !== "openai") {
|
||||
return;
|
||||
} else {
|
||||
JanInferenceOpenAIExtension._currentModel = model;
|
||||
JanInferenceOpenAIExtension.writeDefaultEngineSettings();
|
||||
// Todo: Check model list with API key
|
||||
events.emit(EventName.OnModelReady, model);
|
||||
}
|
||||
}
|
||||
|
||||
private static async handleModelStop(model: OpenAIModel) {
|
||||
if (model.engine !== "openai") {
|
||||
return;
|
||||
}
|
||||
events.emit(EventName.OnModelStopped, model);
|
||||
}
|
||||
|
||||
private static async handleInferenceStopped(
|
||||
instance: JanInferenceOpenAIExtension
|
||||
) {
|
||||
instance.isCancelled = true;
|
||||
instance.controller?.abort();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles a new message request by making an inference request and emitting events.
|
||||
* Function registered in event manager, should be static to avoid binding issues.
|
||||
* Pass instance as a reference.
|
||||
* @param {MessageRequest} data - The data for the new message request.
|
||||
*/
|
||||
private static async handleMessageRequest(
|
||||
data: MessageRequest,
|
||||
instance: JanInferenceOpenAIExtension
|
||||
) {
|
||||
if (data.model.engine !== "openai") {
|
||||
return;
|
||||
}
|
||||
|
||||
const timestamp = Date.now();
|
||||
const message: ThreadMessage = {
|
||||
id: ulid(),
|
||||
thread_id: data.threadId,
|
||||
assistant_id: data.assistantId,
|
||||
role: ChatCompletionRole.Assistant,
|
||||
content: [],
|
||||
status: MessageStatus.Pending,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: "thread.message",
|
||||
};
|
||||
events.emit(EventName.OnMessageResponse, message);
|
||||
|
||||
instance.isCancelled = false;
|
||||
instance.controller = new AbortController();
|
||||
|
||||
requestInference(
|
||||
data?.messages ?? [],
|
||||
this._engineSettings,
|
||||
JanInferenceOpenAIExtension._currentModel,
|
||||
instance.controller
|
||||
).subscribe({
|
||||
next: (content) => {
|
||||
const messageContent: ThreadContent = {
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: content.trim(),
|
||||
annotations: [],
|
||||
},
|
||||
};
|
||||
message.content = [messageContent];
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
complete: async () => {
|
||||
message.status = MessageStatus.Ready;
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
error: async (err) => {
|
||||
const messageContent: ThreadContent = {
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: "Error occurred: " + err.message,
|
||||
annotations: [],
|
||||
},
|
||||
};
|
||||
message.content = [messageContent];
|
||||
message.status = MessageStatus.Ready;
|
||||
events.emit(EventName.OnMessageUpdate, message);
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
15
extensions/inference-openai-extension/tsconfig.json
Normal file
15
extensions/inference-openai-extension/tsconfig.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2016",
|
||||
"module": "ES6",
|
||||
"moduleResolution": "node",
|
||||
|
||||
"outDir": "./dist",
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"strict": false,
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"]
|
||||
}
|
||||
42
extensions/inference-openai-extension/webpack.config.js
Normal file
42
extensions/inference-openai-extension/webpack.config.js
Normal file
@ -0,0 +1,42 @@
|
||||
const path = require("path");
|
||||
const webpack = require("webpack");
|
||||
const packageJson = require("./package.json");
|
||||
|
||||
module.exports = {
|
||||
experiments: { outputModule: true },
|
||||
entry: "./src/index.ts", // Adjust the entry point to match your project's main file
|
||||
mode: "production",
|
||||
module: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.tsx?$/,
|
||||
use: "ts-loader",
|
||||
exclude: /node_modules/,
|
||||
},
|
||||
],
|
||||
},
|
||||
plugins: [
|
||||
new webpack.DefinePlugin({
|
||||
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||
INFERENCE_URL: JSON.stringify(
|
||||
process.env.INFERENCE_URL ||
|
||||
"http://127.0.0.1:3928/inferences/llamacpp/chat_completion"
|
||||
),
|
||||
}),
|
||||
],
|
||||
output: {
|
||||
filename: "index.js", // Adjust the output file name as needed
|
||||
path: path.resolve(__dirname, "dist"),
|
||||
library: { type: "module" }, // Specify ESM output format
|
||||
},
|
||||
resolve: {
|
||||
extensions: [".ts", ".js"],
|
||||
fallback: {
|
||||
path: require.resolve("path-browserify"),
|
||||
},
|
||||
},
|
||||
optimization: {
|
||||
minimize: false,
|
||||
},
|
||||
// Add loaders and other configuration as needed for your project
|
||||
};
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "NousResearch, The Bloke",
|
||||
"tags": ["34B", "Finetuned"],
|
||||
"size": 24320000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
|
||||
{
|
||||
"source_url": "https://huggingface.co/TheBloke/deepseek-coder-1.3b-instruct-GGUF/resolve/main/deepseek-coder-1.3b-instruct.Q8_0.gguf",
|
||||
"id": "deepseek-coder-1.3b",
|
||||
@ -19,5 +20,6 @@
|
||||
"author": "Deepseek, The Bloke",
|
||||
"tags": ["Tiny", "Foundational Model"],
|
||||
"size": 1430000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Deepseek, The Bloke",
|
||||
"tags": ["34B", "Foundational Model"],
|
||||
"size": 26040000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
20
models/gpt-3.5-turbo-16k-0613/model.json
Normal file
20
models/gpt-3.5-turbo-16k-0613/model.json
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"source_url": "https://openai.com",
|
||||
"id": "gpt-3.5-turbo-16k-0613",
|
||||
"object": "model",
|
||||
"name": "OpenAI GPT 3.5 Turbo 16k 0613",
|
||||
"version": 1.0,
|
||||
"description": "OpenAI GPT 3.5 Turbo 16k 0613 model is extremely good",
|
||||
"format": "api",
|
||||
"settings": {},
|
||||
"parameters": {
|
||||
"max_tokens": 4096
|
||||
},
|
||||
"metadata": {
|
||||
"author": "OpenAI",
|
||||
"tags": ["General", "Big Context Length"]
|
||||
},
|
||||
"engine": "openai",
|
||||
"state": "ready"
|
||||
}
|
||||
|
||||
18
models/gpt-3.5-turbo/model.json
Normal file
18
models/gpt-3.5-turbo/model.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"source_url": "https://openai.com",
|
||||
"id": "gpt-3.5-turbo",
|
||||
"object": "model",
|
||||
"name": "OpenAI GPT 3.5 Turbo",
|
||||
"version": 1.0,
|
||||
"description": "OpenAI GPT 3.5 Turbo model is extremely good",
|
||||
"format": "api",
|
||||
"settings": {},
|
||||
"parameters": {},
|
||||
"metadata": {
|
||||
"author": "OpenAI",
|
||||
"tags": ["General", "Big Context Length"]
|
||||
},
|
||||
"engine": "openai",
|
||||
"state": "ready"
|
||||
}
|
||||
|
||||
20
models/gpt-4/model.json
Normal file
20
models/gpt-4/model.json
Normal file
@ -0,0 +1,20 @@
|
||||
{
|
||||
"source_url": "https://openai.com",
|
||||
"id": "gpt-4",
|
||||
"object": "model",
|
||||
"name": "OpenAI GPT 3.5",
|
||||
"version": 1.0,
|
||||
"description": "OpenAI GPT 3.5 model is extremely good",
|
||||
"format": "api",
|
||||
"settings": {},
|
||||
"parameters": {
|
||||
"max_tokens": 4096
|
||||
},
|
||||
"metadata": {
|
||||
"author": "OpenAI",
|
||||
"tags": ["General", "Big Context Length"]
|
||||
},
|
||||
"engine": "openai",
|
||||
"state": "ready"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "MetaAI, The Bloke",
|
||||
"tags": ["70B", "Foundational Model"],
|
||||
"size": 43920000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "MetaAI, The Bloke",
|
||||
"tags": ["7B", "Foundational Model"],
|
||||
"size": 4080000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "MetaAI, The Bloke",
|
||||
"tags": ["7B", "Foundational Model"],
|
||||
"size": 4780000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Lizpreciatior, The Bloke",
|
||||
"tags": ["70B", "Finetuned"],
|
||||
"size": 48750000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
"tags": ["Featured", "7B", "Foundational Model"],
|
||||
"size": 4370000000,
|
||||
"cover": "https://raw.githubusercontent.com/janhq/jan/main/models/mistral-ins-7b-q4/cover.png"
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "MistralAI, The Bloke",
|
||||
"tags": ["7B", "Foundational Model"],
|
||||
"size": 5130000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Intel, The Bloke",
|
||||
"tags": ["Recommended", "7B", "Finetuned"],
|
||||
"size": 4370000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "NeverSleep, The Bloke",
|
||||
"tags": ["34B", "Finetuned"],
|
||||
"size": 12040000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -20,5 +20,6 @@
|
||||
"tags": ["Featured", "7B", "Merged"],
|
||||
"size": 4370000000,
|
||||
"cover": "https://raw.githubusercontent.com/janhq/jan/main/models/openhermes-neural-7b/cover.png"
|
||||
}
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Microsoft, The Bloke",
|
||||
"tags": ["13B", "Finetuned"],
|
||||
"size": 9230000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Phind, The Bloke",
|
||||
"tags": ["34B", "Finetuned"],
|
||||
"size": 24320000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,5 +19,6 @@
|
||||
"author": "Pansophic, The Bloke",
|
||||
"tags": ["Tiny", "Finetuned"],
|
||||
"size": 1710000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "Berkeley-nest, The Bloke",
|
||||
"tags": ["Recommended", "7B","Finetuned"],
|
||||
"size": 4370000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "KoboldAI, The Bloke",
|
||||
"tags": ["13B", "Finetuned"],
|
||||
"size": 9230000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,5 +19,6 @@
|
||||
"author": "TinyLlama",
|
||||
"tags": ["Tiny", "Foundation Model"],
|
||||
"size": 637000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "WizardLM, The Bloke",
|
||||
"tags": ["Recommended", "13B", "Finetuned"],
|
||||
"size": 9230000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "01-ai, The Bloke",
|
||||
"tags": ["34B", "Foundational Model"],
|
||||
"size": 24320000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
"author": "HuggingFaceH4, The Bloke",
|
||||
"tags": ["7B", "Finetuned"],
|
||||
"size": 4370000000
|
||||
}
|
||||
},
|
||||
"engine": "nitro"
|
||||
}
|
||||
|
||||
@ -7,10 +7,16 @@ import {
|
||||
ThreadMessage,
|
||||
ExtensionType,
|
||||
MessageStatus,
|
||||
Model,
|
||||
} from '@janhq/core'
|
||||
import { ConversationalExtension } from '@janhq/core'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
|
||||
|
||||
import { toaster } from '../Toast'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
addNewMessageAtom,
|
||||
@ -24,19 +30,61 @@ import {
|
||||
export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||
const updateMessage = useSetAtom(updateMessageAtom)
|
||||
const { downloadedModels } = useGetDownloadedModels()
|
||||
const setActiveModel = useSetAtom(activeModelAtom)
|
||||
const setStateModel = useSetAtom(stateModelAtom)
|
||||
|
||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
const modelsRef = useRef(downloadedModels)
|
||||
const threadsRef = useRef(threads)
|
||||
|
||||
useEffect(() => {
|
||||
threadsRef.current = threads
|
||||
}, [threads])
|
||||
|
||||
useEffect(() => {
|
||||
modelsRef.current = downloadedModels
|
||||
}, [downloadedModels])
|
||||
|
||||
async function handleNewMessageResponse(message: ThreadMessage) {
|
||||
addNewMessage(message)
|
||||
}
|
||||
|
||||
async function handleModelReady(model: Model) {
|
||||
setActiveModel(model)
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been started.`,
|
||||
})
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
}
|
||||
|
||||
async function handleModelStopped(model: Model) {
|
||||
setTimeout(async () => {
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been stopped.`,
|
||||
})
|
||||
}, 500)
|
||||
}
|
||||
|
||||
async function handleModelFail(res: any) {
|
||||
const errorMessage = `${res.error}`
|
||||
alert(errorMessage)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: res.modelId,
|
||||
}))
|
||||
}
|
||||
|
||||
async function handleMessageResponseUpdate(message: ThreadMessage) {
|
||||
updateMessage(
|
||||
message.id,
|
||||
@ -73,6 +121,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
if (window.core.events) {
|
||||
events.on(EventName.OnMessageResponse, handleNewMessageResponse)
|
||||
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
|
||||
events.on(EventName.OnModelReady, handleModelReady)
|
||||
events.on(EventName.OnModelFail, handleModelFail)
|
||||
events.on(EventName.OnModelStopped, handleModelStopped)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { ExtensionType, InferenceExtension } from '@janhq/core'
|
||||
import {
|
||||
EventName,
|
||||
events,
|
||||
} from '@janhq/core'
|
||||
import { Model, ModelSettingParams } from '@janhq/core'
|
||||
import { atom, useAtom } from 'jotai'
|
||||
|
||||
@ -9,9 +12,13 @@ import { useGetDownloadedModels } from './useGetDownloadedModels'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
const activeModelAtom = atom<Model | undefined>(undefined)
|
||||
export const activeModelAtom = atom<Model | undefined>(undefined)
|
||||
|
||||
const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
|
||||
export const stateModelAtom = atom({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: '',
|
||||
})
|
||||
|
||||
export function useActiveModel() {
|
||||
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
|
||||
@ -47,59 +54,14 @@ export function useActiveModel() {
|
||||
return
|
||||
}
|
||||
|
||||
const currentTime = Date.now()
|
||||
const res = await initModel(modelId, model?.settings)
|
||||
if (res && res.error) {
|
||||
const errorMessage = `${res.error}`
|
||||
alert(errorMessage)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: modelId,
|
||||
}))
|
||||
} else {
|
||||
console.debug(
|
||||
`Model ${modelId} successfully initialized! Took ${
|
||||
Date.now() - currentTime
|
||||
}ms`
|
||||
)
|
||||
setActiveModel(model)
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${modelId} has been started.`,
|
||||
})
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: modelId,
|
||||
}))
|
||||
}
|
||||
events.emit(EventName.OnModelInit, model)
|
||||
}
|
||||
|
||||
const stopModel = async (modelId: string) => {
|
||||
const model = downloadedModels.find((e) => e.id === modelId)
|
||||
setStateModel({ state: 'stop', loading: true, model: modelId })
|
||||
setTimeout(async () => {
|
||||
extensionManager
|
||||
.get<InferenceExtension>(ExtensionType.Inference)
|
||||
?.stopModel()
|
||||
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${modelId} has been stopped.`,
|
||||
})
|
||||
}, 500)
|
||||
events.emit(EventName.OnModelStop, model)
|
||||
}
|
||||
|
||||
return { activeModel, startModel, stopModel, stateModel }
|
||||
}
|
||||
|
||||
const initModel = async (
|
||||
modelId: string,
|
||||
settings?: ModelSettingParams
|
||||
): Promise<any> => {
|
||||
return extensionManager
|
||||
.get<InferenceExtension>(ExtensionType.Inference)
|
||||
?.initModel(modelId, settings)
|
||||
}
|
||||
|
||||
@ -67,6 +67,7 @@ export const useCreateNewThread = () => {
|
||||
top_p: 0,
|
||||
stream: false,
|
||||
},
|
||||
engine: undefined
|
||||
},
|
||||
instructions: assistant.instructions,
|
||||
}
|
||||
|
||||
@ -50,7 +50,6 @@ export default function useSendChatMessage() {
|
||||
const [queuedMessage, setQueuedMessage] = useState(false)
|
||||
|
||||
const modelRef = useRef<Model | undefined>()
|
||||
|
||||
useEffect(() => {
|
||||
modelRef.current = activeModel
|
||||
}, [activeModel])
|
||||
@ -91,18 +90,35 @@ export default function useSendChatMessage() {
|
||||
id: ulid(),
|
||||
messages: messages,
|
||||
threadId: activeThread.id,
|
||||
model: activeThread.assistants[0].model ?? selectedModel,
|
||||
}
|
||||
|
||||
const modelId = selectedModel?.id ?? activeThread.assistants[0].model.id
|
||||
|
||||
if (activeModel?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
await startModel(modelId)
|
||||
startModel(modelId)
|
||||
await WaitForModelStarting(modelId)
|
||||
setQueuedMessage(false)
|
||||
}
|
||||
events.emit(EventName.OnMessageSent, messageRequest)
|
||||
}
|
||||
|
||||
// TODO: Refactor @louis
|
||||
const WaitForModelStarting = async (modelId: string) => {
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(async () => {
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
console.log('waiting for model to start')
|
||||
await WaitForModelStarting(modelId)
|
||||
resolve()
|
||||
} else {
|
||||
resolve()
|
||||
}
|
||||
}, 200)
|
||||
})
|
||||
}
|
||||
|
||||
const sendChatMessage = async () => {
|
||||
if (!currentPrompt || currentPrompt.trim().length === 0) {
|
||||
return
|
||||
@ -132,6 +148,7 @@ export default function useSendChatMessage() {
|
||||
id: selectedModel.id,
|
||||
settings: selectedModel.settings,
|
||||
parameters: selectedModel.parameters,
|
||||
engine: selectedModel.engine,
|
||||
},
|
||||
},
|
||||
],
|
||||
@ -178,7 +195,7 @@ export default function useSendChatMessage() {
|
||||
id: msgId,
|
||||
threadId: activeThread.id,
|
||||
messages,
|
||||
parameters: activeThread.assistants[0].model.parameters,
|
||||
model: selectedModel ?? activeThread.assistants[0].model,
|
||||
}
|
||||
const timestamp = Date.now()
|
||||
const threadMessage: ThreadMessage = {
|
||||
@ -210,7 +227,8 @@ export default function useSendChatMessage() {
|
||||
|
||||
if (activeModel?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
await startModel(modelId)
|
||||
startModel(modelId)
|
||||
await WaitForModelStarting(modelId)
|
||||
setQueuedMessage(false)
|
||||
}
|
||||
events.emit(EventName.OnMessageSent, messageRequest)
|
||||
|
||||
@ -30,9 +30,8 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
|
||||
const { resendChatMessage } = useSendChatMessage()
|
||||
|
||||
const onStopInferenceClick = async () => {
|
||||
await extensionManager
|
||||
.get<InferenceExtension>(ExtensionType.Inference)
|
||||
?.stopInference()
|
||||
events.emit(EventName.OnInferenceStopped, {})
|
||||
|
||||
setTimeout(() => {
|
||||
events.emit(EventName.OnMessageUpdate, {
|
||||
...message,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
export const toGigabytes = (input: number) => {
|
||||
if (!input) return ''
|
||||
if (input > 1024 ** 3) {
|
||||
return (input / 1000 ** 3).toFixed(2) + 'GB'
|
||||
} else if (input > 1024 ** 2) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user