fix: refactor inference engines to extends AIEngine (#2347)

* fix: refactor nitro to extends localoaiengine

* fix: refactor openai extension

* chore: refactor groq extension

* chore: refactor triton tensorrt extension

* chore: add tests

* chore: refactor engines
This commit is contained in:
Louis 2024-03-22 09:35:14 +07:00 committed by GitHub
parent b8e4a029a4
commit acbec78dbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 625 additions and 1292 deletions

View File

@ -56,7 +56,8 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
* @param paths - The paths to join.
* @returns {Promise<string>} A promise that resolves with the joined path.
*/
const joinPath: (paths: string[]) => Promise<string> = (paths) => globalThis.core.api?.joinPath(paths)
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
globalThis.core.api?.joinPath(paths)
/**
* Retrive the basename from an url.

View File

@ -14,7 +14,9 @@ export abstract class AIEngine extends BaseExtension {
// The model folder
modelFolder: string = 'models'
abstract models(): Promise<Model[]>
models(): Promise<Model[]> {
return Promise.resolve([])
}
/**
* On extension load, subscribe to events.

View File

@ -9,9 +9,9 @@ import { OAIEngine } from './OAIEngine'
*/
export abstract class LocalOAIEngine extends OAIEngine {
// The inference engine
abstract nodeModule: string
loadModelFunctionName: string = 'loadModel'
unloadModelFunctionName: string = 'unloadModel'
isRunning: boolean = false
/**
* On extension load, subscribe to events.
@ -19,22 +19,27 @@ export abstract class LocalOAIEngine extends OAIEngine {
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model))
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async onModelInit(model: Model) {
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(this.nodeModule, this.loadModelFunctionName, {
modelFolder,
model,
}, systemInfo)
const res = await executeOnMain(
this.nodeModule,
this.loadModelFunctionName,
{
modelFolder,
model,
},
systemInfo
)
if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
@ -45,16 +50,14 @@ export abstract class LocalOAIEngine extends OAIEngine {
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
this.isRunning = true
}
}
/**
* Stops the model.
*/
onModelStop(model: Model) {
if (model.engine?.toString() !== this.provider) return
this.isRunning = false
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})

View File

@ -23,7 +23,6 @@ import { events } from '../../events'
export abstract class OAIEngine extends AIEngine {
// The inference engine
abstract inferenceUrl: string
abstract nodeModule: string
// Controller to handle stop requests
controller = new AbortController()
@ -38,7 +37,7 @@ export abstract class OAIEngine extends AIEngine {
onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.onInferenceStopped())
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
}
/**
@ -78,7 +77,13 @@ export abstract class OAIEngine extends AIEngine {
...data.model,
}
requestInference(this.inferenceUrl, data.messages ?? [], model, this.controller).subscribe({
requestInference(
this.inferenceUrl,
data.messages ?? [],
model,
this.controller,
this.headers()
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
@ -109,8 +114,15 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
onInferenceStopped() {
stopInference() {
this.isCancelled = true
this.controller?.abort()
}
/**
* Headers for the inference request
*/
headers(): HeadersInit {
return {}
}
}

View File

@ -0,0 +1,46 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -12,7 +12,8 @@ export function requestInference(
id: string
parameters: ModelRuntimeParams
},
controller?: AbortController
controller?: AbortController,
headers?: HeadersInit
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
@ -27,6 +28,7 @@ export function requestInference(
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json',
...headers,
},
body: requestBody,
signal: controller?.signal,

View File

@ -1,3 +1,4 @@
export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'

View File

@ -25,7 +25,7 @@
"@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1",
"ulid": "^2.3.0"
"ulidx": "^2.3.0"
},
"engines": {
"node": ">=18.0.0"

View File

@ -1,16 +0,0 @@
declare const MODULE: string
declare const GROQ_DOMAIN: string
declare interface EngineSettings {
full_url?: string
api_key?: string
}
enum GroqChatCompletionModelName {
'mixtral-8x7b-32768' = 'mixtral-8x7b-32768',
'llama2-70b-4096' = 'llama2-70b-4096',
}
declare type GroqModel = Omit<Model, 'id'> & {
id: GroqChatCompletionModelName
}

View File

@ -1,83 +0,0 @@
import { ErrorCode } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: GroqModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
// let model_id: string = model.id
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model.id,
...model.parameters,
})
fetch(`${engine.full_url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
'Access-Control-Allow-Origin': '*',
'Authorization': `Bearer ${engine.api_key}`,
// 'api-key': `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'An error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,218 +7,77 @@
*/
import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events,
fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName,
joinPath,
RemoteOAIEngine,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulid'
import { join } from 'path'
declare const COMPLETION_URL: string
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceGroqExtension extends BaseExtension {
private static readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'groq.json'
export default class JanInferenceGroqExtension extends RemoteOAIEngine {
private readonly _engineDir = 'file://engines'
private readonly _engineMetadataFileName = 'groq.json'
private static _currentModel: GroqModel
inferenceUrl: string = COMPLETION_URL
provider = 'groq'
apiKey = ''
private static _engineSettings: EngineSettings = {
full_url: 'https://api.groq.com/openai/v1/chat/completions',
private _engineSettings = {
full_url: COMPLETION_URL,
api_key: 'gsk-<your key here>',
}
controller = new AbortController()
isCancelled = false
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) {
await fs
.mkdirSync(JanInferenceGroqExtension._engineDir)
.catch((err) => console.debug(err))
super.onLoad()
if (!(await fs.existsSync(this._engineDir))) {
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
}
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceGroqExtension.handleMessageRequest(data, this)
)
events.on(ModelEvent.OnModelInit, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceGroqExtension.handleInferenceStopped(this)
})
this.writeDefaultEngineSettings()
const settingsFilePath = await joinPath([
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName,
this._engineDir,
this._engineMetadataFileName,
])
// Events subscription
events.on(
AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath)
JanInferenceGroqExtension.writeDefaultEngineSettings()
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
}
)
}
/**
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName
)
const engineFile = join(this._engineDir, this._engineMetadataFileName)
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceGroqExtension._engineSettings =
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.full_url
this.apiKey = this._engineSettings.api_key
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2)
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private static async handleModelInit(model: GroqModel) {
if (model.engine !== InferenceEngine.groq) {
return
} else {
JanInferenceGroqExtension._currentModel = model
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: GroqModel) {
if (model.engine !== 'groq') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
private static async handleInferenceStopped(
instance: JanInferenceGroqExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceGroqExtension
) {
if (data.model.engine !== 'groq') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceGroqExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
}

View File

@ -18,7 +18,7 @@ module.exports = {
plugins: [
new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
GROQ_DOMAIN: JSON.stringify('api.groq.com'),
COMPLETION_URL: JSON.stringify('https://api.groq.com/openai/v1/chat/completions'),
}),
],
output: {

View File

@ -0,0 +1,5 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};

View File

@ -7,6 +7,7 @@
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro",
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
@ -15,29 +16,34 @@
"build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish": "run-script-os"
"build:publish": "yarn test && run-script-os"
},
"exports": {
".": "./dist/index.js",
"./main": "./dist/node/index.cjs.js"
},
"devDependencies": {
"@babel/preset-typescript": "^7.24.1",
"@jest/globals": "^29.7.0",
"@rollup/plugin-commonjs": "^25.0.7",
"@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^15.2.3",
"@rollup/plugin-replace": "^5.0.5",
"@types/jest": "^29.5.12",
"@types/node": "^20.11.4",
"@types/os-utils": "^0.0.4",
"@types/tcp-port-used": "^1.0.4",
"cpx": "^1.5.0",
"download-cli": "^1.1.1",
"jest": "^29.7.0",
"rimraf": "^3.0.2",
"rollup": "^2.38.5",
"rollup-plugin-define": "^1.0.1",
"rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6",
"typescript": "^5.3.3",
"@types/os-utils": "^0.0.4",
"@rollup/plugin-replace": "^5.0.5"
"ts-jest": "^29.1.2",
"typescript": "^5.3.3"
},
"dependencies": {
"@janhq/core": "file:../../core",

View File

@ -0,0 +1,6 @@
module.exports = {
presets: [
['@babel/preset-env', { targets: { node: 'current' } }],
'@babel/preset-typescript',
],
}

View File

@ -1,66 +0,0 @@
import { Model } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
inferenceUrl: string,
recentMessages: any[],
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
messages: recentMessages,
model: model.id,
stream: true,
...model.parameters,
})
fetch(inferenceUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,58 +7,31 @@
*/
import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType,
MessageStatus,
ThreadContent,
ThreadMessage,
events,
executeOnMain,
fs,
Model,
joinPath,
InferenceExtension,
log,
InferenceEngine,
MessageEvent,
ModelEvent,
InferenceEvent,
ModelSettingParams,
getJanDataFolderPath,
LocalOAIEngine,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceNitroExtension extends InferenceExtension {
private static readonly _homeDir = 'file://engines'
private static readonly _settingsDir = 'file://settings'
private static readonly _engineMetadataFileName = 'nitro.json'
export default class JanInferenceNitroExtension extends LocalOAIEngine {
nodeModule: string = NODE
provider: string = 'nitro'
models(): Promise<Model[]> {
return Promise.resolve([])
}
/**
* Checking the health for Nitro's process each 5 secs.
*/
private static readonly _intervalHealthCheck = 5 * 1000
private _currentModel: Model | undefined
private _engineSettings: ModelSettingParams = {
ctx_len: 2048,
ngl: 100,
cpu_threads: 1,
cont_batching: false,
embedding: true,
}
controller = new AbortController()
isCancelled = false
/**
* The interval id for the health check. Used to stop the health check.
*/
@ -69,114 +42,30 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
*/
private nitroProcessInfo: any = undefined
private inferenceUrl = ''
/**
* The URL for making inference requests.
*/
inferenceUrl = ''
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) {
try {
await fs.mkdirSync(JanInferenceNitroExtension._homeDir)
} catch (e) {
console.debug(e)
}
}
// init inference url
// @ts-ignore
const electronApi = window?.electronAPI
this.inferenceUrl = INFERENCE_URL
if (!electronApi) {
// If the extension is running in the browser, use the base API URL from the core package.
if (!('electronAPI' in window)) {
this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions`
}
console.debug('Inference url: ', this.inferenceUrl)
if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir)))
await fs.mkdirSync(JanInferenceNitroExtension._settingsDir)
this.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.onMessageRequest(data)
)
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model))
events.on(InferenceEvent.OnInferenceStopped, () =>
this.onInferenceStopped()
)
}
/**
* Stops the model inference.
*/
onUnload(): void {}
private async writeDefaultEngineSettings() {
try {
const engineFile = await joinPath([
JanInferenceNitroExtension._homeDir,
JanInferenceNitroExtension._engineMetadataFileName,
])
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private async onModelInit(model: Model) {
if (model.engine !== InferenceEngine.nitro) return
const modelFolder = await joinPath([
await getJanDataFolderPath(),
'models',
model.id,
])
this._currentModel = model
const nitroInitResult = await executeOnMain(NODE, 'runModel', {
modelFolder,
model,
})
if (nitroInitResult?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: nitroInitResult.error,
})
return
}
events.emit(ModelEvent.OnModelReady, model)
this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck
)
}
private async onModelStop(model: Model) {
if (model.engine !== 'nitro') return
await executeOnMain(NODE, 'stopModel')
events.emit(ModelEvent.OnModelStopped, {})
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
super.onLoad()
}
/**
@ -193,118 +82,24 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
this.nitroProcessInfo = health
}
private async onInferenceStopped() {
this.isCancelled = true
this.controller?.abort()
override loadModel(model: Model): Promise<void> {
if (model.engine !== this.provider) return Promise.resolve()
this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck
)
return super.loadModel(model)
}
/**
* Makes a single response inference request.
* @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response.
*/
async inference(data: MessageRequest): Promise<ThreadMessage> {
const timestamp = Date.now()
const message: ThreadMessage = {
thread_id: data.threadId,
created: timestamp,
updated: timestamp,
status: MessageStatus.Ready,
id: '',
role: ChatCompletionRole.Assistant,
object: 'thread.message',
content: [],
override unloadModel(model: Model): void {
super.unloadModel(model)
if (model.engine && model.engine !== this.provider) return
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return new Promise(async (resolve, reject) => {
if (!this._currentModel) return Promise.reject('No model loaded')
requestInference(
this.inferenceUrl,
data.messages ?? [],
this._currentModel
).subscribe({
next: (_content: any) => {},
complete: async () => {
resolve(message)
},
error: async (err: any) => {
reject(err)
},
})
})
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private async onMessageRequest(data: MessageRequest) {
if (data.model?.engine !== InferenceEngine.nitro || !this._currentModel) {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
this.isCancelled = false
this.controller = new AbortController()
// @ts-ignore
const model: Model = {
...(this._currentModel || {}),
...(data.model || {}),
}
requestInference(
this.inferenceUrl,
data.messages ?? [],
model,
this.controller
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err: any) => {
if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
message.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
log(`[APP]::Error: ${err.message}`)
},
})
}
}

View File

@ -0,0 +1,233 @@
import { describe, expect, it } from '@jest/globals'
import { executableNitroFile } from './execute'
import { GpuSetting } from '@janhq/core'
import { sep } from 'path'
let testSettings: GpuSetting = {
run_mode: 'cpu',
vulkan: false,
cuda: {
exist: false,
version: '11',
},
gpu_highest_vram: '0',
gpus: [],
gpus_in_use: [],
is_initial: false,
notify: true,
nvidia_driver: {
exist: false,
version: '11',
},
}
const originalPlatform = process.platform
describe('test executable nitro file', () => {
afterAll(function () {
Object.defineProperty(process, 'platform', {
value: originalPlatform,
})
})
it('executes on MacOS ARM', () => {
Object.defineProperty(process, 'platform', {
value: 'darwin',
})
Object.defineProperty(process, 'arch', {
value: 'arm64',
})
expect(executableNitroFile(testSettings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`mac-arm64${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on MacOS Intel', () => {
Object.defineProperty(process, 'platform', {
value: 'darwin',
})
Object.defineProperty(process, 'arch', {
value: 'x64',
})
expect(executableNitroFile(testSettings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`mac-x64${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Windows CPU', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'cpu',
cuda: {
exist: true,
version: '11',
},
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cpu${sep}nitro.exe`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Windows Cuda 11', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '11',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cuda-11-7${sep}nitro.exe`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Windows Cuda 12', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '12',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cuda-12-0${sep}nitro.exe`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Linux CPU', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'cpu',
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cpu${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Linux Cuda 11', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '11',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cuda-11-7${sep}nitro`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Linux Cuda 12', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '12',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cuda-12-0${sep}nitro`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
})

View File

@ -1,5 +1,4 @@
import { getJanDataFolderPath } from '@janhq/core/node'
import { readFileSync } from 'fs'
import { GpuSetting, SystemInformation } from '@janhq/core'
import * as path from 'path'
export interface NitroExecutableOptions {
@ -7,79 +6,56 @@ export interface NitroExecutableOptions {
cudaVisibleDevices: string
vkVisibleDevices: string
}
const runMode = (settings?: GpuSetting): string => {
if (process.platform === 'darwin')
// MacOS use arch instead of cpu / cuda
return process.arch === 'arm64' ? 'arm64' : 'x64'
export const GPU_INFO_FILE = path.join(
getJanDataFolderPath(),
'settings',
'settings.json'
)
if (!settings) return 'cpu'
return settings.vulkan === true
? 'vulkan'
: settings.run_mode === 'cpu'
? 'cpu'
: 'cuda'
}
const os = (): string => {
return process.platform === 'win32'
? 'win'
: process.platform === 'darwin'
? 'mac'
: 'linux'
}
const extension = (): '.exe' | '' => {
return process.platform === 'win32' ? '.exe' : ''
}
const cudaVersion = (settings?: GpuSetting): '11-7' | '12-0' | undefined => {
const isUsingCuda =
settings?.vulkan !== true && settings?.run_mode === 'gpu' && os() !== 'mac'
if (!isUsingCuda) return undefined
return settings?.cuda?.version === '11' ? '11-7' : '12-0'
}
/**
* Find which executable file to run based on the current platform.
* @returns The name of the executable file to run.
*/
export const executableNitroFile = (): NitroExecutableOptions => {
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
let cudaVisibleDevices = ''
let vkVisibleDevices = ''
let binaryName = 'nitro'
/**
* The binary folder is different for each platform.
*/
if (process.platform === 'win32') {
/**
* For Windows: win-cpu, win-vulkan, win-cuda-11-7, win-cuda-12-0
*/
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
if (gpuInfo['run_mode'] === 'cpu') {
binaryFolder = path.join(binaryFolder, 'win-cpu')
} else {
if (gpuInfo['cuda']?.version === '11') {
binaryFolder = path.join(binaryFolder, 'win-cuda-11-7')
} else {
binaryFolder = path.join(binaryFolder, 'win-cuda-12-0')
}
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
}
if (gpuInfo['vulkan'] === true) {
binaryFolder = path.join(__dirname, '..', 'bin')
binaryFolder = path.join(binaryFolder, 'win-vulkan')
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
}
binaryName = 'nitro.exe'
} else if (process.platform === 'darwin') {
/**
* For MacOS: mac-arm64 (Silicon), mac-x64 (InteL)
*/
if (process.arch === 'arm64') {
binaryFolder = path.join(binaryFolder, 'mac-arm64')
} else {
binaryFolder = path.join(binaryFolder, 'mac-x64')
}
} else {
/**
* For Linux: linux-cpu, linux-vulkan, linux-cuda-11-7, linux-cuda-12-0
*/
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
if (gpuInfo['run_mode'] === 'cpu') {
binaryFolder = path.join(binaryFolder, 'linux-cpu')
} else {
if (gpuInfo['cuda']?.version === '11') {
binaryFolder = path.join(binaryFolder, 'linux-cuda-11-7')
} else {
binaryFolder = path.join(binaryFolder, 'linux-cuda-12-0')
}
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
}
export const executableNitroFile = (
gpuSetting?: GpuSetting
): NitroExecutableOptions => {
let binaryFolder = [os(), runMode(gpuSetting), cudaVersion(gpuSetting)]
.filter((e) => !!e)
.join('-')
let cudaVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
let vkVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
let binaryName = `nitro${extension()}`
if (gpuInfo['vulkan'] === true) {
binaryFolder = path.join(__dirname, '..', 'bin')
binaryFolder = path.join(binaryFolder, 'linux-vulkan')
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
}
}
return {
executablePath: path.join(binaryFolder, binaryName),
executablePath: path.join(__dirname, '..', 'bin', binaryFolder, binaryName),
cudaVisibleDevices,
vkVisibleDevices,
}

View File

@ -10,6 +10,7 @@ import {
InferenceEngine,
ModelSettingParams,
PromptTemplate,
SystemInformation,
} from '@janhq/core/node'
import { executableNitroFile } from './execute'
@ -51,7 +52,7 @@ let currentSettings: ModelSettingParams | undefined = undefined
* @param wrapper - The model wrapper.
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
*/
function stopModel(): Promise<void> {
function unloadModel(): Promise<void> {
return killSubprocess()
}
@ -61,46 +62,47 @@ function stopModel(): Promise<void> {
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
*/
async function runModel(
wrapper: ModelInitOptions
async function loadModel(
params: ModelInitOptions,
systemInfo?: SystemInformation
): Promise<ModelOperationResponse | void> {
if (wrapper.model.engine !== InferenceEngine.nitro) {
if (params.model.engine !== InferenceEngine.nitro) {
// Not a nitro model
return Promise.resolve()
}
if (wrapper.model.engine !== InferenceEngine.nitro) {
if (params.model.engine !== InferenceEngine.nitro) {
return Promise.reject('Not a nitro model')
} else {
const nitroResourceProbe = await getSystemResourceInfo()
// Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt
if (wrapper.model.settings.prompt_template) {
const promptTemplate = wrapper.model.settings.prompt_template
if (params.model.settings.prompt_template) {
const promptTemplate = params.model.settings.prompt_template
const prompt = promptTemplateConverter(promptTemplate)
if (prompt?.error) {
return Promise.reject(prompt.error)
}
wrapper.model.settings.system_prompt = prompt.system_prompt
wrapper.model.settings.user_prompt = prompt.user_prompt
wrapper.model.settings.ai_prompt = prompt.ai_prompt
params.model.settings.system_prompt = prompt.system_prompt
params.model.settings.user_prompt = prompt.user_prompt
params.model.settings.ai_prompt = prompt.ai_prompt
}
// modelFolder is the absolute path to the running model folder
// e.g. ~/jan/models/llama-2
let modelFolder = wrapper.modelFolder
let modelFolder = params.modelFolder
let llama_model_path = wrapper.model.settings.llama_model_path
let llama_model_path = params.model.settings.llama_model_path
// Absolute model path support
if (
wrapper.model?.sources.length &&
wrapper.model.sources.every((e) => fs.existsSync(e.url))
params.model?.sources.length &&
params.model.sources.every((e) => fs.existsSync(e.url))
) {
llama_model_path =
wrapper.model.sources.length === 1
? wrapper.model.sources[0].url
: wrapper.model.sources.find((e) =>
e.url.includes(llama_model_path ?? wrapper.model.id)
params.model.sources.length === 1
? params.model.sources[0].url
: params.model.sources.find((e) =>
e.url.includes(llama_model_path ?? params.model.id)
)?.url
}
@ -114,7 +116,7 @@ async function runModel(
// 2. Prioritize GGUF File (manual import)
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) ||
// 3. Fallback Model ID (for backward compatibility)
file === wrapper.model.id
file === params.model.id
)
if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile)
}
@ -124,17 +126,17 @@ async function runModel(
if (!llama_model_path) return Promise.reject('No GGUF model file found')
currentSettings = {
...wrapper.model.settings,
...params.model.settings,
llama_model_path,
// This is critical and requires real CPU physical core count (or performance core)
cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore),
...(wrapper.model.settings.mmproj && {
mmproj: path.isAbsolute(wrapper.model.settings.mmproj)
? wrapper.model.settings.mmproj
: path.join(modelFolder, wrapper.model.settings.mmproj),
...(params.model.settings.mmproj && {
mmproj: path.isAbsolute(params.model.settings.mmproj)
? params.model.settings.mmproj
: path.join(modelFolder, params.model.settings.mmproj),
}),
}
return runNitroAndLoadModel()
return runNitroAndLoadModel(systemInfo)
}
}
@ -144,7 +146,7 @@ async function runModel(
* 3. Validate model status
* @returns
*/
async function runNitroAndLoadModel() {
async function runNitroAndLoadModel(systemInfo?: SystemInformation) {
// Gather system information for CPU physical cores and memory
return killSubprocess()
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
@ -160,7 +162,7 @@ async function runNitroAndLoadModel() {
return Promise.resolve()
}
})
.then(spawnNitroProcess)
.then(() => spawnNitroProcess(systemInfo))
.then(() => loadLLMModel(currentSettings))
.then(validateModelStatus)
.catch((err) => {
@ -325,12 +327,12 @@ async function killSubprocess(): Promise<void> {
* Spawns a Nitro subprocess.
* @returns A promise that resolves when the Nitro subprocess is started.
*/
function spawnNitroProcess(): Promise<any> {
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
log(`[NITRO]::Debug: Spawning Nitro subprocess...`)
return new Promise<void>(async (resolve, reject) => {
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
let executableOptions = executableNitroFile()
let executableOptions = executableNitroFile(systemInfo?.gpuSetting)
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary
@ -402,9 +404,8 @@ const getCurrentNitroProcessInfo = (): NitroProcessInfo => {
}
export default {
runModel,
stopModel,
killSubprocess,
loadModel,
unloadModel,
dispose,
getCurrentNitroProcessInfo,
}

View File

@ -1,14 +1,14 @@
# Jan inference plugin
# OpenAI Engine Extension
Created using Jan app example
Created using Jan extension example
# Create a Jan Plugin using Typescript
# Create a Jan Extension using Typescript
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀
## Create Your Own Plugin
## Create Your Own Extension
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
To create your own extension, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository
2. Select Create a new repository
@ -18,7 +18,7 @@ To create your own plugin, you can use this repository as a template! Just follo
## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension.
> [!NOTE]
>
@ -43,36 +43,37 @@ After you've cloned the repository to your local machine or codespace, you'll ne
1. :white_check_mark: Check your artifact
There will be a tgz file in your plugin directory now
There will be a tgz file in your extension directory now
## Update the Plugin Metadata
## Update the Extension Metadata
The [`package.json`](package.json) file defines metadata about your plugin, such as
plugin name, main entry, description and version.
The [`package.json`](package.json) file defines metadata about your extension, such as
extension name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your plugin.
When you copy this repository, update `package.json` with the name, description for your extension.
## Update the Plugin Code
## Update the Extension Code
The [`src/`](./src/) directory is the heart of your plugin! This contains the
source code that will be run when your plugin extension functions are invoked. You can replace the
The [`src/`](./src/) directory is the heart of your extension! This contains the
source code that will be run when your extension functions are invoked. You can replace the
contents of this directory with your own code.
There are a few things to keep in mind when writing your plugin code:
There are a few things to keep in mind when writing your extension code:
- Most Jan Plugin Extension functions are processed asynchronously.
- Most Jan Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript
import { core } from "@janhq/core";
import { events, MessageEvent, MessageRequest } from '@janhq/core'
function onStart(): Promise<any> {
return core.invokePluginFunc(MODULE_PATH, "run", 0);
return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.inference(data)
)
}
```
For more information about the Jan Plugin Core module, see the
For more information about the Jan Extension Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your plugin!
So, what are you waiting for? Go ahead and start customizing your extension!

View File

@ -4,6 +4,7 @@
"description": "This extension enables OpenAI chat completion API calls",
"main": "dist/index.js",
"module": "dist/module.js",
"engine": "openai",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {

View File

@ -1,26 +0,0 @@
declare const MODULE: string
declare const OPENAI_DOMAIN: string
declare interface EngineSettings {
full_url?: string
api_key?: string
}
enum OpenAIChatCompletionModelName {
'gpt-3.5-turbo-instruct' = 'gpt-3.5-turbo-instruct',
'gpt-3.5-turbo-instruct-0914' = 'gpt-3.5-turbo-instruct-0914',
'gpt-4-1106-preview' = 'gpt-4-1106-preview',
'gpt-3.5-turbo-0613' = 'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-0301' = 'gpt-3.5-turbo-0301',
'gpt-3.5-turbo' = 'gpt-3.5-turbo',
'gpt-3.5-turbo-16k-0613' = 'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-1106' = 'gpt-3.5-turbo-1106',
'gpt-4-vision-preview' = 'gpt-4-vision-preview',
'gpt-4' = 'gpt-4',
'gpt-4-0314' = 'gpt-4-0314',
'gpt-4-0613' = 'gpt-4-0613',
}
declare type OpenAIModel = Omit<Model, 'id'> & {
id: OpenAIChatCompletionModelName
}

View File

@ -1,85 +0,0 @@
import { ErrorCode } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: OpenAIModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
let model_id: string = model.id
if (engine.full_url.includes(OPENAI_DOMAIN)) {
model_id = engine.full_url.split('/')[5]
}
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model_id,
...model.parameters,
})
fetch(`${engine.full_url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
'Access-Control-Allow-Origin': '*',
'Authorization': `Bearer ${engine.api_key}`,
'api-key': `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'An error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -5,75 +5,52 @@
* @version 1.0.0
* @module inference-openai-extension/src/index
*/
declare const ENGINE: string
import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events,
fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName,
joinPath,
RemoteOAIEngine,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { join } from 'path'
declare const COMPLETION_URL: string
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceOpenAIExtension extends BaseExtension {
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
private static readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'openai.json'
private static readonly _engineMetadataFileName = `${ENGINE}.json`
private static _currentModel: OpenAIModel
private static _engineSettings: EngineSettings = {
full_url: 'https://api.openai.com/v1/chat/completions',
private _engineSettings = {
full_url: COMPLETION_URL,
api_key: 'sk-<your key here>',
}
controller = new AbortController()
isCancelled = false
inferenceUrl: string = COMPLETION_URL
provider: string = 'openai'
apiKey: string = ''
// TODO: Just use registerSettings from BaseExtension
// Remove these methods
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
super.onLoad()
if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) {
await fs
.mkdirSync(JanInferenceOpenAIExtension._engineDir)
.catch((err) => console.debug(err))
}
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceOpenAIExtension.handleMessageRequest(data, this)
)
events.on(ModelEvent.OnModelInit, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceOpenAIExtension.handleInferenceStopped(this)
})
this.writeDefaultEngineSettings()
const settingsFilePath = await joinPath([
JanInferenceOpenAIExtension._engineDir,
@ -84,18 +61,12 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath)
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
}
)
}
/**
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
async writeDefaultEngineSettings() {
try {
const engineFile = join(
JanInferenceOpenAIExtension._engineDir,
@ -103,122 +74,18 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
)
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceOpenAIExtension._engineSettings =
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.full_url
this.apiKey = this._engineSettings.api_key
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private static async handleModelInit(model: OpenAIModel) {
if (model.engine !== InferenceEngine.openai) {
return
} else {
JanInferenceOpenAIExtension._currentModel = model
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: OpenAIModel) {
if (model.engine !== 'openai') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
private static async handleInferenceStopped(
instance: JanInferenceOpenAIExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceOpenAIExtension
) {
if (data.model.engine !== 'openai') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceOpenAIExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
}

View File

@ -17,8 +17,8 @@ module.exports = {
},
plugins: [
new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
OPENAI_DOMAIN: JSON.stringify('openai.azure.com'),
ENGINE: JSON.stringify(packageJson.engine),
COMPLETION_URL: JSON.stringify('https://api.openai.com/v1/chat/completions'),
}),
],
output: {

View File

@ -1,5 +0,0 @@
import { Model } from '@janhq/core'
declare interface EngineSettings {
base_url?: string
}

View File

@ -1,63 +0,0 @@
import { Observable } from 'rxjs'
import { EngineSettings } from '../@types/global'
import { Model } from '@janhq/core'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const text_input = recentMessages.map((message) => message.text).join('\n')
const requestBody = JSON.stringify({
text_input: text_input,
max_tokens: 4096,
temperature: 0,
bad_words: '',
stop_words: '[DONE]',
stream: true,
})
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
'Access-Control-Allow-Origin': '*',
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
subscriber.next(content)
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,212 +7,76 @@
*/
import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ModelSettingParams,
ThreadContent,
ThreadMessage,
AppConfigurationEventName,
events,
fs,
joinPath,
Model,
BaseExtension,
MessageEvent,
ModelEvent,
RemoteOAIEngine,
} from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { join } from 'path'
import { EngineSettings } from './@types/global'
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceTritonTrtLLMExtension extends BaseExtension {
private static readonly _homeDir = 'file://engines'
private static readonly _engineMetadataFileName = 'triton_trtllm.json'
export default class JanInferenceTritonTrtLLMExtension extends RemoteOAIEngine {
private readonly _engineDir = 'file://engines'
private readonly _engineMetadataFileName = 'triton_trtllm.json'
static _currentModel: Model
inferenceUrl: string = ''
provider: string = 'triton_trtllm'
apiKey: string = ''
static _engineSettings: EngineSettings = {
base_url: '',
_engineSettings: {
base_url: ''
api_key: ''
}
controller = new AbortController()
isCancelled = false
/**
* Subscribes to events emitted by the @janhq/core package.
*/
async onLoad() {
if (!(await fs.existsSync(JanInferenceTritonTrtLLMExtension._homeDir)))
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
super.onLoad()
if (!(await fs.existsSync(this._engineDir))) {
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
}
this.writeDefaultEngineSettings()
const settingsFilePath = await joinPath([
this._engineDir,
this._engineMetadataFileName,
])
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this)
events.on(
AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
}
)
events.on(ModelEvent.OnModelInit, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelStop(model)
})
}
/**
* Stops the model inference.
*/
onUnload(): void {}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
return
}
static async writeDefaultEngineSettings() {
async writeDefaultEngineSettings() {
try {
const engine_json = join(
JanInferenceTritonTrtLLMExtension._homeDir,
JanInferenceTritonTrtLLMExtension._engineMetadataFileName
)
const engine_json = join(this._engineDir, this._engineMetadataFileName)
if (await fs.existsSync(engine_json)) {
const engine = await fs.readFileSync(engine_json, 'utf-8')
JanInferenceTritonTrtLLMExtension._engineSettings =
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.base_url
this.apiKey = this._engineSettings.api_key
} else {
await fs.writeFileSync(
engine_json,
JSON.stringify(
JanInferenceTritonTrtLLMExtension._engineSettings,
null,
2
)
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true
this.controller?.abort()
}
private static async handleModelInit(model: Model) {
if (model.engine !== 'triton_trtllm') {
return
} else {
JanInferenceTritonTrtLLMExtension._currentModel = model
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: Model) {
if (model.engine !== 'triton_trtllm') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceTritonTrtLLMExtension
) {
if (data.model.engine !== 'triton_trtllm') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
events.emit(MessageEvent.OnMessageResponse, message)
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceTritonTrtLLMExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length) {
message.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
}

View File

@ -43,14 +43,14 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
private supportedPlatform = ['win32', 'linux']
private isUpdateAvailable = false
compatibility() {
override compatibility() {
return COMPATIBILITY as unknown as Compatibility
}
/**
* models implemented by the extension
* define pre-populated models
*/
async models(): Promise<Model[]> {
override async models(): Promise<Model[]> {
if ((await this.installationState()) === 'Installed')
return models as unknown as Model[]
return []
@ -160,11 +160,11 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {})
}
async onModelInit(model: Model): Promise<void> {
override async loadModel(model: Model): Promise<void> {
if (model.engine !== this.provider) return
if ((await this.installationState()) === 'Installed')
return super.onModelInit(model)
return super.loadModel(model)
else {
events.emit(ModelEvent.OnModelFail, {
...model,
@ -175,7 +175,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
}
}
updatable() {
override updatable() {
return this.isUpdateAvailable
}
@ -241,8 +241,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled'
}
override onInferenceStopped() {
if (!this.isRunning) return
override stopInference() {
showToast(
'Unable to Stop Inference',
'The model does not support stopping inference.'
@ -250,8 +249,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
return Promise.resolve()
}
inference(data: MessageRequest): void {
if (!this.isRunning) return
override inference(data: MessageRequest): void {
if (!this.loadedModel) return
// TensorRT LLM Extension supports streaming only
if (data.model) data.model.parameters.stream = true
super.inference(data)

View File

@ -1,27 +1,26 @@
{
"sources": [
{
"url": "https://groq.com"
}
],
"id": "llama2-70b-4096",
"object": "model",
"name": "Groq Llama 2 70b",
"version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Meta",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}
"sources": [
{
"url": "https://groq.com"
}
],
"id": "llama2-70b-4096",
"object": "model",
"name": "Groq Llama 2 70b",
"version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Meta",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}

View File

@ -1,27 +1,26 @@
{
"sources": [
{
"url": "https://groq.com"
}
],
"id": "mixtral-8x7b-32768",
"object": "model",
"name": "Groq Mixtral 8x7b Instruct",
"version": "1.0",
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Mistral",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}
"sources": [
{
"url": "https://groq.com"
}
],
"id": "mixtral-8x7b-32768",
"object": "model",
"name": "Groq Mixtral 8x7b Instruct",
"version": "1.0",
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
"format": "api",
"settings": {},
"parameters": {
"max_tokens": 4096,
"temperature": 0.7,
"top_p": 1,
"stop": null,
"stream": true
},
"metadata": {
"author": "Mistral",
"tags": ["General", "Big Context Length"]
},
"engine": "groq"
}

View File

@ -75,12 +75,14 @@ const DropdownListSidebar = ({
// TODO: Update filter condition for the local model
const localModel = downloadedModels.filter(
(model) => model.engine !== InferenceEngine.openai
(model) =>
model.engine === InferenceEngine.nitro ||
model.engine === InferenceEngine.nitro_tensorrt_llm
)
const remoteModel = downloadedModels.filter(
(model) =>
model.engine === InferenceEngine.openai ||
model.engine === InferenceEngine.groq
model.engine !== InferenceEngine.nitro &&
model.engine !== InferenceEngine.nitro_tensorrt_llm
)
const modelOptions = isTabActive === 0 ? localModel : remoteModel

View File

@ -48,9 +48,8 @@ export default function RowModel(props: RowModelProps) {
const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom)
const isRemoteModel =
props.data.engine === InferenceEngine.openai ||
props.data.engine === InferenceEngine.groq ||
props.data.engine === InferenceEngine.triton_trtllm
props.data.engine !== InferenceEngine.nitro &&
props.data.engine !== InferenceEngine.nitro_tensorrt_llm
const onModelActionClick = (modelId: string) => {
if (activeModel && activeModel.id === modelId) {

View File

@ -8,7 +8,6 @@ export const isCoreExtensionInstalled = () => {
if (!extensionManager.get(ExtensionTypeEnum.Conversational)) {
return false
}
if (!extensionManager.get(ExtensionTypeEnum.Inference)) return false
if (!extensionManager.get(ExtensionTypeEnum.Model)) {
return false
}
@ -22,7 +21,6 @@ export const setupBaseExtensions = async () => {
if (
!extensionManager.get(ExtensionTypeEnum.Conversational) ||
!extensionManager.get(ExtensionTypeEnum.Inference) ||
!extensionManager.get(ExtensionTypeEnum.Model)
) {
const installed = await extensionManager.install(baseExtensions)