fix: refactor inference engines to extends AIEngine (#2347)

* fix: refactor nitro to extends localoaiengine

* fix: refactor openai extension

* chore: refactor groq extension

* chore: refactor triton tensorrt extension

* chore: add tests

* chore: refactor engines
This commit is contained in:
Louis 2024-03-22 09:35:14 +07:00 committed by GitHub
parent b8e4a029a4
commit acbec78dbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 625 additions and 1292 deletions

View File

@ -56,7 +56,8 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
* @param paths - The paths to join. * @param paths - The paths to join.
* @returns {Promise<string>} A promise that resolves with the joined path. * @returns {Promise<string>} A promise that resolves with the joined path.
*/ */
const joinPath: (paths: string[]) => Promise<string> = (paths) => globalThis.core.api?.joinPath(paths) const joinPath: (paths: string[]) => Promise<string> = (paths) =>
globalThis.core.api?.joinPath(paths)
/** /**
* Retrive the basename from an url. * Retrive the basename from an url.

View File

@ -14,7 +14,9 @@ export abstract class AIEngine extends BaseExtension {
// The model folder // The model folder
modelFolder: string = 'models' modelFolder: string = 'models'
abstract models(): Promise<Model[]> models(): Promise<Model[]> {
return Promise.resolve([])
}
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.

View File

@ -9,9 +9,9 @@ import { OAIEngine } from './OAIEngine'
*/ */
export abstract class LocalOAIEngine extends OAIEngine { export abstract class LocalOAIEngine extends OAIEngine {
// The inference engine // The inference engine
abstract nodeModule: string
loadModelFunctionName: string = 'loadModel' loadModelFunctionName: string = 'loadModel'
unloadModelFunctionName: string = 'unloadModel' unloadModelFunctionName: string = 'unloadModel'
isRunning: boolean = false
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.
@ -19,22 +19,27 @@ export abstract class LocalOAIEngine extends OAIEngine {
onLoad() { onLoad() {
super.onLoad() super.onLoad()
// These events are applicable to local inference providers // These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model)) events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model)) events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
} }
/** /**
* Load the model. * Load the model.
*/ */
async onModelInit(model: Model) { async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return if (model.engine.toString() !== this.provider) return
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id]) const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
const res = await executeOnMain(this.nodeModule, this.loadModelFunctionName, { const res = await executeOnMain(
modelFolder, this.nodeModule,
model, this.loadModelFunctionName,
}, systemInfo) {
modelFolder,
model,
},
systemInfo
)
if (res?.error) { if (res?.error) {
events.emit(ModelEvent.OnModelFail, { events.emit(ModelEvent.OnModelFail, {
@ -45,16 +50,14 @@ export abstract class LocalOAIEngine extends OAIEngine {
} else { } else {
this.loadedModel = model this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model) events.emit(ModelEvent.OnModelReady, model)
this.isRunning = true
} }
} }
/** /**
* Stops the model. * Stops the model.
*/ */
onModelStop(model: Model) { unloadModel(model: Model) {
if (model.engine?.toString() !== this.provider) return if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
this.isRunning = false
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {}) events.emit(ModelEvent.OnModelStopped, {})

View File

@ -23,7 +23,6 @@ import { events } from '../../events'
export abstract class OAIEngine extends AIEngine { export abstract class OAIEngine extends AIEngine {
// The inference engine // The inference engine
abstract inferenceUrl: string abstract inferenceUrl: string
abstract nodeModule: string
// Controller to handle stop requests // Controller to handle stop requests
controller = new AbortController() controller = new AbortController()
@ -38,7 +37,7 @@ export abstract class OAIEngine extends AIEngine {
onLoad() { onLoad() {
super.onLoad() super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.onInferenceStopped()) events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
} }
/** /**
@ -78,7 +77,13 @@ export abstract class OAIEngine extends AIEngine {
...data.model, ...data.model,
} }
requestInference(this.inferenceUrl, data.messages ?? [], model, this.controller).subscribe({ requestInference(
this.inferenceUrl,
data.messages ?? [],
model,
this.controller,
this.headers()
).subscribe({
next: (content: any) => { next: (content: any) => {
const messageContent: ThreadContent = { const messageContent: ThreadContent = {
type: ContentType.Text, type: ContentType.Text,
@ -109,8 +114,15 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* Stops the inference. * Stops the inference.
*/ */
onInferenceStopped() { stopInference() {
this.isCancelled = true this.isCancelled = true
this.controller?.abort() this.controller?.abort()
} }
/**
* Headers for the inference request
*/
headers(): HeadersInit {
return {}
}
} }

View File

@ -0,0 +1,46 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -12,7 +12,8 @@ export function requestInference(
id: string id: string
parameters: ModelRuntimeParams parameters: ModelRuntimeParams
}, },
controller?: AbortController controller?: AbortController,
headers?: HeadersInit
): Observable<string> { ): Observable<string> {
return new Observable((subscriber) => { return new Observable((subscriber) => {
const requestBody = JSON.stringify({ const requestBody = JSON.stringify({
@ -27,6 +28,7 @@ export function requestInference(
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Origin': '*',
'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json', 'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json',
...headers,
}, },
body: requestBody, body: requestBody,
signal: controller?.signal, signal: controller?.signal,

View File

@ -1,3 +1,4 @@
export * from './AIEngine' export * from './AIEngine'
export * from './OAIEngine' export * from './OAIEngine'
export * from './LocalOAIEngine' export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'

View File

@ -25,7 +25,7 @@
"@janhq/core": "file:../../core", "@janhq/core": "file:../../core",
"fetch-retry": "^5.0.6", "fetch-retry": "^5.0.6",
"path-browserify": "^1.0.1", "path-browserify": "^1.0.1",
"ulid": "^2.3.0" "ulidx": "^2.3.0"
}, },
"engines": { "engines": {
"node": ">=18.0.0" "node": ">=18.0.0"

View File

@ -1,16 +0,0 @@
declare const MODULE: string
declare const GROQ_DOMAIN: string
declare interface EngineSettings {
full_url?: string
api_key?: string
}
enum GroqChatCompletionModelName {
'mixtral-8x7b-32768' = 'mixtral-8x7b-32768',
'llama2-70b-4096' = 'llama2-70b-4096',
}
declare type GroqModel = Omit<Model, 'id'> & {
id: GroqChatCompletionModelName
}

View File

@ -1,83 +0,0 @@
import { ErrorCode } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: GroqModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
// let model_id: string = model.id
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model.id,
...model.parameters,
})
fetch(`${engine.full_url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
'Access-Control-Allow-Origin': '*',
'Authorization': `Bearer ${engine.api_key}`,
// 'api-key': `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'An error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,218 +7,77 @@
*/ */
import { import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events, events,
fs, fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName, AppConfigurationEventName,
joinPath, joinPath,
RemoteOAIEngine,
} from '@janhq/core' } from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulid'
import { join } from 'path' import { join } from 'path'
declare const COMPLETION_URL: string
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceGroqExtension extends BaseExtension { export default class JanInferenceGroqExtension extends RemoteOAIEngine {
private static readonly _engineDir = 'file://engines' private readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'groq.json' private readonly _engineMetadataFileName = 'groq.json'
private static _currentModel: GroqModel inferenceUrl: string = COMPLETION_URL
provider = 'groq'
apiKey = ''
private static _engineSettings: EngineSettings = { private _engineSettings = {
full_url: 'https://api.groq.com/openai/v1/chat/completions', full_url: COMPLETION_URL,
api_key: 'gsk-<your key here>', api_key: 'gsk-<your key here>',
} }
controller = new AbortController()
isCancelled = false
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
async onLoad() { async onLoad() {
if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) { super.onLoad()
await fs
.mkdirSync(JanInferenceGroqExtension._engineDir) if (!(await fs.existsSync(this._engineDir))) {
.catch((err) => console.debug(err)) await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
} }
JanInferenceGroqExtension.writeDefaultEngineSettings() this.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceGroqExtension.handleMessageRequest(data, this)
)
events.on(ModelEvent.OnModelInit, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: GroqModel) => {
JanInferenceGroqExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceGroqExtension.handleInferenceStopped(this)
})
const settingsFilePath = await joinPath([ const settingsFilePath = await joinPath([
JanInferenceGroqExtension._engineDir, this._engineDir,
JanInferenceGroqExtension._engineMetadataFileName, this._engineMetadataFileName,
]) ])
// Events subscription
events.on( events.on(
AppConfigurationEventName.OnConfigurationUpdate, AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => { (settingsKey: string) => {
// Update settings on changes // Update settings on changes
if (settingsKey === settingsFilePath) if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
JanInferenceGroqExtension.writeDefaultEngineSettings()
} }
) )
} }
/** async writeDefaultEngineSettings() {
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
try { try {
const engineFile = join( const engineFile = join(this._engineDir, this._engineMetadataFileName)
JanInferenceGroqExtension._engineDir,
JanInferenceGroqExtension._engineMetadataFileName
)
if (await fs.existsSync(engineFile)) { if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8') const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceGroqExtension._engineSettings = this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine) typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.full_url
this.apiKey = this._engineSettings.api_key
} else { } else {
await fs.writeFileSync( await fs.writeFileSync(
engineFile, engineFile,
JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2) JSON.stringify(this._engineSettings, null, 2)
) )
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
} }
private static async handleModelInit(model: GroqModel) {
if (model.engine !== InferenceEngine.groq) {
return
} else {
JanInferenceGroqExtension._currentModel = model
JanInferenceGroqExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: GroqModel) {
if (model.engine !== 'groq') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
private static async handleInferenceStopped(
instance: JanInferenceGroqExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceGroqExtension
) {
if (data.model.engine !== 'groq') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceGroqExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
} }

View File

@ -18,7 +18,7 @@ module.exports = {
plugins: [ plugins: [
new webpack.DefinePlugin({ new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
GROQ_DOMAIN: JSON.stringify('api.groq.com'), COMPLETION_URL: JSON.stringify('https://api.groq.com/openai/v1/chat/completions'),
}), }),
], ],
output: { output: {

View File

@ -0,0 +1,5 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
};

View File

@ -7,6 +7,7 @@
"author": "Jan <service@jan.ai>", "author": "Jan <service@jan.ai>",
"license": "AGPL-3.0", "license": "AGPL-3.0",
"scripts": { "scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts", "build": "tsc --module commonjs && rollup -c rollup.config.ts",
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro", "downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro",
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro", "downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
@ -15,29 +16,34 @@
"build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install", "build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish": "run-script-os" "build:publish": "yarn test && run-script-os"
}, },
"exports": { "exports": {
".": "./dist/index.js", ".": "./dist/index.js",
"./main": "./dist/node/index.cjs.js" "./main": "./dist/node/index.cjs.js"
}, },
"devDependencies": { "devDependencies": {
"@babel/preset-typescript": "^7.24.1",
"@jest/globals": "^29.7.0",
"@rollup/plugin-commonjs": "^25.0.7", "@rollup/plugin-commonjs": "^25.0.7",
"@rollup/plugin-json": "^6.1.0", "@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^15.2.3", "@rollup/plugin-node-resolve": "^15.2.3",
"@rollup/plugin-replace": "^5.0.5",
"@types/jest": "^29.5.12",
"@types/node": "^20.11.4", "@types/node": "^20.11.4",
"@types/os-utils": "^0.0.4",
"@types/tcp-port-used": "^1.0.4", "@types/tcp-port-used": "^1.0.4",
"cpx": "^1.5.0", "cpx": "^1.5.0",
"download-cli": "^1.1.1", "download-cli": "^1.1.1",
"jest": "^29.7.0",
"rimraf": "^3.0.2", "rimraf": "^3.0.2",
"rollup": "^2.38.5", "rollup": "^2.38.5",
"rollup-plugin-define": "^1.0.1", "rollup-plugin-define": "^1.0.1",
"rollup-plugin-sourcemaps": "^0.6.3", "rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0", "rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6", "run-script-os": "^1.1.6",
"typescript": "^5.3.3", "ts-jest": "^29.1.2",
"@types/os-utils": "^0.0.4", "typescript": "^5.3.3"
"@rollup/plugin-replace": "^5.0.5"
}, },
"dependencies": { "dependencies": {
"@janhq/core": "file:../../core", "@janhq/core": "file:../../core",

View File

@ -0,0 +1,6 @@
module.exports = {
presets: [
['@babel/preset-env', { targets: { node: 'current' } }],
'@babel/preset-typescript',
],
}

View File

@ -1,66 +0,0 @@
import { Model } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
inferenceUrl: string,
recentMessages: any[],
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const requestBody = JSON.stringify({
messages: recentMessages,
model: model.id,
stream: true,
...model.parameters,
})
fetch(inferenceUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Access-Control-Allow-Origin': '*',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,58 +7,31 @@
*/ */
import { import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType,
MessageStatus,
ThreadContent,
ThreadMessage,
events, events,
executeOnMain, executeOnMain,
fs,
Model, Model,
joinPath,
InferenceExtension,
log,
InferenceEngine,
MessageEvent,
ModelEvent, ModelEvent,
InferenceEvent, LocalOAIEngine,
ModelSettingParams,
getJanDataFolderPath,
} from '@janhq/core' } from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceNitroExtension extends InferenceExtension { export default class JanInferenceNitroExtension extends LocalOAIEngine {
private static readonly _homeDir = 'file://engines' nodeModule: string = NODE
private static readonly _settingsDir = 'file://settings' provider: string = 'nitro'
private static readonly _engineMetadataFileName = 'nitro.json'
models(): Promise<Model[]> {
return Promise.resolve([])
}
/** /**
* Checking the health for Nitro's process each 5 secs. * Checking the health for Nitro's process each 5 secs.
*/ */
private static readonly _intervalHealthCheck = 5 * 1000 private static readonly _intervalHealthCheck = 5 * 1000
private _currentModel: Model | undefined
private _engineSettings: ModelSettingParams = {
ctx_len: 2048,
ngl: 100,
cpu_threads: 1,
cont_batching: false,
embedding: true,
}
controller = new AbortController()
isCancelled = false
/** /**
* The interval id for the health check. Used to stop the health check. * The interval id for the health check. Used to stop the health check.
*/ */
@ -69,114 +42,30 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
*/ */
private nitroProcessInfo: any = undefined private nitroProcessInfo: any = undefined
private inferenceUrl = '' /**
* The URL for making inference requests.
*/
inferenceUrl = ''
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
async onLoad() { async onLoad() {
if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) {
try {
await fs.mkdirSync(JanInferenceNitroExtension._homeDir)
} catch (e) {
console.debug(e)
}
}
// init inference url
// @ts-ignore
const electronApi = window?.electronAPI
this.inferenceUrl = INFERENCE_URL this.inferenceUrl = INFERENCE_URL
if (!electronApi) {
// If the extension is running in the browser, use the base API URL from the core package.
if (!('electronAPI' in window)) {
this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions` this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions`
} }
console.debug('Inference url: ', this.inferenceUrl) console.debug('Inference url: ', this.inferenceUrl)
if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir)))
await fs.mkdirSync(JanInferenceNitroExtension._settingsDir)
this.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.onMessageRequest(data)
)
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model))
events.on(InferenceEvent.OnInferenceStopped, () =>
this.onInferenceStopped()
)
}
/**
* Stops the model inference.
*/
onUnload(): void {}
private async writeDefaultEngineSettings() {
try {
const engineFile = await joinPath([
JanInferenceNitroExtension._homeDir,
JanInferenceNitroExtension._engineMetadataFileName,
])
if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8')
this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine)
} else {
await fs.writeFileSync(
engineFile,
JSON.stringify(this._engineSettings, null, 2)
)
}
} catch (err) {
console.error(err)
}
}
private async onModelInit(model: Model) {
if (model.engine !== InferenceEngine.nitro) return
const modelFolder = await joinPath([
await getJanDataFolderPath(),
'models',
model.id,
])
this._currentModel = model
const nitroInitResult = await executeOnMain(NODE, 'runModel', {
modelFolder,
model,
})
if (nitroInitResult?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: nitroInitResult.error,
})
return
}
events.emit(ModelEvent.OnModelReady, model)
this.getNitroProcesHealthIntervalId = setInterval( this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(), () => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck JanInferenceNitroExtension._intervalHealthCheck
) )
}
private async onModelStop(model: Model) { super.onLoad()
if (model.engine !== 'nitro') return
await executeOnMain(NODE, 'stopModel')
events.emit(ModelEvent.OnModelStopped, {})
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
} }
/** /**
@ -193,118 +82,24 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
this.nitroProcessInfo = health this.nitroProcessInfo = health
} }
private async onInferenceStopped() { override loadModel(model: Model): Promise<void> {
this.isCancelled = true if (model.engine !== this.provider) return Promise.resolve()
this.controller?.abort() this.getNitroProcesHealthIntervalId = setInterval(
() => this.periodicallyGetNitroHealth(),
JanInferenceNitroExtension._intervalHealthCheck
)
return super.loadModel(model)
} }
/** override unloadModel(model: Model): void {
* Makes a single response inference request. super.unloadModel(model)
* @param {MessageRequest} data - The data for the inference request.
* @returns {Promise<any>} A promise that resolves with the inference response. if (model.engine && model.engine !== this.provider) return
*/
async inference(data: MessageRequest): Promise<ThreadMessage> { // stop the periocally health check
const timestamp = Date.now() if (this.getNitroProcesHealthIntervalId) {
const message: ThreadMessage = { clearInterval(this.getNitroProcesHealthIntervalId)
thread_id: data.threadId, this.getNitroProcesHealthIntervalId = undefined
created: timestamp,
updated: timestamp,
status: MessageStatus.Ready,
id: '',
role: ChatCompletionRole.Assistant,
object: 'thread.message',
content: [],
} }
return new Promise(async (resolve, reject) => {
if (!this._currentModel) return Promise.reject('No model loaded')
requestInference(
this.inferenceUrl,
data.messages ?? [],
this._currentModel
).subscribe({
next: (_content: any) => {},
complete: async () => {
resolve(message)
},
error: async (err: any) => {
reject(err)
},
})
})
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private async onMessageRequest(data: MessageRequest) {
if (data.model?.engine !== InferenceEngine.nitro || !this._currentModel) {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
this.isCancelled = false
this.controller = new AbortController()
// @ts-ignore
const model: Model = {
...(this._currentModel || {}),
...(data.model || {}),
}
requestInference(
this.inferenceUrl,
data.messages ?? [],
model,
this.controller
).subscribe({
next: (content: any) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err: any) => {
if (this.isCancelled || message.content.length) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
message.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
log(`[APP]::Error: ${err.message}`)
},
})
} }
} }

View File

@ -0,0 +1,233 @@
import { describe, expect, it } from '@jest/globals'
import { executableNitroFile } from './execute'
import { GpuSetting } from '@janhq/core'
import { sep } from 'path'
let testSettings: GpuSetting = {
run_mode: 'cpu',
vulkan: false,
cuda: {
exist: false,
version: '11',
},
gpu_highest_vram: '0',
gpus: [],
gpus_in_use: [],
is_initial: false,
notify: true,
nvidia_driver: {
exist: false,
version: '11',
},
}
const originalPlatform = process.platform
describe('test executable nitro file', () => {
afterAll(function () {
Object.defineProperty(process, 'platform', {
value: originalPlatform,
})
})
it('executes on MacOS ARM', () => {
Object.defineProperty(process, 'platform', {
value: 'darwin',
})
Object.defineProperty(process, 'arch', {
value: 'arm64',
})
expect(executableNitroFile(testSettings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`mac-arm64${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on MacOS Intel', () => {
Object.defineProperty(process, 'platform', {
value: 'darwin',
})
Object.defineProperty(process, 'arch', {
value: 'x64',
})
expect(executableNitroFile(testSettings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`mac-x64${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Windows CPU', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'cpu',
cuda: {
exist: true,
version: '11',
},
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cpu${sep}nitro.exe`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Windows Cuda 11', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '11',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cuda-11-7${sep}nitro.exe`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Windows Cuda 12', () => {
Object.defineProperty(process, 'platform', {
value: 'win32',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '12',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`win-cuda-12-0${sep}nitro.exe`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Linux CPU', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'cpu',
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cpu${sep}nitro`),
cudaVisibleDevices: '',
vkVisibleDevices: '',
})
)
})
it('executes on Linux Cuda 11', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '11',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cuda-11-7${sep}nitro`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
it('executes on Linux Cuda 12', () => {
Object.defineProperty(process, 'platform', {
value: 'linux',
})
const settings: GpuSetting = {
...testSettings,
run_mode: 'gpu',
cuda: {
exist: true,
version: '12',
},
nvidia_driver: {
exist: true,
version: '12',
},
gpus_in_use: ['0'],
gpus: [
{
id: '0',
name: 'NVIDIA GeForce GTX 1080',
vram: '80000000',
},
],
}
expect(executableNitroFile(settings)).toEqual(
expect.objectContaining({
executablePath: expect.stringContaining(`linux-cuda-12-0${sep}nitro`),
cudaVisibleDevices: '0',
vkVisibleDevices: '0',
})
)
})
})

View File

@ -1,5 +1,4 @@
import { getJanDataFolderPath } from '@janhq/core/node' import { GpuSetting, SystemInformation } from '@janhq/core'
import { readFileSync } from 'fs'
import * as path from 'path' import * as path from 'path'
export interface NitroExecutableOptions { export interface NitroExecutableOptions {
@ -7,79 +6,56 @@ export interface NitroExecutableOptions {
cudaVisibleDevices: string cudaVisibleDevices: string
vkVisibleDevices: string vkVisibleDevices: string
} }
const runMode = (settings?: GpuSetting): string => {
if (process.platform === 'darwin')
// MacOS use arch instead of cpu / cuda
return process.arch === 'arm64' ? 'arm64' : 'x64'
export const GPU_INFO_FILE = path.join( if (!settings) return 'cpu'
getJanDataFolderPath(),
'settings', return settings.vulkan === true
'settings.json' ? 'vulkan'
) : settings.run_mode === 'cpu'
? 'cpu'
: 'cuda'
}
const os = (): string => {
return process.platform === 'win32'
? 'win'
: process.platform === 'darwin'
? 'mac'
: 'linux'
}
const extension = (): '.exe' | '' => {
return process.platform === 'win32' ? '.exe' : ''
}
const cudaVersion = (settings?: GpuSetting): '11-7' | '12-0' | undefined => {
const isUsingCuda =
settings?.vulkan !== true && settings?.run_mode === 'gpu' && os() !== 'mac'
if (!isUsingCuda) return undefined
return settings?.cuda?.version === '11' ? '11-7' : '12-0'
}
/** /**
* Find which executable file to run based on the current platform. * Find which executable file to run based on the current platform.
* @returns The name of the executable file to run. * @returns The name of the executable file to run.
*/ */
export const executableNitroFile = (): NitroExecutableOptions => { export const executableNitroFile = (
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default gpuSetting?: GpuSetting
let cudaVisibleDevices = '' ): NitroExecutableOptions => {
let vkVisibleDevices = '' let binaryFolder = [os(), runMode(gpuSetting), cudaVersion(gpuSetting)]
let binaryName = 'nitro' .filter((e) => !!e)
/** .join('-')
* The binary folder is different for each platform. let cudaVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
*/ let vkVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
if (process.platform === 'win32') { let binaryName = `nitro${extension()}`
/**
* For Windows: win-cpu, win-vulkan, win-cuda-11-7, win-cuda-12-0
*/
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
if (gpuInfo['run_mode'] === 'cpu') {
binaryFolder = path.join(binaryFolder, 'win-cpu')
} else {
if (gpuInfo['cuda']?.version === '11') {
binaryFolder = path.join(binaryFolder, 'win-cuda-11-7')
} else {
binaryFolder = path.join(binaryFolder, 'win-cuda-12-0')
}
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
}
if (gpuInfo['vulkan'] === true) {
binaryFolder = path.join(__dirname, '..', 'bin')
binaryFolder = path.join(binaryFolder, 'win-vulkan')
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
}
binaryName = 'nitro.exe'
} else if (process.platform === 'darwin') {
/**
* For MacOS: mac-arm64 (Silicon), mac-x64 (InteL)
*/
if (process.arch === 'arm64') {
binaryFolder = path.join(binaryFolder, 'mac-arm64')
} else {
binaryFolder = path.join(binaryFolder, 'mac-x64')
}
} else {
/**
* For Linux: linux-cpu, linux-vulkan, linux-cuda-11-7, linux-cuda-12-0
*/
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
if (gpuInfo['run_mode'] === 'cpu') {
binaryFolder = path.join(binaryFolder, 'linux-cpu')
} else {
if (gpuInfo['cuda']?.version === '11') {
binaryFolder = path.join(binaryFolder, 'linux-cuda-11-7')
} else {
binaryFolder = path.join(binaryFolder, 'linux-cuda-12-0')
}
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
}
if (gpuInfo['vulkan'] === true) {
binaryFolder = path.join(__dirname, '..', 'bin')
binaryFolder = path.join(binaryFolder, 'linux-vulkan')
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
}
}
return { return {
executablePath: path.join(binaryFolder, binaryName), executablePath: path.join(__dirname, '..', 'bin', binaryFolder, binaryName),
cudaVisibleDevices, cudaVisibleDevices,
vkVisibleDevices, vkVisibleDevices,
} }

View File

@ -10,6 +10,7 @@ import {
InferenceEngine, InferenceEngine,
ModelSettingParams, ModelSettingParams,
PromptTemplate, PromptTemplate,
SystemInformation,
} from '@janhq/core/node' } from '@janhq/core/node'
import { executableNitroFile } from './execute' import { executableNitroFile } from './execute'
@ -51,7 +52,7 @@ let currentSettings: ModelSettingParams | undefined = undefined
* @param wrapper - The model wrapper. * @param wrapper - The model wrapper.
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate. * @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
*/ */
function stopModel(): Promise<void> { function unloadModel(): Promise<void> {
return killSubprocess() return killSubprocess()
} }
@ -61,46 +62,47 @@ function stopModel(): Promise<void> {
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load. * @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package * TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
*/ */
async function runModel( async function loadModel(
wrapper: ModelInitOptions params: ModelInitOptions,
systemInfo?: SystemInformation
): Promise<ModelOperationResponse | void> { ): Promise<ModelOperationResponse | void> {
if (wrapper.model.engine !== InferenceEngine.nitro) { if (params.model.engine !== InferenceEngine.nitro) {
// Not a nitro model // Not a nitro model
return Promise.resolve() return Promise.resolve()
} }
if (wrapper.model.engine !== InferenceEngine.nitro) { if (params.model.engine !== InferenceEngine.nitro) {
return Promise.reject('Not a nitro model') return Promise.reject('Not a nitro model')
} else { } else {
const nitroResourceProbe = await getSystemResourceInfo() const nitroResourceProbe = await getSystemResourceInfo()
// Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt // Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt
if (wrapper.model.settings.prompt_template) { if (params.model.settings.prompt_template) {
const promptTemplate = wrapper.model.settings.prompt_template const promptTemplate = params.model.settings.prompt_template
const prompt = promptTemplateConverter(promptTemplate) const prompt = promptTemplateConverter(promptTemplate)
if (prompt?.error) { if (prompt?.error) {
return Promise.reject(prompt.error) return Promise.reject(prompt.error)
} }
wrapper.model.settings.system_prompt = prompt.system_prompt params.model.settings.system_prompt = prompt.system_prompt
wrapper.model.settings.user_prompt = prompt.user_prompt params.model.settings.user_prompt = prompt.user_prompt
wrapper.model.settings.ai_prompt = prompt.ai_prompt params.model.settings.ai_prompt = prompt.ai_prompt
} }
// modelFolder is the absolute path to the running model folder // modelFolder is the absolute path to the running model folder
// e.g. ~/jan/models/llama-2 // e.g. ~/jan/models/llama-2
let modelFolder = wrapper.modelFolder let modelFolder = params.modelFolder
let llama_model_path = wrapper.model.settings.llama_model_path let llama_model_path = params.model.settings.llama_model_path
// Absolute model path support // Absolute model path support
if ( if (
wrapper.model?.sources.length && params.model?.sources.length &&
wrapper.model.sources.every((e) => fs.existsSync(e.url)) params.model.sources.every((e) => fs.existsSync(e.url))
) { ) {
llama_model_path = llama_model_path =
wrapper.model.sources.length === 1 params.model.sources.length === 1
? wrapper.model.sources[0].url ? params.model.sources[0].url
: wrapper.model.sources.find((e) => : params.model.sources.find((e) =>
e.url.includes(llama_model_path ?? wrapper.model.id) e.url.includes(llama_model_path ?? params.model.id)
)?.url )?.url
} }
@ -114,7 +116,7 @@ async function runModel(
// 2. Prioritize GGUF File (manual import) // 2. Prioritize GGUF File (manual import)
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) || file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) ||
// 3. Fallback Model ID (for backward compatibility) // 3. Fallback Model ID (for backward compatibility)
file === wrapper.model.id file === params.model.id
) )
if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile) if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile)
} }
@ -124,17 +126,17 @@ async function runModel(
if (!llama_model_path) return Promise.reject('No GGUF model file found') if (!llama_model_path) return Promise.reject('No GGUF model file found')
currentSettings = { currentSettings = {
...wrapper.model.settings, ...params.model.settings,
llama_model_path, llama_model_path,
// This is critical and requires real CPU physical core count (or performance core) // This is critical and requires real CPU physical core count (or performance core)
cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore), cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore),
...(wrapper.model.settings.mmproj && { ...(params.model.settings.mmproj && {
mmproj: path.isAbsolute(wrapper.model.settings.mmproj) mmproj: path.isAbsolute(params.model.settings.mmproj)
? wrapper.model.settings.mmproj ? params.model.settings.mmproj
: path.join(modelFolder, wrapper.model.settings.mmproj), : path.join(modelFolder, params.model.settings.mmproj),
}), }),
} }
return runNitroAndLoadModel() return runNitroAndLoadModel(systemInfo)
} }
} }
@ -144,7 +146,7 @@ async function runModel(
* 3. Validate model status * 3. Validate model status
* @returns * @returns
*/ */
async function runNitroAndLoadModel() { async function runNitroAndLoadModel(systemInfo?: SystemInformation) {
// Gather system information for CPU physical cores and memory // Gather system information for CPU physical cores and memory
return killSubprocess() return killSubprocess()
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
@ -160,7 +162,7 @@ async function runNitroAndLoadModel() {
return Promise.resolve() return Promise.resolve()
} }
}) })
.then(spawnNitroProcess) .then(() => spawnNitroProcess(systemInfo))
.then(() => loadLLMModel(currentSettings)) .then(() => loadLLMModel(currentSettings))
.then(validateModelStatus) .then(validateModelStatus)
.catch((err) => { .catch((err) => {
@ -325,12 +327,12 @@ async function killSubprocess(): Promise<void> {
* Spawns a Nitro subprocess. * Spawns a Nitro subprocess.
* @returns A promise that resolves when the Nitro subprocess is started. * @returns A promise that resolves when the Nitro subprocess is started.
*/ */
function spawnNitroProcess(): Promise<any> { function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
log(`[NITRO]::Debug: Spawning Nitro subprocess...`) log(`[NITRO]::Debug: Spawning Nitro subprocess...`)
return new Promise<void>(async (resolve, reject) => { return new Promise<void>(async (resolve, reject) => {
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
let executableOptions = executableNitroFile() let executableOptions = executableNitroFile(systemInfo?.gpuSetting)
const args: string[] = ['1', LOCAL_HOST, PORT.toString()] const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary // Execute the binary
@ -402,9 +404,8 @@ const getCurrentNitroProcessInfo = (): NitroProcessInfo => {
} }
export default { export default {
runModel, loadModel,
stopModel, unloadModel,
killSubprocess,
dispose, dispose,
getCurrentNitroProcessInfo, getCurrentNitroProcessInfo,
} }

View File

@ -1,14 +1,14 @@
# Jan inference plugin # OpenAI Engine Extension
Created using Jan app example Created using Jan extension example
# Create a Jan Plugin using Typescript # Create a Jan Extension using Typescript
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀 Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀
## Create Your Own Plugin ## Create Your Own Extension
To create your own plugin, you can use this repository as a template! Just follow the below instructions: To create your own extension, you can use this repository as a template! Just follow the below instructions:
1. Click the Use this template button at the top of the repository 1. Click the Use this template button at the top of the repository
2. Select Create a new repository 2. Select Create a new repository
@ -18,7 +18,7 @@ To create your own plugin, you can use this repository as a template! Just follo
## Initial Setup ## Initial Setup
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin. After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension.
> [!NOTE] > [!NOTE]
> >
@ -43,36 +43,37 @@ After you've cloned the repository to your local machine or codespace, you'll ne
1. :white_check_mark: Check your artifact 1. :white_check_mark: Check your artifact
There will be a tgz file in your plugin directory now There will be a tgz file in your extension directory now
## Update the Plugin Metadata ## Update the Extension Metadata
The [`package.json`](package.json) file defines metadata about your plugin, such as The [`package.json`](package.json) file defines metadata about your extension, such as
plugin name, main entry, description and version. extension name, main entry, description and version.
When you copy this repository, update `package.json` with the name, description for your plugin. When you copy this repository, update `package.json` with the name, description for your extension.
## Update the Plugin Code ## Update the Extension Code
The [`src/`](./src/) directory is the heart of your plugin! This contains the The [`src/`](./src/) directory is the heart of your extension! This contains the
source code that will be run when your plugin extension functions are invoked. You can replace the source code that will be run when your extension functions are invoked. You can replace the
contents of this directory with your own code. contents of this directory with your own code.
There are a few things to keep in mind when writing your plugin code: There are a few things to keep in mind when writing your extension code:
- Most Jan Plugin Extension functions are processed asynchronously. - Most Jan Extension functions are processed asynchronously.
In `index.ts`, you will see that the extension function will return a `Promise<any>`. In `index.ts`, you will see that the extension function will return a `Promise<any>`.
```typescript ```typescript
import { core } from "@janhq/core"; import { events, MessageEvent, MessageRequest } from '@janhq/core'
function onStart(): Promise<any> { function onStart(): Promise<any> {
return core.invokePluginFunc(MODULE_PATH, "run", 0); return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
this.inference(data)
)
} }
``` ```
For more information about the Jan Plugin Core module, see the For more information about the Jan Extension Core module, see the
[documentation](https://github.com/janhq/jan/blob/main/core/README.md). [documentation](https://github.com/janhq/jan/blob/main/core/README.md).
So, what are you waiting for? Go ahead and start customizing your plugin! So, what are you waiting for? Go ahead and start customizing your extension!

View File

@ -4,6 +4,7 @@
"description": "This extension enables OpenAI chat completion API calls", "description": "This extension enables OpenAI chat completion API calls",
"main": "dist/index.js", "main": "dist/index.js",
"module": "dist/module.js", "module": "dist/module.js",
"engine": "openai",
"author": "Jan <service@jan.ai>", "author": "Jan <service@jan.ai>",
"license": "AGPL-3.0", "license": "AGPL-3.0",
"scripts": { "scripts": {

View File

@ -1,26 +0,0 @@
declare const MODULE: string
declare const OPENAI_DOMAIN: string
declare interface EngineSettings {
full_url?: string
api_key?: string
}
enum OpenAIChatCompletionModelName {
'gpt-3.5-turbo-instruct' = 'gpt-3.5-turbo-instruct',
'gpt-3.5-turbo-instruct-0914' = 'gpt-3.5-turbo-instruct-0914',
'gpt-4-1106-preview' = 'gpt-4-1106-preview',
'gpt-3.5-turbo-0613' = 'gpt-3.5-turbo-0613',
'gpt-3.5-turbo-0301' = 'gpt-3.5-turbo-0301',
'gpt-3.5-turbo' = 'gpt-3.5-turbo',
'gpt-3.5-turbo-16k-0613' = 'gpt-3.5-turbo-16k-0613',
'gpt-3.5-turbo-1106' = 'gpt-3.5-turbo-1106',
'gpt-4-vision-preview' = 'gpt-4-vision-preview',
'gpt-4' = 'gpt-4',
'gpt-4-0314' = 'gpt-4-0314',
'gpt-4-0613' = 'gpt-4-0613',
}
declare type OpenAIModel = Omit<Model, 'id'> & {
id: OpenAIChatCompletionModelName
}

View File

@ -1,85 +0,0 @@
import { ErrorCode } from '@janhq/core'
import { Observable } from 'rxjs'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: OpenAIModel,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
let model_id: string = model.id
if (engine.full_url.includes(OPENAI_DOMAIN)) {
model_id = engine.full_url.split('/')[5]
}
const requestBody = JSON.stringify({
messages: recentMessages,
stream: true,
model: model_id,
...model.parameters,
})
fetch(`${engine.full_url}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': model.parameters.stream
? 'text/event-stream'
: 'application/json',
'Access-Control-Allow-Origin': '*',
'Authorization': `Bearer ${engine.api_key}`,
'api-key': `${engine.api_key}`,
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
if (!response.ok) {
const data = await response.json()
const error = {
message: data.error?.message ?? 'An error occurred.',
code: data.error?.code ?? ErrorCode.Unknown,
}
subscriber.error(error)
subscriber.complete()
return
}
if (model.parameters.stream === false) {
const data = await response.json()
subscriber.next(data.choices[0]?.message?.content ?? '')
} else {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
if (content.startsWith('assistant: ')) {
content = content.replace('assistant: ', '')
}
subscriber.next(content)
}
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -5,75 +5,52 @@
* @version 1.0.0 * @version 1.0.0
* @module inference-openai-extension/src/index * @module inference-openai-extension/src/index
*/ */
declare const ENGINE: string
import { import {
ChatCompletionRole,
ContentType,
MessageRequest,
MessageStatus,
ThreadContent,
ThreadMessage,
events, events,
fs, fs,
InferenceEngine,
BaseExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
InferenceEvent,
AppConfigurationEventName, AppConfigurationEventName,
joinPath, joinPath,
RemoteOAIEngine,
} from '@janhq/core' } from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { join } from 'path' import { join } from 'path'
declare const COMPLETION_URL: string
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceOpenAIExtension extends BaseExtension { export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
private static readonly _engineDir = 'file://engines' private static readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'openai.json' private static readonly _engineMetadataFileName = `${ENGINE}.json`
private static _currentModel: OpenAIModel private _engineSettings = {
full_url: COMPLETION_URL,
private static _engineSettings: EngineSettings = {
full_url: 'https://api.openai.com/v1/chat/completions',
api_key: 'sk-<your key here>', api_key: 'sk-<your key here>',
} }
controller = new AbortController() inferenceUrl: string = COMPLETION_URL
isCancelled = false provider: string = 'openai'
apiKey: string = ''
// TODO: Just use registerSettings from BaseExtension
// Remove these methods
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
async onLoad() { async onLoad() {
super.onLoad()
if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) { if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) {
await fs await fs
.mkdirSync(JanInferenceOpenAIExtension._engineDir) .mkdirSync(JanInferenceOpenAIExtension._engineDir)
.catch((err) => console.debug(err)) .catch((err) => console.debug(err))
} }
JanInferenceOpenAIExtension.writeDefaultEngineSettings() this.writeDefaultEngineSettings()
// Events subscription
events.on(MessageEvent.OnMessageSent, (data) =>
JanInferenceOpenAIExtension.handleMessageRequest(data, this)
)
events.on(ModelEvent.OnModelInit, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: OpenAIModel) => {
JanInferenceOpenAIExtension.handleModelStop(model)
})
events.on(InferenceEvent.OnInferenceStopped, () => {
JanInferenceOpenAIExtension.handleInferenceStopped(this)
})
const settingsFilePath = await joinPath([ const settingsFilePath = await joinPath([
JanInferenceOpenAIExtension._engineDir, JanInferenceOpenAIExtension._engineDir,
@ -84,18 +61,12 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
AppConfigurationEventName.OnConfigurationUpdate, AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => { (settingsKey: string) => {
// Update settings on changes // Update settings on changes
if (settingsKey === settingsFilePath) if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
} }
) )
} }
/** async writeDefaultEngineSettings() {
* Stops the model inference.
*/
onUnload(): void {}
static async writeDefaultEngineSettings() {
try { try {
const engineFile = join( const engineFile = join(
JanInferenceOpenAIExtension._engineDir, JanInferenceOpenAIExtension._engineDir,
@ -103,122 +74,18 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
) )
if (await fs.existsSync(engineFile)) { if (await fs.existsSync(engineFile)) {
const engine = await fs.readFileSync(engineFile, 'utf-8') const engine = await fs.readFileSync(engineFile, 'utf-8')
JanInferenceOpenAIExtension._engineSettings = this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine) typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.full_url
this.apiKey = this._engineSettings.api_key
} else { } else {
await fs.writeFileSync( await fs.writeFileSync(
engineFile, engineFile,
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2) JSON.stringify(this._engineSettings, null, 2)
) )
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
} }
private static async handleModelInit(model: OpenAIModel) {
if (model.engine !== InferenceEngine.openai) {
return
} else {
JanInferenceOpenAIExtension._currentModel = model
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: OpenAIModel) {
if (model.engine !== 'openai') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
private static async handleInferenceStopped(
instance: JanInferenceOpenAIExtension
) {
instance.isCancelled = true
instance.controller?.abort()
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceOpenAIExtension
) {
if (data.model.engine !== 'openai') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message)
}
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceOpenAIExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Stopped
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Error
message.error_code = err.code
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
} }

View File

@ -17,8 +17,8 @@ module.exports = {
}, },
plugins: [ plugins: [
new webpack.DefinePlugin({ new webpack.DefinePlugin({
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`), ENGINE: JSON.stringify(packageJson.engine),
OPENAI_DOMAIN: JSON.stringify('openai.azure.com'), COMPLETION_URL: JSON.stringify('https://api.openai.com/v1/chat/completions'),
}), }),
], ],
output: { output: {

View File

@ -1,5 +0,0 @@
import { Model } from '@janhq/core'
declare interface EngineSettings {
base_url?: string
}

View File

@ -1,63 +0,0 @@
import { Observable } from 'rxjs'
import { EngineSettings } from '../@types/global'
import { Model } from '@janhq/core'
/**
* Sends a request to the inference server to generate a response based on the recent messages.
* @param recentMessages - An array of recent messages to use as context for the inference.
* @param engine - The engine settings to use for the inference.
* @param model - The model to use for the inference.
* @returns An Observable that emits the generated response as a string.
*/
export function requestInference(
recentMessages: any[],
engine: EngineSettings,
model: Model,
controller?: AbortController
): Observable<string> {
return new Observable((subscriber) => {
const text_input = recentMessages.map((message) => message.text).join('\n')
const requestBody = JSON.stringify({
text_input: text_input,
max_tokens: 4096,
temperature: 0,
bad_words: '',
stop_words: '[DONE]',
stream: true,
})
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
'Access-Control-Allow-Origin': '*',
},
body: requestBody,
signal: controller?.signal,
})
.then(async (response) => {
const stream = response.body
const decoder = new TextDecoder('utf-8')
const reader = stream?.getReader()
let content = ''
while (true && reader) {
const { done, value } = await reader.read()
if (done) {
break
}
const text = decoder.decode(value)
const lines = text.trim().split('\n')
for (const line of lines) {
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
const data = JSON.parse(line.replace('data: ', ''))
content += data.choices[0]?.delta?.content ?? ''
subscriber.next(content)
}
}
}
subscriber.complete()
})
.catch((err) => subscriber.error(err))
})
}

View File

@ -7,212 +7,76 @@
*/ */
import { import {
ChatCompletionRole, AppConfigurationEventName,
ContentType,
MessageRequest,
MessageStatus,
ModelSettingParams,
ThreadContent,
ThreadMessage,
events, events,
fs, fs,
joinPath,
Model, Model,
BaseExtension, RemoteOAIEngine,
MessageEvent,
ModelEvent,
} from '@janhq/core' } from '@janhq/core'
import { requestInference } from './helpers/sse'
import { ulid } from 'ulidx'
import { join } from 'path' import { join } from 'path'
import { EngineSettings } from './@types/global'
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class JanInferenceTritonTrtLLMExtension extends BaseExtension { export default class JanInferenceTritonTrtLLMExtension extends RemoteOAIEngine {
private static readonly _homeDir = 'file://engines' private readonly _engineDir = 'file://engines'
private static readonly _engineMetadataFileName = 'triton_trtllm.json' private readonly _engineMetadataFileName = 'triton_trtllm.json'
static _currentModel: Model inferenceUrl: string = ''
provider: string = 'triton_trtllm'
apiKey: string = ''
static _engineSettings: EngineSettings = { _engineSettings: {
base_url: '', base_url: ''
api_key: ''
} }
controller = new AbortController()
isCancelled = false
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
async onLoad() { async onLoad() {
if (!(await fs.existsSync(JanInferenceTritonTrtLLMExtension._homeDir))) super.onLoad()
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings() if (!(await fs.existsSync(this._engineDir))) {
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
}
this.writeDefaultEngineSettings()
const settingsFilePath = await joinPath([
this._engineDir,
this._engineMetadataFileName,
])
// Events subscription // Events subscription
events.on(MessageEvent.OnMessageSent, (data) => events.on(
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this) AppConfigurationEventName.OnConfigurationUpdate,
(settingsKey: string) => {
// Update settings on changes
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
}
) )
events.on(ModelEvent.OnModelInit, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelInit(model)
})
events.on(ModelEvent.OnModelStop, (model: Model) => {
JanInferenceTritonTrtLLMExtension.handleModelStop(model)
})
} }
/** async writeDefaultEngineSettings() {
* Stops the model inference.
*/
onUnload(): void {}
/**
* Initializes the model with the specified file name.
* @param {string} modelId - The ID of the model to initialize.
* @returns {Promise<void>} A promise that resolves when the model is initialized.
*/
async initModel(
modelId: string,
settings?: ModelSettingParams
): Promise<void> {
return
}
static async writeDefaultEngineSettings() {
try { try {
const engine_json = join( const engine_json = join(this._engineDir, this._engineMetadataFileName)
JanInferenceTritonTrtLLMExtension._homeDir,
JanInferenceTritonTrtLLMExtension._engineMetadataFileName
)
if (await fs.existsSync(engine_json)) { if (await fs.existsSync(engine_json)) {
const engine = await fs.readFileSync(engine_json, 'utf-8') const engine = await fs.readFileSync(engine_json, 'utf-8')
JanInferenceTritonTrtLLMExtension._engineSettings = this._engineSettings =
typeof engine === 'object' ? engine : JSON.parse(engine) typeof engine === 'object' ? engine : JSON.parse(engine)
this.inferenceUrl = this._engineSettings.base_url
this.apiKey = this._engineSettings.api_key
} else { } else {
await fs.writeFileSync( await fs.writeFileSync(
engine_json, engine_json,
JSON.stringify( JSON.stringify(this._engineSettings, null, 2)
JanInferenceTritonTrtLLMExtension._engineSettings,
null,
2
)
) )
} }
} catch (err) { } catch (err) {
console.error(err) console.error(err)
} }
} }
/**
* Stops the model.
* @returns {Promise<void>} A promise that resolves when the model is stopped.
*/
async stopModel(): Promise<void> {}
/**
* Stops streaming inference.
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
*/
async stopInference(): Promise<void> {
this.isCancelled = true
this.controller?.abort()
}
private static async handleModelInit(model: Model) {
if (model.engine !== 'triton_trtllm') {
return
} else {
JanInferenceTritonTrtLLMExtension._currentModel = model
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
// Todo: Check model list with API key
events.emit(ModelEvent.OnModelReady, model)
}
}
private static async handleModelStop(model: Model) {
if (model.engine !== 'triton_trtllm') {
return
}
events.emit(ModelEvent.OnModelStopped, model)
}
/**
* Handles a new message request by making an inference request and emitting events.
* Function registered in event manager, should be static to avoid binding issues.
* Pass instance as a reference.
* @param {MessageRequest} data - The data for the new message request.
*/
private static async handleMessageRequest(
data: MessageRequest,
instance: JanInferenceTritonTrtLLMExtension
) {
if (data.model.engine !== 'triton_trtllm') {
return
}
const timestamp = Date.now()
const message: ThreadMessage = {
id: ulid(),
thread_id: data.threadId,
assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant,
content: [],
status: MessageStatus.Pending,
created: timestamp,
updated: timestamp,
object: 'thread.message',
}
events.emit(MessageEvent.OnMessageResponse, message)
instance.isCancelled = false
instance.controller = new AbortController()
requestInference(
data?.messages ?? [],
this._engineSettings,
{
...JanInferenceTritonTrtLLMExtension._currentModel,
parameters: data.model.parameters,
},
instance.controller
).subscribe({
next: (content) => {
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: content.trim(),
annotations: [],
},
}
message.content = [messageContent]
events.emit(MessageEvent.OnMessageUpdate, message)
},
complete: async () => {
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
},
error: async (err) => {
if (instance.isCancelled || message.content.length) {
message.status = MessageStatus.Error
events.emit(MessageEvent.OnMessageUpdate, message)
return
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {
value: 'An error occurred. ' + err.message,
annotations: [],
},
}
message.content = [messageContent]
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
},
})
}
} }

View File

@ -43,14 +43,14 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
private supportedPlatform = ['win32', 'linux'] private supportedPlatform = ['win32', 'linux']
private isUpdateAvailable = false private isUpdateAvailable = false
compatibility() { override compatibility() {
return COMPATIBILITY as unknown as Compatibility return COMPATIBILITY as unknown as Compatibility
} }
/** /**
* models implemented by the extension * models implemented by the extension
* define pre-populated models * define pre-populated models
*/ */
async models(): Promise<Model[]> { override async models(): Promise<Model[]> {
if ((await this.installationState()) === 'Installed') if ((await this.installationState()) === 'Installed')
return models as unknown as Model[] return models as unknown as Model[]
return [] return []
@ -160,11 +160,11 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
events.emit(ModelEvent.OnModelsUpdate, {}) events.emit(ModelEvent.OnModelsUpdate, {})
} }
async onModelInit(model: Model): Promise<void> { override async loadModel(model: Model): Promise<void> {
if (model.engine !== this.provider) return if (model.engine !== this.provider) return
if ((await this.installationState()) === 'Installed') if ((await this.installationState()) === 'Installed')
return super.onModelInit(model) return super.loadModel(model)
else { else {
events.emit(ModelEvent.OnModelFail, { events.emit(ModelEvent.OnModelFail, {
...model, ...model,
@ -175,7 +175,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
} }
} }
updatable() { override updatable() {
return this.isUpdateAvailable return this.isUpdateAvailable
} }
@ -241,8 +241,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled' return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled'
} }
override onInferenceStopped() { override stopInference() {
if (!this.isRunning) return
showToast( showToast(
'Unable to Stop Inference', 'Unable to Stop Inference',
'The model does not support stopping inference.' 'The model does not support stopping inference.'
@ -250,8 +249,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
return Promise.resolve() return Promise.resolve()
} }
inference(data: MessageRequest): void { override inference(data: MessageRequest): void {
if (!this.isRunning) return if (!this.loadedModel) return
// TensorRT LLM Extension supports streaming only // TensorRT LLM Extension supports streaming only
if (data.model) data.model.parameters.stream = true if (data.model) data.model.parameters.stream = true
super.inference(data) super.inference(data)

View File

@ -1,27 +1,26 @@
{ {
"sources": [ "sources": [
{ {
"url": "https://groq.com" "url": "https://groq.com"
} }
], ],
"id": "llama2-70b-4096", "id": "llama2-70b-4096",
"object": "model", "object": "model",
"name": "Groq Llama 2 70b", "name": "Groq Llama 2 70b",
"version": "1.0", "version": "1.0",
"description": "Groq Llama 2 70b with supercharged speed!", "description": "Groq Llama 2 70b with supercharged speed!",
"format": "api", "format": "api",
"settings": {}, "settings": {},
"parameters": { "parameters": {
"max_tokens": 4096, "max_tokens": 4096,
"temperature": 0.7, "temperature": 0.7,
"top_p": 1, "top_p": 1,
"stop": null, "stop": null,
"stream": true "stream": true
}, },
"metadata": { "metadata": {
"author": "Meta", "author": "Meta",
"tags": ["General", "Big Context Length"] "tags": ["General", "Big Context Length"]
}, },
"engine": "groq" "engine": "groq"
} }

View File

@ -1,27 +1,26 @@
{ {
"sources": [ "sources": [
{ {
"url": "https://groq.com" "url": "https://groq.com"
} }
], ],
"id": "mixtral-8x7b-32768", "id": "mixtral-8x7b-32768",
"object": "model", "object": "model",
"name": "Groq Mixtral 8x7b Instruct", "name": "Groq Mixtral 8x7b Instruct",
"version": "1.0", "version": "1.0",
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!", "description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
"format": "api", "format": "api",
"settings": {}, "settings": {},
"parameters": { "parameters": {
"max_tokens": 4096, "max_tokens": 4096,
"temperature": 0.7, "temperature": 0.7,
"top_p": 1, "top_p": 1,
"stop": null, "stop": null,
"stream": true "stream": true
}, },
"metadata": { "metadata": {
"author": "Mistral", "author": "Mistral",
"tags": ["General", "Big Context Length"] "tags": ["General", "Big Context Length"]
}, },
"engine": "groq" "engine": "groq"
} }

View File

@ -75,12 +75,14 @@ const DropdownListSidebar = ({
// TODO: Update filter condition for the local model // TODO: Update filter condition for the local model
const localModel = downloadedModels.filter( const localModel = downloadedModels.filter(
(model) => model.engine !== InferenceEngine.openai (model) =>
model.engine === InferenceEngine.nitro ||
model.engine === InferenceEngine.nitro_tensorrt_llm
) )
const remoteModel = downloadedModels.filter( const remoteModel = downloadedModels.filter(
(model) => (model) =>
model.engine === InferenceEngine.openai || model.engine !== InferenceEngine.nitro &&
model.engine === InferenceEngine.groq model.engine !== InferenceEngine.nitro_tensorrt_llm
) )
const modelOptions = isTabActive === 0 ? localModel : remoteModel const modelOptions = isTabActive === 0 ? localModel : remoteModel

View File

@ -48,9 +48,8 @@ export default function RowModel(props: RowModelProps) {
const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom) const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom)
const isRemoteModel = const isRemoteModel =
props.data.engine === InferenceEngine.openai || props.data.engine !== InferenceEngine.nitro &&
props.data.engine === InferenceEngine.groq || props.data.engine !== InferenceEngine.nitro_tensorrt_llm
props.data.engine === InferenceEngine.triton_trtllm
const onModelActionClick = (modelId: string) => { const onModelActionClick = (modelId: string) => {
if (activeModel && activeModel.id === modelId) { if (activeModel && activeModel.id === modelId) {

View File

@ -8,7 +8,6 @@ export const isCoreExtensionInstalled = () => {
if (!extensionManager.get(ExtensionTypeEnum.Conversational)) { if (!extensionManager.get(ExtensionTypeEnum.Conversational)) {
return false return false
} }
if (!extensionManager.get(ExtensionTypeEnum.Inference)) return false
if (!extensionManager.get(ExtensionTypeEnum.Model)) { if (!extensionManager.get(ExtensionTypeEnum.Model)) {
return false return false
} }
@ -22,7 +21,6 @@ export const setupBaseExtensions = async () => {
if ( if (
!extensionManager.get(ExtensionTypeEnum.Conversational) || !extensionManager.get(ExtensionTypeEnum.Conversational) ||
!extensionManager.get(ExtensionTypeEnum.Inference) ||
!extensionManager.get(ExtensionTypeEnum.Model) !extensionManager.get(ExtensionTypeEnum.Model)
) { ) {
const installed = await extensionManager.install(baseExtensions) const installed = await extensionManager.install(baseExtensions)