diff --git a/core/src/core.ts b/core/src/core.ts index 3339759b2..32244e784 100644 --- a/core/src/core.ts +++ b/core/src/core.ts @@ -56,7 +56,8 @@ const openFileExplorer: (path: string) => Promise = (path) => * @param paths - The paths to join. * @returns {Promise} A promise that resolves with the joined path. */ -const joinPath: (paths: string[]) => Promise = (paths) => globalThis.core.api?.joinPath(paths) +const joinPath: (paths: string[]) => Promise = (paths) => + globalThis.core.api?.joinPath(paths) /** * Retrive the basename from an url. diff --git a/core/src/extensions/ai-engines/AIEngine.ts b/core/src/extensions/ai-engines/AIEngine.ts index 608b5c193..c65c081fd 100644 --- a/core/src/extensions/ai-engines/AIEngine.ts +++ b/core/src/extensions/ai-engines/AIEngine.ts @@ -14,7 +14,9 @@ export abstract class AIEngine extends BaseExtension { // The model folder modelFolder: string = 'models' - abstract models(): Promise + models(): Promise { + return Promise.resolve([]) + } /** * On extension load, subscribe to events. diff --git a/core/src/extensions/ai-engines/LocalOAIEngine.ts b/core/src/extensions/ai-engines/LocalOAIEngine.ts index 89444ff0f..f6557cd8f 100644 --- a/core/src/extensions/ai-engines/LocalOAIEngine.ts +++ b/core/src/extensions/ai-engines/LocalOAIEngine.ts @@ -9,9 +9,9 @@ import { OAIEngine } from './OAIEngine' */ export abstract class LocalOAIEngine extends OAIEngine { // The inference engine + abstract nodeModule: string loadModelFunctionName: string = 'loadModel' unloadModelFunctionName: string = 'unloadModel' - isRunning: boolean = false /** * On extension load, subscribe to events. @@ -19,22 +19,27 @@ export abstract class LocalOAIEngine extends OAIEngine { onLoad() { super.onLoad() // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model)) - events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model)) + events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** * Load the model. */ - async onModelInit(model: Model) { + async loadModel(model: Model) { if (model.engine.toString() !== this.provider) return const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id]) const systemInfo = await systemInformation() - const res = await executeOnMain(this.nodeModule, this.loadModelFunctionName, { - modelFolder, - model, - }, systemInfo) + const res = await executeOnMain( + this.nodeModule, + this.loadModelFunctionName, + { + modelFolder, + model, + }, + systemInfo + ) if (res?.error) { events.emit(ModelEvent.OnModelFail, { @@ -45,16 +50,14 @@ export abstract class LocalOAIEngine extends OAIEngine { } else { this.loadedModel = model events.emit(ModelEvent.OnModelReady, model) - this.isRunning = true } } /** * Stops the model. */ - onModelStop(model: Model) { - if (model.engine?.toString() !== this.provider) return - - this.isRunning = false + unloadModel(model: Model) { + if (model.engine && model.engine?.toString() !== this.provider) return + this.loadedModel = undefined executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { events.emit(ModelEvent.OnModelStopped, {}) diff --git a/core/src/extensions/ai-engines/OAIEngine.ts b/core/src/extensions/ai-engines/OAIEngine.ts index 948de56ca..5936005bb 100644 --- a/core/src/extensions/ai-engines/OAIEngine.ts +++ b/core/src/extensions/ai-engines/OAIEngine.ts @@ -23,7 +23,6 @@ import { events } from '../../events' export abstract class OAIEngine extends AIEngine { // The inference engine abstract inferenceUrl: string - abstract nodeModule: string // Controller to handle stop requests controller = new AbortController() @@ -38,7 +37,7 @@ export abstract class OAIEngine extends AIEngine { onLoad() { super.onLoad() events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) - events.on(InferenceEvent.OnInferenceStopped, () => this.onInferenceStopped()) + events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) } /** @@ -78,7 +77,13 @@ export abstract class OAIEngine extends AIEngine { ...data.model, } - requestInference(this.inferenceUrl, data.messages ?? [], model, this.controller).subscribe({ + requestInference( + this.inferenceUrl, + data.messages ?? [], + model, + this.controller, + this.headers() + ).subscribe({ next: (content: any) => { const messageContent: ThreadContent = { type: ContentType.Text, @@ -109,8 +114,15 @@ export abstract class OAIEngine extends AIEngine { /** * Stops the inference. */ - onInferenceStopped() { + stopInference() { this.isCancelled = true this.controller?.abort() } + + /** + * Headers for the inference request + */ + headers(): HeadersInit { + return {} + } } diff --git a/core/src/extensions/ai-engines/RemoteOAIEngine.ts b/core/src/extensions/ai-engines/RemoteOAIEngine.ts new file mode 100644 index 000000000..5e9804b23 --- /dev/null +++ b/core/src/extensions/ai-engines/RemoteOAIEngine.ts @@ -0,0 +1,46 @@ +import { events } from '../../events' +import { Model, ModelEvent } from '../../types' +import { OAIEngine } from './OAIEngine' + +/** + * Base OAI Remote Inference Provider + * Added the implementation of loading and unloading model (applicable to local inference providers) + */ +export abstract class RemoteOAIEngine extends OAIEngine { + // The inference engine + abstract apiKey: string + /** + * On extension load, subscribe to events. + */ + onLoad() { + super.onLoad() + // These events are applicable to local inference providers + events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) + events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) + } + + /** + * Load the model. + */ + async loadModel(model: Model) { + if (model.engine.toString() !== this.provider) return + events.emit(ModelEvent.OnModelReady, model) + } + /** + * Stops the model. + */ + unloadModel(model: Model) { + if (model.engine && model.engine.toString() !== this.provider) return + events.emit(ModelEvent.OnModelStopped, {}) + } + + /** + * Headers for the inference request + */ + override headers(): HeadersInit { + return { + 'Authorization': `Bearer ${this.apiKey}`, + 'api-key': `${this.apiKey}`, + } + } +} diff --git a/core/src/extensions/ai-engines/helpers/sse.ts b/core/src/extensions/ai-engines/helpers/sse.ts index 3d810d934..723d0dc13 100644 --- a/core/src/extensions/ai-engines/helpers/sse.ts +++ b/core/src/extensions/ai-engines/helpers/sse.ts @@ -12,7 +12,8 @@ export function requestInference( id: string parameters: ModelRuntimeParams }, - controller?: AbortController + controller?: AbortController, + headers?: HeadersInit ): Observable { return new Observable((subscriber) => { const requestBody = JSON.stringify({ @@ -27,6 +28,7 @@ export function requestInference( 'Content-Type': 'application/json', 'Access-Control-Allow-Origin': '*', 'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json', + ...headers, }, body: requestBody, signal: controller?.signal, diff --git a/core/src/extensions/ai-engines/index.ts b/core/src/extensions/ai-engines/index.ts index f4da62a7c..fc341380a 100644 --- a/core/src/extensions/ai-engines/index.ts +++ b/core/src/extensions/ai-engines/index.ts @@ -1,3 +1,4 @@ export * from './AIEngine' export * from './OAIEngine' export * from './LocalOAIEngine' +export * from './RemoteOAIEngine' diff --git a/extensions/inference-groq-extension/package.json b/extensions/inference-groq-extension/package.json index c6faf5418..78efd3552 100644 --- a/extensions/inference-groq-extension/package.json +++ b/extensions/inference-groq-extension/package.json @@ -25,7 +25,7 @@ "@janhq/core": "file:../../core", "fetch-retry": "^5.0.6", "path-browserify": "^1.0.1", - "ulid": "^2.3.0" + "ulidx": "^2.3.0" }, "engines": { "node": ">=18.0.0" diff --git a/extensions/inference-groq-extension/src/@types/global.d.ts b/extensions/inference-groq-extension/src/@types/global.d.ts deleted file mode 100644 index f817fb406..000000000 --- a/extensions/inference-groq-extension/src/@types/global.d.ts +++ /dev/null @@ -1,16 +0,0 @@ -declare const MODULE: string -declare const GROQ_DOMAIN: string - -declare interface EngineSettings { - full_url?: string - api_key?: string -} - -enum GroqChatCompletionModelName { - 'mixtral-8x7b-32768' = 'mixtral-8x7b-32768', - 'llama2-70b-4096' = 'llama2-70b-4096', -} - -declare type GroqModel = Omit & { - id: GroqChatCompletionModelName -} diff --git a/extensions/inference-groq-extension/src/helpers/sse.ts b/extensions/inference-groq-extension/src/helpers/sse.ts deleted file mode 100644 index 35c40053c..000000000 --- a/extensions/inference-groq-extension/src/helpers/sse.ts +++ /dev/null @@ -1,83 +0,0 @@ -import { ErrorCode } from '@janhq/core' -import { Observable } from 'rxjs' - -/** - * Sends a request to the inference server to generate a response based on the recent messages. - * @param recentMessages - An array of recent messages to use as context for the inference. - * @param engine - The engine settings to use for the inference. - * @param model - The model to use for the inference. - * @returns An Observable that emits the generated response as a string. - */ -export function requestInference( - recentMessages: any[], - engine: EngineSettings, - model: GroqModel, - controller?: AbortController -): Observable { - return new Observable((subscriber) => { - // let model_id: string = model.id - - const requestBody = JSON.stringify({ - messages: recentMessages, - stream: true, - model: model.id, - ...model.parameters, - }) - fetch(`${engine.full_url}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Accept': model.parameters.stream - ? 'text/event-stream' - : 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Authorization': `Bearer ${engine.api_key}`, - // 'api-key': `${engine.api_key}`, - }, - body: requestBody, - signal: controller?.signal, - }) - .then(async (response) => { - if (!response.ok) { - const data = await response.json() - const error = { - message: data.error?.message ?? 'An error occurred.', - code: data.error?.code ?? ErrorCode.Unknown, - } - subscriber.error(error) - subscriber.complete() - return - } - if (model.parameters.stream === false) { - const data = await response.json() - subscriber.next(data.choices[0]?.message?.content ?? '') - } else { - const stream = response.body - const decoder = new TextDecoder('utf-8') - const reader = stream?.getReader() - let content = '' - - while (true && reader) { - const { done, value } = await reader.read() - if (done) { - break - } - const text = decoder.decode(value) - const lines = text.trim().split('\n') - for (const line of lines) { - if (line.startsWith('data: ') && !line.includes('data: [DONE]')) { - const data = JSON.parse(line.replace('data: ', '')) - content += data.choices[0]?.delta?.content ?? '' - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', '') - } - subscriber.next(content) - } - } - } - } - subscriber.complete() - }) - .catch((err) => subscriber.error(err)) - }) -} diff --git a/extensions/inference-groq-extension/src/index.ts b/extensions/inference-groq-extension/src/index.ts index 0fe22a11c..f4dc23d1c 100644 --- a/extensions/inference-groq-extension/src/index.ts +++ b/extensions/inference-groq-extension/src/index.ts @@ -7,218 +7,77 @@ */ import { - ChatCompletionRole, - ContentType, - MessageRequest, - MessageStatus, - ThreadContent, - ThreadMessage, events, fs, - InferenceEngine, - BaseExtension, - MessageEvent, - MessageRequestType, - ModelEvent, - InferenceEvent, AppConfigurationEventName, joinPath, + RemoteOAIEngine, } from '@janhq/core' -import { requestInference } from './helpers/sse' -import { ulid } from 'ulid' import { join } from 'path' +declare const COMPLETION_URL: string /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class JanInferenceGroqExtension extends BaseExtension { - private static readonly _engineDir = 'file://engines' - private static readonly _engineMetadataFileName = 'groq.json' +export default class JanInferenceGroqExtension extends RemoteOAIEngine { + private readonly _engineDir = 'file://engines' + private readonly _engineMetadataFileName = 'groq.json' - private static _currentModel: GroqModel + inferenceUrl: string = COMPLETION_URL + provider = 'groq' + apiKey = '' - private static _engineSettings: EngineSettings = { - full_url: 'https://api.groq.com/openai/v1/chat/completions', + private _engineSettings = { + full_url: COMPLETION_URL, api_key: 'gsk-', } - controller = new AbortController() - isCancelled = false - /** * Subscribes to events emitted by the @janhq/core package. */ async onLoad() { - if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) { - await fs - .mkdirSync(JanInferenceGroqExtension._engineDir) - .catch((err) => console.debug(err)) + super.onLoad() + + if (!(await fs.existsSync(this._engineDir))) { + await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err)) } - JanInferenceGroqExtension.writeDefaultEngineSettings() - - // Events subscription - events.on(MessageEvent.OnMessageSent, (data) => - JanInferenceGroqExtension.handleMessageRequest(data, this) - ) - - events.on(ModelEvent.OnModelInit, (model: GroqModel) => { - JanInferenceGroqExtension.handleModelInit(model) - }) - - events.on(ModelEvent.OnModelStop, (model: GroqModel) => { - JanInferenceGroqExtension.handleModelStop(model) - }) - events.on(InferenceEvent.OnInferenceStopped, () => { - JanInferenceGroqExtension.handleInferenceStopped(this) - }) + this.writeDefaultEngineSettings() const settingsFilePath = await joinPath([ - JanInferenceGroqExtension._engineDir, - JanInferenceGroqExtension._engineMetadataFileName, + this._engineDir, + this._engineMetadataFileName, ]) + // Events subscription events.on( AppConfigurationEventName.OnConfigurationUpdate, (settingsKey: string) => { // Update settings on changes - if (settingsKey === settingsFilePath) - JanInferenceGroqExtension.writeDefaultEngineSettings() + if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings() } ) } - /** - * Stops the model inference. - */ - onUnload(): void {} - - static async writeDefaultEngineSettings() { + async writeDefaultEngineSettings() { try { - const engineFile = join( - JanInferenceGroqExtension._engineDir, - JanInferenceGroqExtension._engineMetadataFileName - ) + const engineFile = join(this._engineDir, this._engineMetadataFileName) if (await fs.existsSync(engineFile)) { const engine = await fs.readFileSync(engineFile, 'utf-8') - JanInferenceGroqExtension._engineSettings = + this._engineSettings = typeof engine === 'object' ? engine : JSON.parse(engine) + this.inferenceUrl = this._engineSettings.full_url + this.apiKey = this._engineSettings.api_key } else { await fs.writeFileSync( engineFile, - JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2) + JSON.stringify(this._engineSettings, null, 2) ) } } catch (err) { console.error(err) } } - private static async handleModelInit(model: GroqModel) { - if (model.engine !== InferenceEngine.groq) { - return - } else { - JanInferenceGroqExtension._currentModel = model - JanInferenceGroqExtension.writeDefaultEngineSettings() - // Todo: Check model list with API key - events.emit(ModelEvent.OnModelReady, model) - } - } - - private static async handleModelStop(model: GroqModel) { - if (model.engine !== 'groq') { - return - } - events.emit(ModelEvent.OnModelStopped, model) - } - - private static async handleInferenceStopped( - instance: JanInferenceGroqExtension - ) { - instance.isCancelled = true - instance.controller?.abort() - } - - /** - * Handles a new message request by making an inference request and emitting events. - * Function registered in event manager, should be static to avoid binding issues. - * Pass instance as a reference. - * @param {MessageRequest} data - The data for the new message request. - */ - private static async handleMessageRequest( - data: MessageRequest, - instance: JanInferenceGroqExtension - ) { - if (data.model.engine !== 'groq') { - return - } - - const timestamp = Date.now() - const message: ThreadMessage = { - id: ulid(), - thread_id: data.threadId, - type: data.type, - assistant_id: data.assistantId, - role: ChatCompletionRole.Assistant, - content: [], - status: MessageStatus.Pending, - created: timestamp, - updated: timestamp, - object: 'thread.message', - } - - if (data.type !== MessageRequestType.Summary) { - events.emit(MessageEvent.OnMessageResponse, message) - } - - instance.isCancelled = false - instance.controller = new AbortController() - - requestInference( - data?.messages ?? [], - this._engineSettings, - { - ...JanInferenceGroqExtension._currentModel, - parameters: data.model.parameters, - }, - instance.controller - ).subscribe({ - next: (content) => { - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: content.trim(), - annotations: [], - }, - } - message.content = [messageContent] - events.emit(MessageEvent.OnMessageUpdate, message) - }, - complete: async () => { - message.status = message.content.length - ? MessageStatus.Ready - : MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - }, - error: async (err) => { - if (instance.isCancelled || message.content.length > 0) { - message.status = MessageStatus.Stopped - events.emit(MessageEvent.OnMessageUpdate, message) - return - } - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: 'An error occurred. ' + err.message, - annotations: [], - }, - } - message.content = [messageContent] - message.status = MessageStatus.Error - message.error_code = err.code - events.emit(MessageEvent.OnMessageUpdate, message) - }, - }) - } } diff --git a/extensions/inference-groq-extension/webpack.config.js b/extensions/inference-groq-extension/webpack.config.js index 96110e818..5352b56b7 100644 --- a/extensions/inference-groq-extension/webpack.config.js +++ b/extensions/inference-groq-extension/webpack.config.js @@ -18,7 +18,7 @@ module.exports = { plugins: [ new webpack.DefinePlugin({ MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), - GROQ_DOMAIN: JSON.stringify('api.groq.com'), + COMPLETION_URL: JSON.stringify('https://api.groq.com/openai/v1/chat/completions'), }), ], output: { diff --git a/extensions/inference-nitro-extension/jest.config.js b/extensions/inference-nitro-extension/jest.config.js new file mode 100644 index 000000000..b413e106d --- /dev/null +++ b/extensions/inference-nitro-extension/jest.config.js @@ -0,0 +1,5 @@ +/** @type {import('ts-jest').JestConfigWithTsJest} */ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', +}; \ No newline at end of file diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json index 45bd8307a..25abaf049 100644 --- a/extensions/inference-nitro-extension/package.json +++ b/extensions/inference-nitro-extension/package.json @@ -7,6 +7,7 @@ "author": "Jan ", "license": "AGPL-3.0", "scripts": { + "test": "jest", "build": "tsc --module commonjs && rollup -c rollup.config.ts", "downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro", "downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro", @@ -15,29 +16,34 @@ "build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", - "build:publish": "run-script-os" + "build:publish": "yarn test && run-script-os" }, "exports": { ".": "./dist/index.js", "./main": "./dist/node/index.cjs.js" }, "devDependencies": { + "@babel/preset-typescript": "^7.24.1", + "@jest/globals": "^29.7.0", "@rollup/plugin-commonjs": "^25.0.7", "@rollup/plugin-json": "^6.1.0", "@rollup/plugin-node-resolve": "^15.2.3", + "@rollup/plugin-replace": "^5.0.5", + "@types/jest": "^29.5.12", "@types/node": "^20.11.4", + "@types/os-utils": "^0.0.4", "@types/tcp-port-used": "^1.0.4", "cpx": "^1.5.0", "download-cli": "^1.1.1", + "jest": "^29.7.0", "rimraf": "^3.0.2", "rollup": "^2.38.5", "rollup-plugin-define": "^1.0.1", "rollup-plugin-sourcemaps": "^0.6.3", "rollup-plugin-typescript2": "^0.36.0", "run-script-os": "^1.1.6", - "typescript": "^5.3.3", - "@types/os-utils": "^0.0.4", - "@rollup/plugin-replace": "^5.0.5" + "ts-jest": "^29.1.2", + "typescript": "^5.3.3" }, "dependencies": { "@janhq/core": "file:../../core", diff --git a/extensions/inference-nitro-extension/src/babel.config.js b/extensions/inference-nitro-extension/src/babel.config.js new file mode 100644 index 000000000..befbdd148 --- /dev/null +++ b/extensions/inference-nitro-extension/src/babel.config.js @@ -0,0 +1,6 @@ +module.exports = { + presets: [ + ['@babel/preset-env', { targets: { node: 'current' } }], + '@babel/preset-typescript', + ], +} diff --git a/extensions/inference-nitro-extension/src/helpers/sse.ts b/extensions/inference-nitro-extension/src/helpers/sse.ts deleted file mode 100644 index 06176c9b9..000000000 --- a/extensions/inference-nitro-extension/src/helpers/sse.ts +++ /dev/null @@ -1,66 +0,0 @@ -import { Model } from '@janhq/core' -import { Observable } from 'rxjs' -/** - * Sends a request to the inference server to generate a response based on the recent messages. - * @param recentMessages - An array of recent messages to use as context for the inference. - * @returns An Observable that emits the generated response as a string. - */ -export function requestInference( - inferenceUrl: string, - recentMessages: any[], - model: Model, - controller?: AbortController -): Observable { - return new Observable((subscriber) => { - const requestBody = JSON.stringify({ - messages: recentMessages, - model: model.id, - stream: true, - ...model.parameters, - }) - fetch(inferenceUrl, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Accept': model.parameters.stream - ? 'text/event-stream' - : 'application/json', - }, - body: requestBody, - signal: controller?.signal, - }) - .then(async (response) => { - if (model.parameters.stream === false) { - const data = await response.json() - subscriber.next(data.choices[0]?.message?.content ?? '') - } else { - const stream = response.body - const decoder = new TextDecoder('utf-8') - const reader = stream?.getReader() - let content = '' - - while (true && reader) { - const { done, value } = await reader.read() - if (done) { - break - } - const text = decoder.decode(value) - const lines = text.trim().split('\n') - for (const line of lines) { - if (line.startsWith('data: ') && !line.includes('data: [DONE]')) { - const data = JSON.parse(line.replace('data: ', '')) - content += data.choices[0]?.delta?.content ?? '' - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', '') - } - subscriber.next(content) - } - } - } - } - subscriber.complete() - }) - .catch((err) => subscriber.error(err)) - }) -} diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index e398cb643..3a23082ba 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -7,58 +7,31 @@ */ import { - ChatCompletionRole, - ContentType, - MessageRequest, - MessageRequestType, - MessageStatus, - ThreadContent, - ThreadMessage, events, executeOnMain, - fs, Model, - joinPath, - InferenceExtension, - log, - InferenceEngine, - MessageEvent, ModelEvent, - InferenceEvent, - ModelSettingParams, - getJanDataFolderPath, + LocalOAIEngine, } from '@janhq/core' -import { requestInference } from './helpers/sse' -import { ulid } from 'ulidx' /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class JanInferenceNitroExtension extends InferenceExtension { - private static readonly _homeDir = 'file://engines' - private static readonly _settingsDir = 'file://settings' - private static readonly _engineMetadataFileName = 'nitro.json' +export default class JanInferenceNitroExtension extends LocalOAIEngine { + nodeModule: string = NODE + provider: string = 'nitro' + + models(): Promise { + return Promise.resolve([]) + } /** * Checking the health for Nitro's process each 5 secs. */ private static readonly _intervalHealthCheck = 5 * 1000 - private _currentModel: Model | undefined - - private _engineSettings: ModelSettingParams = { - ctx_len: 2048, - ngl: 100, - cpu_threads: 1, - cont_batching: false, - embedding: true, - } - - controller = new AbortController() - isCancelled = false - /** * The interval id for the health check. Used to stop the health check. */ @@ -69,114 +42,30 @@ export default class JanInferenceNitroExtension extends InferenceExtension { */ private nitroProcessInfo: any = undefined - private inferenceUrl = '' + /** + * The URL for making inference requests. + */ + inferenceUrl = '' /** * Subscribes to events emitted by the @janhq/core package. */ async onLoad() { - if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) { - try { - await fs.mkdirSync(JanInferenceNitroExtension._homeDir) - } catch (e) { - console.debug(e) - } - } - - // init inference url - // @ts-ignore - const electronApi = window?.electronAPI this.inferenceUrl = INFERENCE_URL - if (!electronApi) { + + // If the extension is running in the browser, use the base API URL from the core package. + if (!('electronAPI' in window)) { this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions` } + console.debug('Inference url: ', this.inferenceUrl) - if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir))) - await fs.mkdirSync(JanInferenceNitroExtension._settingsDir) - this.writeDefaultEngineSettings() - - // Events subscription - events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => - this.onMessageRequest(data) - ) - - events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model)) - - events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model)) - - events.on(InferenceEvent.OnInferenceStopped, () => - this.onInferenceStopped() - ) - } - - /** - * Stops the model inference. - */ - onUnload(): void {} - - private async writeDefaultEngineSettings() { - try { - const engineFile = await joinPath([ - JanInferenceNitroExtension._homeDir, - JanInferenceNitroExtension._engineMetadataFileName, - ]) - if (await fs.existsSync(engineFile)) { - const engine = await fs.readFileSync(engineFile, 'utf-8') - this._engineSettings = - typeof engine === 'object' ? engine : JSON.parse(engine) - } else { - await fs.writeFileSync( - engineFile, - JSON.stringify(this._engineSettings, null, 2) - ) - } - } catch (err) { - console.error(err) - } - } - - private async onModelInit(model: Model) { - if (model.engine !== InferenceEngine.nitro) return - - const modelFolder = await joinPath([ - await getJanDataFolderPath(), - 'models', - model.id, - ]) - this._currentModel = model - const nitroInitResult = await executeOnMain(NODE, 'runModel', { - modelFolder, - model, - }) - - if (nitroInitResult?.error) { - events.emit(ModelEvent.OnModelFail, { - ...model, - error: nitroInitResult.error, - }) - return - } - - events.emit(ModelEvent.OnModelReady, model) - this.getNitroProcesHealthIntervalId = setInterval( () => this.periodicallyGetNitroHealth(), JanInferenceNitroExtension._intervalHealthCheck ) - } - private async onModelStop(model: Model) { - if (model.engine !== 'nitro') return - - await executeOnMain(NODE, 'stopModel') - events.emit(ModelEvent.OnModelStopped, {}) - - // stop the periocally health check - if (this.getNitroProcesHealthIntervalId) { - clearInterval(this.getNitroProcesHealthIntervalId) - this.getNitroProcesHealthIntervalId = undefined - } + super.onLoad() } /** @@ -193,118 +82,24 @@ export default class JanInferenceNitroExtension extends InferenceExtension { this.nitroProcessInfo = health } - private async onInferenceStopped() { - this.isCancelled = true - this.controller?.abort() + override loadModel(model: Model): Promise { + if (model.engine !== this.provider) return Promise.resolve() + this.getNitroProcesHealthIntervalId = setInterval( + () => this.periodicallyGetNitroHealth(), + JanInferenceNitroExtension._intervalHealthCheck + ) + return super.loadModel(model) } - /** - * Makes a single response inference request. - * @param {MessageRequest} data - The data for the inference request. - * @returns {Promise} A promise that resolves with the inference response. - */ - async inference(data: MessageRequest): Promise { - const timestamp = Date.now() - const message: ThreadMessage = { - thread_id: data.threadId, - created: timestamp, - updated: timestamp, - status: MessageStatus.Ready, - id: '', - role: ChatCompletionRole.Assistant, - object: 'thread.message', - content: [], + override unloadModel(model: Model): void { + super.unloadModel(model) + + if (model.engine && model.engine !== this.provider) return + + // stop the periocally health check + if (this.getNitroProcesHealthIntervalId) { + clearInterval(this.getNitroProcesHealthIntervalId) + this.getNitroProcesHealthIntervalId = undefined } - - return new Promise(async (resolve, reject) => { - if (!this._currentModel) return Promise.reject('No model loaded') - - requestInference( - this.inferenceUrl, - data.messages ?? [], - this._currentModel - ).subscribe({ - next: (_content: any) => {}, - complete: async () => { - resolve(message) - }, - error: async (err: any) => { - reject(err) - }, - }) - }) - } - - /** - * Handles a new message request by making an inference request and emitting events. - * Function registered in event manager, should be static to avoid binding issues. - * Pass instance as a reference. - * @param {MessageRequest} data - The data for the new message request. - */ - private async onMessageRequest(data: MessageRequest) { - if (data.model?.engine !== InferenceEngine.nitro || !this._currentModel) { - return - } - - const timestamp = Date.now() - const message: ThreadMessage = { - id: ulid(), - thread_id: data.threadId, - type: data.type, - assistant_id: data.assistantId, - role: ChatCompletionRole.Assistant, - content: [], - status: MessageStatus.Pending, - created: timestamp, - updated: timestamp, - object: 'thread.message', - } - - if (data.type !== MessageRequestType.Summary) { - events.emit(MessageEvent.OnMessageResponse, message) - } - - this.isCancelled = false - this.controller = new AbortController() - - // @ts-ignore - const model: Model = { - ...(this._currentModel || {}), - ...(data.model || {}), - } - requestInference( - this.inferenceUrl, - data.messages ?? [], - model, - this.controller - ).subscribe({ - next: (content: any) => { - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: content.trim(), - annotations: [], - }, - } - message.content = [messageContent] - events.emit(MessageEvent.OnMessageUpdate, message) - }, - complete: async () => { - message.status = message.content.length - ? MessageStatus.Ready - : MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - }, - error: async (err: any) => { - if (this.isCancelled || message.content.length) { - message.status = MessageStatus.Stopped - events.emit(MessageEvent.OnMessageUpdate, message) - return - } - message.status = MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - log(`[APP]::Error: ${err.message}`) - }, - }) } } diff --git a/extensions/inference-nitro-extension/src/node/execute.test.ts b/extensions/inference-nitro-extension/src/node/execute.test.ts new file mode 100644 index 000000000..62ffdc707 --- /dev/null +++ b/extensions/inference-nitro-extension/src/node/execute.test.ts @@ -0,0 +1,233 @@ +import { describe, expect, it } from '@jest/globals' +import { executableNitroFile } from './execute' +import { GpuSetting } from '@janhq/core' +import { sep } from 'path' + +let testSettings: GpuSetting = { + run_mode: 'cpu', + vulkan: false, + cuda: { + exist: false, + version: '11', + }, + gpu_highest_vram: '0', + gpus: [], + gpus_in_use: [], + is_initial: false, + notify: true, + nvidia_driver: { + exist: false, + version: '11', + }, +} +const originalPlatform = process.platform + +describe('test executable nitro file', () => { + afterAll(function () { + Object.defineProperty(process, 'platform', { + value: originalPlatform, + }) + }) + + it('executes on MacOS ARM', () => { + Object.defineProperty(process, 'platform', { + value: 'darwin', + }) + Object.defineProperty(process, 'arch', { + value: 'arm64', + }) + expect(executableNitroFile(testSettings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`mac-arm64${sep}nitro`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on MacOS Intel', () => { + Object.defineProperty(process, 'platform', { + value: 'darwin', + }) + Object.defineProperty(process, 'arch', { + value: 'x64', + }) + expect(executableNitroFile(testSettings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`mac-x64${sep}nitro`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Windows CPU', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'cpu', + cuda: { + exist: true, + version: '11', + }, + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cpu${sep}nitro.exe`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Windows Cuda 11', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '11', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cuda-11-7${sep}nitro.exe`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Windows Cuda 12', () => { + Object.defineProperty(process, 'platform', { + value: 'win32', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '12', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`win-cuda-12-0${sep}nitro.exe`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Linux CPU', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'cpu', + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cpu${sep}nitro`), + cudaVisibleDevices: '', + vkVisibleDevices: '', + }) + ) + }) + + it('executes on Linux Cuda 11', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '11', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cuda-11-7${sep}nitro`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) + + it('executes on Linux Cuda 12', () => { + Object.defineProperty(process, 'platform', { + value: 'linux', + }) + const settings: GpuSetting = { + ...testSettings, + run_mode: 'gpu', + cuda: { + exist: true, + version: '12', + }, + nvidia_driver: { + exist: true, + version: '12', + }, + gpus_in_use: ['0'], + gpus: [ + { + id: '0', + name: 'NVIDIA GeForce GTX 1080', + vram: '80000000', + }, + ], + } + expect(executableNitroFile(settings)).toEqual( + expect.objectContaining({ + executablePath: expect.stringContaining(`linux-cuda-12-0${sep}nitro`), + cudaVisibleDevices: '0', + vkVisibleDevices: '0', + }) + ) + }) +}) diff --git a/extensions/inference-nitro-extension/src/node/execute.ts b/extensions/inference-nitro-extension/src/node/execute.ts index 8bcc75ae4..c9d541654 100644 --- a/extensions/inference-nitro-extension/src/node/execute.ts +++ b/extensions/inference-nitro-extension/src/node/execute.ts @@ -1,5 +1,4 @@ -import { getJanDataFolderPath } from '@janhq/core/node' -import { readFileSync } from 'fs' +import { GpuSetting, SystemInformation } from '@janhq/core' import * as path from 'path' export interface NitroExecutableOptions { @@ -7,79 +6,56 @@ export interface NitroExecutableOptions { cudaVisibleDevices: string vkVisibleDevices: string } +const runMode = (settings?: GpuSetting): string => { + if (process.platform === 'darwin') + // MacOS use arch instead of cpu / cuda + return process.arch === 'arm64' ? 'arm64' : 'x64' -export const GPU_INFO_FILE = path.join( - getJanDataFolderPath(), - 'settings', - 'settings.json' -) + if (!settings) return 'cpu' + + return settings.vulkan === true + ? 'vulkan' + : settings.run_mode === 'cpu' + ? 'cpu' + : 'cuda' +} + +const os = (): string => { + return process.platform === 'win32' + ? 'win' + : process.platform === 'darwin' + ? 'mac' + : 'linux' +} + +const extension = (): '.exe' | '' => { + return process.platform === 'win32' ? '.exe' : '' +} + +const cudaVersion = (settings?: GpuSetting): '11-7' | '12-0' | undefined => { + const isUsingCuda = + settings?.vulkan !== true && settings?.run_mode === 'gpu' && os() !== 'mac' + + if (!isUsingCuda) return undefined + return settings?.cuda?.version === '11' ? '11-7' : '12-0' +} /** * Find which executable file to run based on the current platform. * @returns The name of the executable file to run. */ -export const executableNitroFile = (): NitroExecutableOptions => { - let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default - let cudaVisibleDevices = '' - let vkVisibleDevices = '' - let binaryName = 'nitro' - /** - * The binary folder is different for each platform. - */ - if (process.platform === 'win32') { - /** - * For Windows: win-cpu, win-vulkan, win-cuda-11-7, win-cuda-12-0 - */ - let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) - if (gpuInfo['run_mode'] === 'cpu') { - binaryFolder = path.join(binaryFolder, 'win-cpu') - } else { - if (gpuInfo['cuda']?.version === '11') { - binaryFolder = path.join(binaryFolder, 'win-cuda-11-7') - } else { - binaryFolder = path.join(binaryFolder, 'win-cuda-12-0') - } - cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',') - } - if (gpuInfo['vulkan'] === true) { - binaryFolder = path.join(__dirname, '..', 'bin') - binaryFolder = path.join(binaryFolder, 'win-vulkan') - vkVisibleDevices = gpuInfo['gpus_in_use'].toString() - } - binaryName = 'nitro.exe' - } else if (process.platform === 'darwin') { - /** - * For MacOS: mac-arm64 (Silicon), mac-x64 (InteL) - */ - if (process.arch === 'arm64') { - binaryFolder = path.join(binaryFolder, 'mac-arm64') - } else { - binaryFolder = path.join(binaryFolder, 'mac-x64') - } - } else { - /** - * For Linux: linux-cpu, linux-vulkan, linux-cuda-11-7, linux-cuda-12-0 - */ - let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8')) - if (gpuInfo['run_mode'] === 'cpu') { - binaryFolder = path.join(binaryFolder, 'linux-cpu') - } else { - if (gpuInfo['cuda']?.version === '11') { - binaryFolder = path.join(binaryFolder, 'linux-cuda-11-7') - } else { - binaryFolder = path.join(binaryFolder, 'linux-cuda-12-0') - } - cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',') - } +export const executableNitroFile = ( + gpuSetting?: GpuSetting +): NitroExecutableOptions => { + let binaryFolder = [os(), runMode(gpuSetting), cudaVersion(gpuSetting)] + .filter((e) => !!e) + .join('-') + let cudaVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? '' + let vkVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? '' + let binaryName = `nitro${extension()}` - if (gpuInfo['vulkan'] === true) { - binaryFolder = path.join(__dirname, '..', 'bin') - binaryFolder = path.join(binaryFolder, 'linux-vulkan') - vkVisibleDevices = gpuInfo['gpus_in_use'].toString() - } - } return { - executablePath: path.join(binaryFolder, binaryName), + executablePath: path.join(__dirname, '..', 'bin', binaryFolder, binaryName), cudaVisibleDevices, vkVisibleDevices, } diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index 5e8b97188..71adac72d 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -10,6 +10,7 @@ import { InferenceEngine, ModelSettingParams, PromptTemplate, + SystemInformation, } from '@janhq/core/node' import { executableNitroFile } from './execute' @@ -51,7 +52,7 @@ let currentSettings: ModelSettingParams | undefined = undefined * @param wrapper - The model wrapper. * @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate. */ -function stopModel(): Promise { +function unloadModel(): Promise { return killSubprocess() } @@ -61,46 +62,47 @@ function stopModel(): Promise { * @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load. * TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package */ -async function runModel( - wrapper: ModelInitOptions +async function loadModel( + params: ModelInitOptions, + systemInfo?: SystemInformation ): Promise { - if (wrapper.model.engine !== InferenceEngine.nitro) { + if (params.model.engine !== InferenceEngine.nitro) { // Not a nitro model return Promise.resolve() } - if (wrapper.model.engine !== InferenceEngine.nitro) { + if (params.model.engine !== InferenceEngine.nitro) { return Promise.reject('Not a nitro model') } else { const nitroResourceProbe = await getSystemResourceInfo() // Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt - if (wrapper.model.settings.prompt_template) { - const promptTemplate = wrapper.model.settings.prompt_template + if (params.model.settings.prompt_template) { + const promptTemplate = params.model.settings.prompt_template const prompt = promptTemplateConverter(promptTemplate) if (prompt?.error) { return Promise.reject(prompt.error) } - wrapper.model.settings.system_prompt = prompt.system_prompt - wrapper.model.settings.user_prompt = prompt.user_prompt - wrapper.model.settings.ai_prompt = prompt.ai_prompt + params.model.settings.system_prompt = prompt.system_prompt + params.model.settings.user_prompt = prompt.user_prompt + params.model.settings.ai_prompt = prompt.ai_prompt } // modelFolder is the absolute path to the running model folder // e.g. ~/jan/models/llama-2 - let modelFolder = wrapper.modelFolder + let modelFolder = params.modelFolder - let llama_model_path = wrapper.model.settings.llama_model_path + let llama_model_path = params.model.settings.llama_model_path // Absolute model path support if ( - wrapper.model?.sources.length && - wrapper.model.sources.every((e) => fs.existsSync(e.url)) + params.model?.sources.length && + params.model.sources.every((e) => fs.existsSync(e.url)) ) { llama_model_path = - wrapper.model.sources.length === 1 - ? wrapper.model.sources[0].url - : wrapper.model.sources.find((e) => - e.url.includes(llama_model_path ?? wrapper.model.id) + params.model.sources.length === 1 + ? params.model.sources[0].url + : params.model.sources.find((e) => + e.url.includes(llama_model_path ?? params.model.id) )?.url } @@ -114,7 +116,7 @@ async function runModel( // 2. Prioritize GGUF File (manual import) file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) || // 3. Fallback Model ID (for backward compatibility) - file === wrapper.model.id + file === params.model.id ) if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile) } @@ -124,17 +126,17 @@ async function runModel( if (!llama_model_path) return Promise.reject('No GGUF model file found') currentSettings = { - ...wrapper.model.settings, + ...params.model.settings, llama_model_path, // This is critical and requires real CPU physical core count (or performance core) cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore), - ...(wrapper.model.settings.mmproj && { - mmproj: path.isAbsolute(wrapper.model.settings.mmproj) - ? wrapper.model.settings.mmproj - : path.join(modelFolder, wrapper.model.settings.mmproj), + ...(params.model.settings.mmproj && { + mmproj: path.isAbsolute(params.model.settings.mmproj) + ? params.model.settings.mmproj + : path.join(modelFolder, params.model.settings.mmproj), }), } - return runNitroAndLoadModel() + return runNitroAndLoadModel(systemInfo) } } @@ -144,7 +146,7 @@ async function runModel( * 3. Validate model status * @returns */ -async function runNitroAndLoadModel() { +async function runNitroAndLoadModel(systemInfo?: SystemInformation) { // Gather system information for CPU physical cores and memory return killSubprocess() .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) @@ -160,7 +162,7 @@ async function runNitroAndLoadModel() { return Promise.resolve() } }) - .then(spawnNitroProcess) + .then(() => spawnNitroProcess(systemInfo)) .then(() => loadLLMModel(currentSettings)) .then(validateModelStatus) .catch((err) => { @@ -325,12 +327,12 @@ async function killSubprocess(): Promise { * Spawns a Nitro subprocess. * @returns A promise that resolves when the Nitro subprocess is started. */ -function spawnNitroProcess(): Promise { +function spawnNitroProcess(systemInfo?: SystemInformation): Promise { log(`[NITRO]::Debug: Spawning Nitro subprocess...`) return new Promise(async (resolve, reject) => { let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default - let executableOptions = executableNitroFile() + let executableOptions = executableNitroFile(systemInfo?.gpuSetting) const args: string[] = ['1', LOCAL_HOST, PORT.toString()] // Execute the binary @@ -402,9 +404,8 @@ const getCurrentNitroProcessInfo = (): NitroProcessInfo => { } export default { - runModel, - stopModel, - killSubprocess, + loadModel, + unloadModel, dispose, getCurrentNitroProcessInfo, } diff --git a/extensions/inference-openai-extension/README.md b/extensions/inference-openai-extension/README.md index 455783efb..c716c725c 100644 --- a/extensions/inference-openai-extension/README.md +++ b/extensions/inference-openai-extension/README.md @@ -1,14 +1,14 @@ -# Jan inference plugin +# OpenAI Engine Extension -Created using Jan app example +Created using Jan extension example -# Create a Jan Plugin using Typescript +# Create a Jan Extension using Typescript -Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 +Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀 -## Create Your Own Plugin +## Create Your Own Extension -To create your own plugin, you can use this repository as a template! Just follow the below instructions: +To create your own extension, you can use this repository as a template! Just follow the below instructions: 1. Click the Use this template button at the top of the repository 2. Select Create a new repository @@ -18,7 +18,7 @@ To create your own plugin, you can use this repository as a template! Just follo ## Initial Setup -After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin. +After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension. > [!NOTE] > @@ -43,36 +43,37 @@ After you've cloned the repository to your local machine or codespace, you'll ne 1. :white_check_mark: Check your artifact - There will be a tgz file in your plugin directory now + There will be a tgz file in your extension directory now -## Update the Plugin Metadata +## Update the Extension Metadata -The [`package.json`](package.json) file defines metadata about your plugin, such as -plugin name, main entry, description and version. +The [`package.json`](package.json) file defines metadata about your extension, such as +extension name, main entry, description and version. -When you copy this repository, update `package.json` with the name, description for your plugin. +When you copy this repository, update `package.json` with the name, description for your extension. -## Update the Plugin Code +## Update the Extension Code -The [`src/`](./src/) directory is the heart of your plugin! This contains the -source code that will be run when your plugin extension functions are invoked. You can replace the +The [`src/`](./src/) directory is the heart of your extension! This contains the +source code that will be run when your extension functions are invoked. You can replace the contents of this directory with your own code. -There are a few things to keep in mind when writing your plugin code: +There are a few things to keep in mind when writing your extension code: -- Most Jan Plugin Extension functions are processed asynchronously. +- Most Jan Extension functions are processed asynchronously. In `index.ts`, you will see that the extension function will return a `Promise`. ```typescript - import { core } from "@janhq/core"; + import { events, MessageEvent, MessageRequest } from '@janhq/core' function onStart(): Promise { - return core.invokePluginFunc(MODULE_PATH, "run", 0); + return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => + this.inference(data) + ) } ``` - For more information about the Jan Plugin Core module, see the + For more information about the Jan Extension Core module, see the [documentation](https://github.com/janhq/jan/blob/main/core/README.md). -So, what are you waiting for? Go ahead and start customizing your plugin! - +So, what are you waiting for? Go ahead and start customizing your extension! diff --git a/extensions/inference-openai-extension/package.json b/extensions/inference-openai-extension/package.json index 9139661fd..7b112a8fb 100644 --- a/extensions/inference-openai-extension/package.json +++ b/extensions/inference-openai-extension/package.json @@ -4,6 +4,7 @@ "description": "This extension enables OpenAI chat completion API calls", "main": "dist/index.js", "module": "dist/module.js", + "engine": "openai", "author": "Jan ", "license": "AGPL-3.0", "scripts": { diff --git a/extensions/inference-openai-extension/src/@types/global.d.ts b/extensions/inference-openai-extension/src/@types/global.d.ts deleted file mode 100644 index a49bb5a2f..000000000 --- a/extensions/inference-openai-extension/src/@types/global.d.ts +++ /dev/null @@ -1,26 +0,0 @@ -declare const MODULE: string -declare const OPENAI_DOMAIN: string - -declare interface EngineSettings { - full_url?: string - api_key?: string -} - -enum OpenAIChatCompletionModelName { - 'gpt-3.5-turbo-instruct' = 'gpt-3.5-turbo-instruct', - 'gpt-3.5-turbo-instruct-0914' = 'gpt-3.5-turbo-instruct-0914', - 'gpt-4-1106-preview' = 'gpt-4-1106-preview', - 'gpt-3.5-turbo-0613' = 'gpt-3.5-turbo-0613', - 'gpt-3.5-turbo-0301' = 'gpt-3.5-turbo-0301', - 'gpt-3.5-turbo' = 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k-0613' = 'gpt-3.5-turbo-16k-0613', - 'gpt-3.5-turbo-1106' = 'gpt-3.5-turbo-1106', - 'gpt-4-vision-preview' = 'gpt-4-vision-preview', - 'gpt-4' = 'gpt-4', - 'gpt-4-0314' = 'gpt-4-0314', - 'gpt-4-0613' = 'gpt-4-0613', -} - -declare type OpenAIModel = Omit & { - id: OpenAIChatCompletionModelName -} diff --git a/extensions/inference-openai-extension/src/helpers/sse.ts b/extensions/inference-openai-extension/src/helpers/sse.ts deleted file mode 100644 index bee2e65bc..000000000 --- a/extensions/inference-openai-extension/src/helpers/sse.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { ErrorCode } from '@janhq/core' -import { Observable } from 'rxjs' - -/** - * Sends a request to the inference server to generate a response based on the recent messages. - * @param recentMessages - An array of recent messages to use as context for the inference. - * @param engine - The engine settings to use for the inference. - * @param model - The model to use for the inference. - * @returns An Observable that emits the generated response as a string. - */ -export function requestInference( - recentMessages: any[], - engine: EngineSettings, - model: OpenAIModel, - controller?: AbortController -): Observable { - return new Observable((subscriber) => { - let model_id: string = model.id - if (engine.full_url.includes(OPENAI_DOMAIN)) { - model_id = engine.full_url.split('/')[5] - } - const requestBody = JSON.stringify({ - messages: recentMessages, - stream: true, - model: model_id, - ...model.parameters, - }) - fetch(`${engine.full_url}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Accept': model.parameters.stream - ? 'text/event-stream' - : 'application/json', - 'Access-Control-Allow-Origin': '*', - 'Authorization': `Bearer ${engine.api_key}`, - 'api-key': `${engine.api_key}`, - }, - body: requestBody, - signal: controller?.signal, - }) - .then(async (response) => { - if (!response.ok) { - const data = await response.json() - const error = { - message: data.error?.message ?? 'An error occurred.', - code: data.error?.code ?? ErrorCode.Unknown, - } - subscriber.error(error) - subscriber.complete() - return - } - if (model.parameters.stream === false) { - const data = await response.json() - subscriber.next(data.choices[0]?.message?.content ?? '') - } else { - const stream = response.body - const decoder = new TextDecoder('utf-8') - const reader = stream?.getReader() - let content = '' - - while (true && reader) { - const { done, value } = await reader.read() - if (done) { - break - } - const text = decoder.decode(value) - const lines = text.trim().split('\n') - for (const line of lines) { - if (line.startsWith('data: ') && !line.includes('data: [DONE]')) { - const data = JSON.parse(line.replace('data: ', '')) - content += data.choices[0]?.delta?.content ?? '' - if (content.startsWith('assistant: ')) { - content = content.replace('assistant: ', '') - } - subscriber.next(content) - } - } - } - } - subscriber.complete() - }) - .catch((err) => subscriber.error(err)) - }) -} diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index ab0c2bde6..ad5b73a1e 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -5,75 +5,52 @@ * @version 1.0.0 * @module inference-openai-extension/src/index */ +declare const ENGINE: string import { - ChatCompletionRole, - ContentType, - MessageRequest, - MessageStatus, - ThreadContent, - ThreadMessage, events, fs, - InferenceEngine, - BaseExtension, - MessageEvent, - MessageRequestType, - ModelEvent, - InferenceEvent, AppConfigurationEventName, joinPath, + RemoteOAIEngine, } from '@janhq/core' -import { requestInference } from './helpers/sse' -import { ulid } from 'ulidx' import { join } from 'path' +declare const COMPLETION_URL: string + /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class JanInferenceOpenAIExtension extends BaseExtension { +export default class JanInferenceOpenAIExtension extends RemoteOAIEngine { private static readonly _engineDir = 'file://engines' - private static readonly _engineMetadataFileName = 'openai.json' + private static readonly _engineMetadataFileName = `${ENGINE}.json` - private static _currentModel: OpenAIModel - - private static _engineSettings: EngineSettings = { - full_url: 'https://api.openai.com/v1/chat/completions', + private _engineSettings = { + full_url: COMPLETION_URL, api_key: 'sk-', } - controller = new AbortController() - isCancelled = false + inferenceUrl: string = COMPLETION_URL + provider: string = 'openai' + apiKey: string = '' + // TODO: Just use registerSettings from BaseExtension + // Remove these methods /** * Subscribes to events emitted by the @janhq/core package. */ async onLoad() { + super.onLoad() + if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) { await fs .mkdirSync(JanInferenceOpenAIExtension._engineDir) .catch((err) => console.debug(err)) } - JanInferenceOpenAIExtension.writeDefaultEngineSettings() - - // Events subscription - events.on(MessageEvent.OnMessageSent, (data) => - JanInferenceOpenAIExtension.handleMessageRequest(data, this) - ) - - events.on(ModelEvent.OnModelInit, (model: OpenAIModel) => { - JanInferenceOpenAIExtension.handleModelInit(model) - }) - - events.on(ModelEvent.OnModelStop, (model: OpenAIModel) => { - JanInferenceOpenAIExtension.handleModelStop(model) - }) - events.on(InferenceEvent.OnInferenceStopped, () => { - JanInferenceOpenAIExtension.handleInferenceStopped(this) - }) + this.writeDefaultEngineSettings() const settingsFilePath = await joinPath([ JanInferenceOpenAIExtension._engineDir, @@ -84,18 +61,12 @@ export default class JanInferenceOpenAIExtension extends BaseExtension { AppConfigurationEventName.OnConfigurationUpdate, (settingsKey: string) => { // Update settings on changes - if (settingsKey === settingsFilePath) - JanInferenceOpenAIExtension.writeDefaultEngineSettings() + if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings() } ) } - /** - * Stops the model inference. - */ - onUnload(): void {} - - static async writeDefaultEngineSettings() { + async writeDefaultEngineSettings() { try { const engineFile = join( JanInferenceOpenAIExtension._engineDir, @@ -103,122 +74,18 @@ export default class JanInferenceOpenAIExtension extends BaseExtension { ) if (await fs.existsSync(engineFile)) { const engine = await fs.readFileSync(engineFile, 'utf-8') - JanInferenceOpenAIExtension._engineSettings = + this._engineSettings = typeof engine === 'object' ? engine : JSON.parse(engine) + this.inferenceUrl = this._engineSettings.full_url + this.apiKey = this._engineSettings.api_key } else { await fs.writeFileSync( engineFile, - JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2) + JSON.stringify(this._engineSettings, null, 2) ) } } catch (err) { console.error(err) } } - private static async handleModelInit(model: OpenAIModel) { - if (model.engine !== InferenceEngine.openai) { - return - } else { - JanInferenceOpenAIExtension._currentModel = model - JanInferenceOpenAIExtension.writeDefaultEngineSettings() - // Todo: Check model list with API key - events.emit(ModelEvent.OnModelReady, model) - } - } - - private static async handleModelStop(model: OpenAIModel) { - if (model.engine !== 'openai') { - return - } - events.emit(ModelEvent.OnModelStopped, model) - } - - private static async handleInferenceStopped( - instance: JanInferenceOpenAIExtension - ) { - instance.isCancelled = true - instance.controller?.abort() - } - - /** - * Handles a new message request by making an inference request and emitting events. - * Function registered in event manager, should be static to avoid binding issues. - * Pass instance as a reference. - * @param {MessageRequest} data - The data for the new message request. - */ - private static async handleMessageRequest( - data: MessageRequest, - instance: JanInferenceOpenAIExtension - ) { - if (data.model.engine !== 'openai') { - return - } - - const timestamp = Date.now() - const message: ThreadMessage = { - id: ulid(), - thread_id: data.threadId, - type: data.type, - assistant_id: data.assistantId, - role: ChatCompletionRole.Assistant, - content: [], - status: MessageStatus.Pending, - created: timestamp, - updated: timestamp, - object: 'thread.message', - } - - if (data.type !== MessageRequestType.Summary) { - events.emit(MessageEvent.OnMessageResponse, message) - } - - instance.isCancelled = false - instance.controller = new AbortController() - - requestInference( - data?.messages ?? [], - this._engineSettings, - { - ...JanInferenceOpenAIExtension._currentModel, - parameters: data.model.parameters, - }, - instance.controller - ).subscribe({ - next: (content) => { - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: content.trim(), - annotations: [], - }, - } - message.content = [messageContent] - events.emit(MessageEvent.OnMessageUpdate, message) - }, - complete: async () => { - message.status = message.content.length - ? MessageStatus.Ready - : MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - }, - error: async (err) => { - if (instance.isCancelled || message.content.length > 0) { - message.status = MessageStatus.Stopped - events.emit(MessageEvent.OnMessageUpdate, message) - return - } - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: 'An error occurred. ' + err.message, - annotations: [], - }, - } - message.content = [messageContent] - message.status = MessageStatus.Error - message.error_code = err.code - events.emit(MessageEvent.OnMessageUpdate, message) - }, - }) - } } diff --git a/extensions/inference-openai-extension/webpack.config.js b/extensions/inference-openai-extension/webpack.config.js index ee2e3b624..ee18035f2 100644 --- a/extensions/inference-openai-extension/webpack.config.js +++ b/extensions/inference-openai-extension/webpack.config.js @@ -17,8 +17,8 @@ module.exports = { }, plugins: [ new webpack.DefinePlugin({ - MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), - OPENAI_DOMAIN: JSON.stringify('openai.azure.com'), + ENGINE: JSON.stringify(packageJson.engine), + COMPLETION_URL: JSON.stringify('https://api.openai.com/v1/chat/completions'), }), ], output: { diff --git a/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts b/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts deleted file mode 100644 index c834feba0..000000000 --- a/extensions/inference-triton-trtllm-extension/src/@types/global.d.ts +++ /dev/null @@ -1,5 +0,0 @@ -import { Model } from '@janhq/core' - -declare interface EngineSettings { - base_url?: string -} diff --git a/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts b/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts deleted file mode 100644 index 9aff61265..000000000 --- a/extensions/inference-triton-trtllm-extension/src/helpers/sse.ts +++ /dev/null @@ -1,63 +0,0 @@ -import { Observable } from 'rxjs' -import { EngineSettings } from '../@types/global' -import { Model } from '@janhq/core' - -/** - * Sends a request to the inference server to generate a response based on the recent messages. - * @param recentMessages - An array of recent messages to use as context for the inference. - * @param engine - The engine settings to use for the inference. - * @param model - The model to use for the inference. - * @returns An Observable that emits the generated response as a string. - */ -export function requestInference( - recentMessages: any[], - engine: EngineSettings, - model: Model, - controller?: AbortController -): Observable { - return new Observable((subscriber) => { - const text_input = recentMessages.map((message) => message.text).join('\n') - const requestBody = JSON.stringify({ - text_input: text_input, - max_tokens: 4096, - temperature: 0, - bad_words: '', - stop_words: '[DONE]', - stream: true, - }) - fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'Accept': 'text/event-stream', - 'Access-Control-Allow-Origin': '*', - }, - body: requestBody, - signal: controller?.signal, - }) - .then(async (response) => { - const stream = response.body - const decoder = new TextDecoder('utf-8') - const reader = stream?.getReader() - let content = '' - - while (true && reader) { - const { done, value } = await reader.read() - if (done) { - break - } - const text = decoder.decode(value) - const lines = text.trim().split('\n') - for (const line of lines) { - if (line.startsWith('data: ') && !line.includes('data: [DONE]')) { - const data = JSON.parse(line.replace('data: ', '')) - content += data.choices[0]?.delta?.content ?? '' - subscriber.next(content) - } - } - } - subscriber.complete() - }) - .catch((err) => subscriber.error(err)) - }) -} diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts index ae1d9315f..0df64acf1 100644 --- a/extensions/inference-triton-trtllm-extension/src/index.ts +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -7,212 +7,76 @@ */ import { - ChatCompletionRole, - ContentType, - MessageRequest, - MessageStatus, - ModelSettingParams, - ThreadContent, - ThreadMessage, + AppConfigurationEventName, events, fs, + joinPath, Model, - BaseExtension, - MessageEvent, - ModelEvent, + RemoteOAIEngine, } from '@janhq/core' -import { requestInference } from './helpers/sse' -import { ulid } from 'ulidx' import { join } from 'path' -import { EngineSettings } from './@types/global' /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class JanInferenceTritonTrtLLMExtension extends BaseExtension { - private static readonly _homeDir = 'file://engines' - private static readonly _engineMetadataFileName = 'triton_trtllm.json' +export default class JanInferenceTritonTrtLLMExtension extends RemoteOAIEngine { + private readonly _engineDir = 'file://engines' + private readonly _engineMetadataFileName = 'triton_trtllm.json' - static _currentModel: Model + inferenceUrl: string = '' + provider: string = 'triton_trtllm' + apiKey: string = '' - static _engineSettings: EngineSettings = { - base_url: '', + _engineSettings: { + base_url: '' + api_key: '' } - controller = new AbortController() - isCancelled = false - /** * Subscribes to events emitted by the @janhq/core package. */ async onLoad() { - if (!(await fs.existsSync(JanInferenceTritonTrtLLMExtension._homeDir))) - JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() + super.onLoad() + if (!(await fs.existsSync(this._engineDir))) { + await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err)) + } + + this.writeDefaultEngineSettings() + + const settingsFilePath = await joinPath([ + this._engineDir, + this._engineMetadataFileName, + ]) // Events subscription - events.on(MessageEvent.OnMessageSent, (data) => - JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this) + events.on( + AppConfigurationEventName.OnConfigurationUpdate, + (settingsKey: string) => { + // Update settings on changes + if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings() + } ) - - events.on(ModelEvent.OnModelInit, (model: Model) => { - JanInferenceTritonTrtLLMExtension.handleModelInit(model) - }) - - events.on(ModelEvent.OnModelStop, (model: Model) => { - JanInferenceTritonTrtLLMExtension.handleModelStop(model) - }) } - /** - * Stops the model inference. - */ - onUnload(): void {} - - /** - * Initializes the model with the specified file name. - * @param {string} modelId - The ID of the model to initialize. - * @returns {Promise} A promise that resolves when the model is initialized. - */ - async initModel( - modelId: string, - settings?: ModelSettingParams - ): Promise { - return - } - - static async writeDefaultEngineSettings() { + async writeDefaultEngineSettings() { try { - const engine_json = join( - JanInferenceTritonTrtLLMExtension._homeDir, - JanInferenceTritonTrtLLMExtension._engineMetadataFileName - ) + const engine_json = join(this._engineDir, this._engineMetadataFileName) if (await fs.existsSync(engine_json)) { const engine = await fs.readFileSync(engine_json, 'utf-8') - JanInferenceTritonTrtLLMExtension._engineSettings = + this._engineSettings = typeof engine === 'object' ? engine : JSON.parse(engine) + this.inferenceUrl = this._engineSettings.base_url + this.apiKey = this._engineSettings.api_key } else { await fs.writeFileSync( engine_json, - JSON.stringify( - JanInferenceTritonTrtLLMExtension._engineSettings, - null, - 2 - ) + JSON.stringify(this._engineSettings, null, 2) ) } } catch (err) { console.error(err) } } - /** - * Stops the model. - * @returns {Promise} A promise that resolves when the model is stopped. - */ - async stopModel(): Promise {} - - /** - * Stops streaming inference. - * @returns {Promise} A promise that resolves when the streaming is stopped. - */ - async stopInference(): Promise { - this.isCancelled = true - this.controller?.abort() - } - - private static async handleModelInit(model: Model) { - if (model.engine !== 'triton_trtllm') { - return - } else { - JanInferenceTritonTrtLLMExtension._currentModel = model - JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() - // Todo: Check model list with API key - events.emit(ModelEvent.OnModelReady, model) - } - } - - private static async handleModelStop(model: Model) { - if (model.engine !== 'triton_trtllm') { - return - } - events.emit(ModelEvent.OnModelStopped, model) - } - - /** - * Handles a new message request by making an inference request and emitting events. - * Function registered in event manager, should be static to avoid binding issues. - * Pass instance as a reference. - * @param {MessageRequest} data - The data for the new message request. - */ - private static async handleMessageRequest( - data: MessageRequest, - instance: JanInferenceTritonTrtLLMExtension - ) { - if (data.model.engine !== 'triton_trtllm') { - return - } - - const timestamp = Date.now() - const message: ThreadMessage = { - id: ulid(), - thread_id: data.threadId, - assistant_id: data.assistantId, - role: ChatCompletionRole.Assistant, - content: [], - status: MessageStatus.Pending, - created: timestamp, - updated: timestamp, - object: 'thread.message', - } - events.emit(MessageEvent.OnMessageResponse, message) - - instance.isCancelled = false - instance.controller = new AbortController() - - requestInference( - data?.messages ?? [], - this._engineSettings, - { - ...JanInferenceTritonTrtLLMExtension._currentModel, - parameters: data.model.parameters, - }, - instance.controller - ).subscribe({ - next: (content) => { - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: content.trim(), - annotations: [], - }, - } - message.content = [messageContent] - events.emit(MessageEvent.OnMessageUpdate, message) - }, - complete: async () => { - message.status = message.content.length - ? MessageStatus.Ready - : MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - }, - error: async (err) => { - if (instance.isCancelled || message.content.length) { - message.status = MessageStatus.Error - events.emit(MessageEvent.OnMessageUpdate, message) - return - } - const messageContent: ThreadContent = { - type: ContentType.Text, - text: { - value: 'An error occurred. ' + err.message, - annotations: [], - }, - } - message.content = [messageContent] - message.status = MessageStatus.Ready - events.emit(MessageEvent.OnMessageUpdate, message) - }, - }) - } } diff --git a/extensions/tensorrt-llm-extension/src/index.ts b/extensions/tensorrt-llm-extension/src/index.ts index d2d08e8a7..de5199b7d 100644 --- a/extensions/tensorrt-llm-extension/src/index.ts +++ b/extensions/tensorrt-llm-extension/src/index.ts @@ -43,14 +43,14 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { private supportedPlatform = ['win32', 'linux'] private isUpdateAvailable = false - compatibility() { + override compatibility() { return COMPATIBILITY as unknown as Compatibility } /** * models implemented by the extension * define pre-populated models */ - async models(): Promise { + override async models(): Promise { if ((await this.installationState()) === 'Installed') return models as unknown as Model[] return [] @@ -160,11 +160,11 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { events.emit(ModelEvent.OnModelsUpdate, {}) } - async onModelInit(model: Model): Promise { + override async loadModel(model: Model): Promise { if (model.engine !== this.provider) return if ((await this.installationState()) === 'Installed') - return super.onModelInit(model) + return super.loadModel(model) else { events.emit(ModelEvent.OnModelFail, { ...model, @@ -175,7 +175,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { } } - updatable() { + override updatable() { return this.isUpdateAvailable } @@ -241,8 +241,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled' } - override onInferenceStopped() { - if (!this.isRunning) return + override stopInference() { showToast( 'Unable to Stop Inference', 'The model does not support stopping inference.' @@ -250,8 +249,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine { return Promise.resolve() } - inference(data: MessageRequest): void { - if (!this.isRunning) return + override inference(data: MessageRequest): void { + if (!this.loadedModel) return // TensorRT LLM Extension supports streaming only if (data.model) data.model.parameters.stream = true super.inference(data) diff --git a/models/groq-llama2-70b/model.json b/models/groq-llama2-70b/model.json index 454591379..c2b925425 100644 --- a/models/groq-llama2-70b/model.json +++ b/models/groq-llama2-70b/model.json @@ -1,27 +1,26 @@ { - "sources": [ - { - "url": "https://groq.com" - } - ], - "id": "llama2-70b-4096", - "object": "model", - "name": "Groq Llama 2 70b", - "version": "1.0", - "description": "Groq Llama 2 70b with supercharged speed!", - "format": "api", - "settings": {}, - "parameters": { - "max_tokens": 4096, - "temperature": 0.7, - "top_p": 1, - "stop": null, - "stream": true - }, - "metadata": { - "author": "Meta", - "tags": ["General", "Big Context Length"] - }, - "engine": "groq" - } - \ No newline at end of file + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "llama2-70b-4096", + "object": "model", + "name": "Groq Llama 2 70b", + "version": "1.0", + "description": "Groq Llama 2 70b with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 1, + "stop": null, + "stream": true + }, + "metadata": { + "author": "Meta", + "tags": ["General", "Big Context Length"] + }, + "engine": "groq" +} diff --git a/models/groq-mixtral-8x7b-instruct/model.json b/models/groq-mixtral-8x7b-instruct/model.json index 241e7ae17..ae1cbbf80 100644 --- a/models/groq-mixtral-8x7b-instruct/model.json +++ b/models/groq-mixtral-8x7b-instruct/model.json @@ -1,27 +1,26 @@ { - "sources": [ - { - "url": "https://groq.com" - } - ], - "id": "mixtral-8x7b-32768", - "object": "model", - "name": "Groq Mixtral 8x7b Instruct", - "version": "1.0", - "description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!", - "format": "api", - "settings": {}, - "parameters": { - "max_tokens": 4096, - "temperature": 0.7, - "top_p": 1, - "stop": null, - "stream": true - }, - "metadata": { - "author": "Mistral", - "tags": ["General", "Big Context Length"] - }, - "engine": "groq" - } - \ No newline at end of file + "sources": [ + { + "url": "https://groq.com" + } + ], + "id": "mixtral-8x7b-32768", + "object": "model", + "name": "Groq Mixtral 8x7b Instruct", + "version": "1.0", + "description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!", + "format": "api", + "settings": {}, + "parameters": { + "max_tokens": 4096, + "temperature": 0.7, + "top_p": 1, + "stop": null, + "stream": true + }, + "metadata": { + "author": "Mistral", + "tags": ["General", "Big Context Length"] + }, + "engine": "groq" +} diff --git a/web/containers/DropdownListSidebar/index.tsx b/web/containers/DropdownListSidebar/index.tsx index fb51c521d..5022c83f1 100644 --- a/web/containers/DropdownListSidebar/index.tsx +++ b/web/containers/DropdownListSidebar/index.tsx @@ -75,12 +75,14 @@ const DropdownListSidebar = ({ // TODO: Update filter condition for the local model const localModel = downloadedModels.filter( - (model) => model.engine !== InferenceEngine.openai + (model) => + model.engine === InferenceEngine.nitro || + model.engine === InferenceEngine.nitro_tensorrt_llm ) const remoteModel = downloadedModels.filter( (model) => - model.engine === InferenceEngine.openai || - model.engine === InferenceEngine.groq + model.engine !== InferenceEngine.nitro && + model.engine !== InferenceEngine.nitro_tensorrt_llm ) const modelOptions = isTabActive === 0 ? localModel : remoteModel diff --git a/web/screens/Settings/Models/Row.tsx b/web/screens/Settings/Models/Row.tsx index b929c85f9..9707f6194 100644 --- a/web/screens/Settings/Models/Row.tsx +++ b/web/screens/Settings/Models/Row.tsx @@ -48,9 +48,8 @@ export default function RowModel(props: RowModelProps) { const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom) const isRemoteModel = - props.data.engine === InferenceEngine.openai || - props.data.engine === InferenceEngine.groq || - props.data.engine === InferenceEngine.triton_trtllm + props.data.engine !== InferenceEngine.nitro && + props.data.engine !== InferenceEngine.nitro_tensorrt_llm const onModelActionClick = (modelId: string) => { if (activeModel && activeModel.id === modelId) { diff --git a/web/services/extensionService.ts b/web/services/extensionService.ts index 6ae4f78f0..975b226b9 100644 --- a/web/services/extensionService.ts +++ b/web/services/extensionService.ts @@ -8,7 +8,6 @@ export const isCoreExtensionInstalled = () => { if (!extensionManager.get(ExtensionTypeEnum.Conversational)) { return false } - if (!extensionManager.get(ExtensionTypeEnum.Inference)) return false if (!extensionManager.get(ExtensionTypeEnum.Model)) { return false } @@ -22,7 +21,6 @@ export const setupBaseExtensions = async () => { if ( !extensionManager.get(ExtensionTypeEnum.Conversational) || - !extensionManager.get(ExtensionTypeEnum.Inference) || !extensionManager.get(ExtensionTypeEnum.Model) ) { const installed = await extensionManager.install(baseExtensions)