fix: refactor inference engines to extends AIEngine (#2347)
* fix: refactor nitro to extends localoaiengine * fix: refactor openai extension * chore: refactor groq extension * chore: refactor triton tensorrt extension * chore: add tests * chore: refactor engines
This commit is contained in:
parent
b8e4a029a4
commit
acbec78dbf
@ -56,7 +56,8 @@ const openFileExplorer: (path: string) => Promise<any> = (path) =>
|
|||||||
* @param paths - The paths to join.
|
* @param paths - The paths to join.
|
||||||
* @returns {Promise<string>} A promise that resolves with the joined path.
|
* @returns {Promise<string>} A promise that resolves with the joined path.
|
||||||
*/
|
*/
|
||||||
const joinPath: (paths: string[]) => Promise<string> = (paths) => globalThis.core.api?.joinPath(paths)
|
const joinPath: (paths: string[]) => Promise<string> = (paths) =>
|
||||||
|
globalThis.core.api?.joinPath(paths)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Retrive the basename from an url.
|
* Retrive the basename from an url.
|
||||||
|
|||||||
@ -14,7 +14,9 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
// The model folder
|
// The model folder
|
||||||
modelFolder: string = 'models'
|
modelFolder: string = 'models'
|
||||||
|
|
||||||
abstract models(): Promise<Model[]>
|
models(): Promise<Model[]> {
|
||||||
|
return Promise.resolve([])
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* On extension load, subscribe to events.
|
||||||
|
|||||||
@ -9,9 +9,9 @@ import { OAIEngine } from './OAIEngine'
|
|||||||
*/
|
*/
|
||||||
export abstract class LocalOAIEngine extends OAIEngine {
|
export abstract class LocalOAIEngine extends OAIEngine {
|
||||||
// The inference engine
|
// The inference engine
|
||||||
|
abstract nodeModule: string
|
||||||
loadModelFunctionName: string = 'loadModel'
|
loadModelFunctionName: string = 'loadModel'
|
||||||
unloadModelFunctionName: string = 'unloadModel'
|
unloadModelFunctionName: string = 'unloadModel'
|
||||||
isRunning: boolean = false
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* On extension load, subscribe to events.
|
||||||
@ -19,22 +19,27 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
onLoad() {
|
onLoad() {
|
||||||
super.onLoad()
|
super.onLoad()
|
||||||
// These events are applicable to local inference providers
|
// These events are applicable to local inference providers
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model))
|
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model))
|
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the model.
|
* Load the model.
|
||||||
*/
|
*/
|
||||||
async onModelInit(model: Model) {
|
async loadModel(model: Model) {
|
||||||
if (model.engine.toString() !== this.provider) return
|
if (model.engine.toString() !== this.provider) return
|
||||||
|
|
||||||
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
|
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
|
||||||
const systemInfo = await systemInformation()
|
const systemInfo = await systemInformation()
|
||||||
const res = await executeOnMain(this.nodeModule, this.loadModelFunctionName, {
|
const res = await executeOnMain(
|
||||||
modelFolder,
|
this.nodeModule,
|
||||||
model,
|
this.loadModelFunctionName,
|
||||||
}, systemInfo)
|
{
|
||||||
|
modelFolder,
|
||||||
|
model,
|
||||||
|
},
|
||||||
|
systemInfo
|
||||||
|
)
|
||||||
|
|
||||||
if (res?.error) {
|
if (res?.error) {
|
||||||
events.emit(ModelEvent.OnModelFail, {
|
events.emit(ModelEvent.OnModelFail, {
|
||||||
@ -45,16 +50,14 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
} else {
|
} else {
|
||||||
this.loadedModel = model
|
this.loadedModel = model
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
this.isRunning = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Stops the model.
|
* Stops the model.
|
||||||
*/
|
*/
|
||||||
onModelStop(model: Model) {
|
unloadModel(model: Model) {
|
||||||
if (model.engine?.toString() !== this.provider) return
|
if (model.engine && model.engine?.toString() !== this.provider) return
|
||||||
|
this.loadedModel = undefined
|
||||||
this.isRunning = false
|
|
||||||
|
|
||||||
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||||
events.emit(ModelEvent.OnModelStopped, {})
|
events.emit(ModelEvent.OnModelStopped, {})
|
||||||
|
|||||||
@ -23,7 +23,6 @@ import { events } from '../../events'
|
|||||||
export abstract class OAIEngine extends AIEngine {
|
export abstract class OAIEngine extends AIEngine {
|
||||||
// The inference engine
|
// The inference engine
|
||||||
abstract inferenceUrl: string
|
abstract inferenceUrl: string
|
||||||
abstract nodeModule: string
|
|
||||||
|
|
||||||
// Controller to handle stop requests
|
// Controller to handle stop requests
|
||||||
controller = new AbortController()
|
controller = new AbortController()
|
||||||
@ -38,7 +37,7 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
onLoad() {
|
onLoad() {
|
||||||
super.onLoad()
|
super.onLoad()
|
||||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () => this.onInferenceStopped())
|
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -78,7 +77,13 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
...data.model,
|
...data.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
requestInference(this.inferenceUrl, data.messages ?? [], model, this.controller).subscribe({
|
requestInference(
|
||||||
|
this.inferenceUrl,
|
||||||
|
data.messages ?? [],
|
||||||
|
model,
|
||||||
|
this.controller,
|
||||||
|
this.headers()
|
||||||
|
).subscribe({
|
||||||
next: (content: any) => {
|
next: (content: any) => {
|
||||||
const messageContent: ThreadContent = {
|
const messageContent: ThreadContent = {
|
||||||
type: ContentType.Text,
|
type: ContentType.Text,
|
||||||
@ -109,8 +114,15 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
/**
|
/**
|
||||||
* Stops the inference.
|
* Stops the inference.
|
||||||
*/
|
*/
|
||||||
onInferenceStopped() {
|
stopInference() {
|
||||||
this.isCancelled = true
|
this.isCancelled = true
|
||||||
this.controller?.abort()
|
this.controller?.abort()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Headers for the inference request
|
||||||
|
*/
|
||||||
|
headers(): HeadersInit {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
46
core/src/extensions/ai-engines/RemoteOAIEngine.ts
Normal file
46
core/src/extensions/ai-engines/RemoteOAIEngine.ts
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import { events } from '../../events'
|
||||||
|
import { Model, ModelEvent } from '../../types'
|
||||||
|
import { OAIEngine } from './OAIEngine'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base OAI Remote Inference Provider
|
||||||
|
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||||
|
*/
|
||||||
|
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||||
|
// The inference engine
|
||||||
|
abstract apiKey: string
|
||||||
|
/**
|
||||||
|
* On extension load, subscribe to events.
|
||||||
|
*/
|
||||||
|
onLoad() {
|
||||||
|
super.onLoad()
|
||||||
|
// These events are applicable to local inference providers
|
||||||
|
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||||
|
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load the model.
|
||||||
|
*/
|
||||||
|
async loadModel(model: Model) {
|
||||||
|
if (model.engine.toString() !== this.provider) return
|
||||||
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Stops the model.
|
||||||
|
*/
|
||||||
|
unloadModel(model: Model) {
|
||||||
|
if (model.engine && model.engine.toString() !== this.provider) return
|
||||||
|
events.emit(ModelEvent.OnModelStopped, {})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Headers for the inference request
|
||||||
|
*/
|
||||||
|
override headers(): HeadersInit {
|
||||||
|
return {
|
||||||
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
|
'api-key': `${this.apiKey}`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -12,7 +12,8 @@ export function requestInference(
|
|||||||
id: string
|
id: string
|
||||||
parameters: ModelRuntimeParams
|
parameters: ModelRuntimeParams
|
||||||
},
|
},
|
||||||
controller?: AbortController
|
controller?: AbortController,
|
||||||
|
headers?: HeadersInit
|
||||||
): Observable<string> {
|
): Observable<string> {
|
||||||
return new Observable((subscriber) => {
|
return new Observable((subscriber) => {
|
||||||
const requestBody = JSON.stringify({
|
const requestBody = JSON.stringify({
|
||||||
@ -27,6 +28,7 @@ export function requestInference(
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Access-Control-Allow-Origin': '*',
|
'Access-Control-Allow-Origin': '*',
|
||||||
'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json',
|
'Accept': model.parameters.stream ? 'text/event-stream' : 'application/json',
|
||||||
|
...headers,
|
||||||
},
|
},
|
||||||
body: requestBody,
|
body: requestBody,
|
||||||
signal: controller?.signal,
|
signal: controller?.signal,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
export * from './AIEngine'
|
export * from './AIEngine'
|
||||||
export * from './OAIEngine'
|
export * from './OAIEngine'
|
||||||
export * from './LocalOAIEngine'
|
export * from './LocalOAIEngine'
|
||||||
|
export * from './RemoteOAIEngine'
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
"@janhq/core": "file:../../core",
|
"@janhq/core": "file:../../core",
|
||||||
"fetch-retry": "^5.0.6",
|
"fetch-retry": "^5.0.6",
|
||||||
"path-browserify": "^1.0.1",
|
"path-browserify": "^1.0.1",
|
||||||
"ulid": "^2.3.0"
|
"ulidx": "^2.3.0"
|
||||||
},
|
},
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=18.0.0"
|
"node": ">=18.0.0"
|
||||||
|
|||||||
@ -1,16 +0,0 @@
|
|||||||
declare const MODULE: string
|
|
||||||
declare const GROQ_DOMAIN: string
|
|
||||||
|
|
||||||
declare interface EngineSettings {
|
|
||||||
full_url?: string
|
|
||||||
api_key?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
enum GroqChatCompletionModelName {
|
|
||||||
'mixtral-8x7b-32768' = 'mixtral-8x7b-32768',
|
|
||||||
'llama2-70b-4096' = 'llama2-70b-4096',
|
|
||||||
}
|
|
||||||
|
|
||||||
declare type GroqModel = Omit<Model, 'id'> & {
|
|
||||||
id: GroqChatCompletionModelName
|
|
||||||
}
|
|
||||||
@ -1,83 +0,0 @@
|
|||||||
import { ErrorCode } from '@janhq/core'
|
|
||||||
import { Observable } from 'rxjs'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
|
||||||
* @param engine - The engine settings to use for the inference.
|
|
||||||
* @param model - The model to use for the inference.
|
|
||||||
* @returns An Observable that emits the generated response as a string.
|
|
||||||
*/
|
|
||||||
export function requestInference(
|
|
||||||
recentMessages: any[],
|
|
||||||
engine: EngineSettings,
|
|
||||||
model: GroqModel,
|
|
||||||
controller?: AbortController
|
|
||||||
): Observable<string> {
|
|
||||||
return new Observable((subscriber) => {
|
|
||||||
// let model_id: string = model.id
|
|
||||||
|
|
||||||
const requestBody = JSON.stringify({
|
|
||||||
messages: recentMessages,
|
|
||||||
stream: true,
|
|
||||||
model: model.id,
|
|
||||||
...model.parameters,
|
|
||||||
})
|
|
||||||
fetch(`${engine.full_url}`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': model.parameters.stream
|
|
||||||
? 'text/event-stream'
|
|
||||||
: 'application/json',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'Authorization': `Bearer ${engine.api_key}`,
|
|
||||||
// 'api-key': `${engine.api_key}`,
|
|
||||||
},
|
|
||||||
body: requestBody,
|
|
||||||
signal: controller?.signal,
|
|
||||||
})
|
|
||||||
.then(async (response) => {
|
|
||||||
if (!response.ok) {
|
|
||||||
const data = await response.json()
|
|
||||||
const error = {
|
|
||||||
message: data.error?.message ?? 'An error occurred.',
|
|
||||||
code: data.error?.code ?? ErrorCode.Unknown,
|
|
||||||
}
|
|
||||||
subscriber.error(error)
|
|
||||||
subscriber.complete()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (model.parameters.stream === false) {
|
|
||||||
const data = await response.json()
|
|
||||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
|
||||||
} else {
|
|
||||||
const stream = response.body
|
|
||||||
const decoder = new TextDecoder('utf-8')
|
|
||||||
const reader = stream?.getReader()
|
|
||||||
let content = ''
|
|
||||||
|
|
||||||
while (true && reader) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const text = decoder.decode(value)
|
|
||||||
const lines = text.trim().split('\n')
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
|
|
||||||
const data = JSON.parse(line.replace('data: ', ''))
|
|
||||||
content += data.choices[0]?.delta?.content ?? ''
|
|
||||||
if (content.startsWith('assistant: ')) {
|
|
||||||
content = content.replace('assistant: ', '')
|
|
||||||
}
|
|
||||||
subscriber.next(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
subscriber.complete()
|
|
||||||
})
|
|
||||||
.catch((err) => subscriber.error(err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -7,218 +7,77 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatCompletionRole,
|
|
||||||
ContentType,
|
|
||||||
MessageRequest,
|
|
||||||
MessageStatus,
|
|
||||||
ThreadContent,
|
|
||||||
ThreadMessage,
|
|
||||||
events,
|
events,
|
||||||
fs,
|
fs,
|
||||||
InferenceEngine,
|
|
||||||
BaseExtension,
|
|
||||||
MessageEvent,
|
|
||||||
MessageRequestType,
|
|
||||||
ModelEvent,
|
|
||||||
InferenceEvent,
|
|
||||||
AppConfigurationEventName,
|
AppConfigurationEventName,
|
||||||
joinPath,
|
joinPath,
|
||||||
|
RemoteOAIEngine,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulid'
|
|
||||||
import { join } from 'path'
|
import { join } from 'path'
|
||||||
|
|
||||||
|
declare const COMPLETION_URL: string
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
export default class JanInferenceGroqExtension extends BaseExtension {
|
export default class JanInferenceGroqExtension extends RemoteOAIEngine {
|
||||||
private static readonly _engineDir = 'file://engines'
|
private readonly _engineDir = 'file://engines'
|
||||||
private static readonly _engineMetadataFileName = 'groq.json'
|
private readonly _engineMetadataFileName = 'groq.json'
|
||||||
|
|
||||||
private static _currentModel: GroqModel
|
inferenceUrl: string = COMPLETION_URL
|
||||||
|
provider = 'groq'
|
||||||
|
apiKey = ''
|
||||||
|
|
||||||
private static _engineSettings: EngineSettings = {
|
private _engineSettings = {
|
||||||
full_url: 'https://api.groq.com/openai/v1/chat/completions',
|
full_url: COMPLETION_URL,
|
||||||
api_key: 'gsk-<your key here>',
|
api_key: 'gsk-<your key here>',
|
||||||
}
|
}
|
||||||
|
|
||||||
controller = new AbortController()
|
|
||||||
isCancelled = false
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subscribes to events emitted by the @janhq/core package.
|
* Subscribes to events emitted by the @janhq/core package.
|
||||||
*/
|
*/
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
if (!(await fs.existsSync(JanInferenceGroqExtension._engineDir))) {
|
super.onLoad()
|
||||||
await fs
|
|
||||||
.mkdirSync(JanInferenceGroqExtension._engineDir)
|
if (!(await fs.existsSync(this._engineDir))) {
|
||||||
.catch((err) => console.debug(err))
|
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
JanInferenceGroqExtension.writeDefaultEngineSettings()
|
this.writeDefaultEngineSettings()
|
||||||
|
|
||||||
// Events subscription
|
|
||||||
events.on(MessageEvent.OnMessageSent, (data) =>
|
|
||||||
JanInferenceGroqExtension.handleMessageRequest(data, this)
|
|
||||||
)
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: GroqModel) => {
|
|
||||||
JanInferenceGroqExtension.handleModelInit(model)
|
|
||||||
})
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: GroqModel) => {
|
|
||||||
JanInferenceGroqExtension.handleModelStop(model)
|
|
||||||
})
|
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () => {
|
|
||||||
JanInferenceGroqExtension.handleInferenceStopped(this)
|
|
||||||
})
|
|
||||||
|
|
||||||
const settingsFilePath = await joinPath([
|
const settingsFilePath = await joinPath([
|
||||||
JanInferenceGroqExtension._engineDir,
|
this._engineDir,
|
||||||
JanInferenceGroqExtension._engineMetadataFileName,
|
this._engineMetadataFileName,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
// Events subscription
|
||||||
events.on(
|
events.on(
|
||||||
AppConfigurationEventName.OnConfigurationUpdate,
|
AppConfigurationEventName.OnConfigurationUpdate,
|
||||||
(settingsKey: string) => {
|
(settingsKey: string) => {
|
||||||
// Update settings on changes
|
// Update settings on changes
|
||||||
if (settingsKey === settingsFilePath)
|
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
|
||||||
JanInferenceGroqExtension.writeDefaultEngineSettings()
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
async writeDefaultEngineSettings() {
|
||||||
* Stops the model inference.
|
|
||||||
*/
|
|
||||||
onUnload(): void {}
|
|
||||||
|
|
||||||
static async writeDefaultEngineSettings() {
|
|
||||||
try {
|
try {
|
||||||
const engineFile = join(
|
const engineFile = join(this._engineDir, this._engineMetadataFileName)
|
||||||
JanInferenceGroqExtension._engineDir,
|
|
||||||
JanInferenceGroqExtension._engineMetadataFileName
|
|
||||||
)
|
|
||||||
if (await fs.existsSync(engineFile)) {
|
if (await fs.existsSync(engineFile)) {
|
||||||
const engine = await fs.readFileSync(engineFile, 'utf-8')
|
const engine = await fs.readFileSync(engineFile, 'utf-8')
|
||||||
JanInferenceGroqExtension._engineSettings =
|
this._engineSettings =
|
||||||
typeof engine === 'object' ? engine : JSON.parse(engine)
|
typeof engine === 'object' ? engine : JSON.parse(engine)
|
||||||
|
this.inferenceUrl = this._engineSettings.full_url
|
||||||
|
this.apiKey = this._engineSettings.api_key
|
||||||
} else {
|
} else {
|
||||||
await fs.writeFileSync(
|
await fs.writeFileSync(
|
||||||
engineFile,
|
engineFile,
|
||||||
JSON.stringify(JanInferenceGroqExtension._engineSettings, null, 2)
|
JSON.stringify(this._engineSettings, null, 2)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private static async handleModelInit(model: GroqModel) {
|
|
||||||
if (model.engine !== InferenceEngine.groq) {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
JanInferenceGroqExtension._currentModel = model
|
|
||||||
JanInferenceGroqExtension.writeDefaultEngineSettings()
|
|
||||||
// Todo: Check model list with API key
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleModelStop(model: GroqModel) {
|
|
||||||
if (model.engine !== 'groq') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
events.emit(ModelEvent.OnModelStopped, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleInferenceStopped(
|
|
||||||
instance: JanInferenceGroqExtension
|
|
||||||
) {
|
|
||||||
instance.isCancelled = true
|
|
||||||
instance.controller?.abort()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles a new message request by making an inference request and emitting events.
|
|
||||||
* Function registered in event manager, should be static to avoid binding issues.
|
|
||||||
* Pass instance as a reference.
|
|
||||||
* @param {MessageRequest} data - The data for the new message request.
|
|
||||||
*/
|
|
||||||
private static async handleMessageRequest(
|
|
||||||
data: MessageRequest,
|
|
||||||
instance: JanInferenceGroqExtension
|
|
||||||
) {
|
|
||||||
if (data.model.engine !== 'groq') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timestamp = Date.now()
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: ulid(),
|
|
||||||
thread_id: data.threadId,
|
|
||||||
type: data.type,
|
|
||||||
assistant_id: data.assistantId,
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created: timestamp,
|
|
||||||
updated: timestamp,
|
|
||||||
object: 'thread.message',
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.type !== MessageRequestType.Summary) {
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
instance.isCancelled = false
|
|
||||||
instance.controller = new AbortController()
|
|
||||||
|
|
||||||
requestInference(
|
|
||||||
data?.messages ?? [],
|
|
||||||
this._engineSettings,
|
|
||||||
{
|
|
||||||
...JanInferenceGroqExtension._currentModel,
|
|
||||||
parameters: data.model.parameters,
|
|
||||||
},
|
|
||||||
instance.controller
|
|
||||||
).subscribe({
|
|
||||||
next: (content) => {
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: content.trim(),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
complete: async () => {
|
|
||||||
message.status = message.content.length
|
|
||||||
? MessageStatus.Ready
|
|
||||||
: MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
error: async (err) => {
|
|
||||||
if (instance.isCancelled || message.content.length > 0) {
|
|
||||||
message.status = MessageStatus.Stopped
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: 'An error occurred. ' + err.message,
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
message.status = MessageStatus.Error
|
|
||||||
message.error_code = err.code
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ module.exports = {
|
|||||||
plugins: [
|
plugins: [
|
||||||
new webpack.DefinePlugin({
|
new webpack.DefinePlugin({
|
||||||
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
||||||
GROQ_DOMAIN: JSON.stringify('api.groq.com'),
|
COMPLETION_URL: JSON.stringify('https://api.groq.com/openai/v1/chat/completions'),
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
output: {
|
output: {
|
||||||
|
|||||||
5
extensions/inference-nitro-extension/jest.config.js
Normal file
5
extensions/inference-nitro-extension/jest.config.js
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
/** @type {import('ts-jest').JestConfigWithTsJest} */
|
||||||
|
module.exports = {
|
||||||
|
preset: 'ts-jest',
|
||||||
|
testEnvironment: 'node',
|
||||||
|
};
|
||||||
@ -7,6 +7,7 @@
|
|||||||
"author": "Jan <service@jan.ai>",
|
"author": "Jan <service@jan.ai>",
|
||||||
"license": "AGPL-3.0",
|
"license": "AGPL-3.0",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
"test": "jest",
|
||||||
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
|
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
|
||||||
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro",
|
"downloadnitro:linux": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64.tar.gz -e --strip 1 -o ./bin/linux-cpu && chmod +x ./bin/linux-cpu/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-12-0.tar.gz -e --strip 1 -o ./bin/linux-cuda-12-0 && chmod +x ./bin/linux-cuda-12-0/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-cuda-11-7.tar.gz -e --strip 1 -o ./bin/linux-cuda-11-7 && chmod +x ./bin/linux-cuda-11-7/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-linux-amd64-vulkan.tar.gz -e --strip 1 -o ./bin/linux-vulkan && chmod +x ./bin/linux-vulkan/nitro",
|
||||||
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
|
"downloadnitro:darwin": "NITRO_VERSION=$(cat ./bin/version.txt) && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-arm64.tar.gz -e --strip 1 -o ./bin/mac-arm64 && chmod +x ./bin/mac-arm64/nitro && download https://github.com/janhq/nitro/releases/download/v${NITRO_VERSION}/nitro-${NITRO_VERSION}-mac-amd64.tar.gz -e --strip 1 -o ./bin/mac-x64 && chmod +x ./bin/mac-x64/nitro",
|
||||||
@ -15,29 +16,34 @@
|
|||||||
"build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
"build:publish:darwin": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && ../../.github/scripts/auto-sign.sh && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||||
"build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
"build:publish:win32": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||||
"build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
"build:publish:linux": "rimraf *.tgz --glob && yarn build && npm run downloadnitro && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
|
||||||
"build:publish": "run-script-os"
|
"build:publish": "yarn test && run-script-os"
|
||||||
},
|
},
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
"./main": "./dist/node/index.cjs.js"
|
"./main": "./dist/node/index.cjs.js"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
"@babel/preset-typescript": "^7.24.1",
|
||||||
|
"@jest/globals": "^29.7.0",
|
||||||
"@rollup/plugin-commonjs": "^25.0.7",
|
"@rollup/plugin-commonjs": "^25.0.7",
|
||||||
"@rollup/plugin-json": "^6.1.0",
|
"@rollup/plugin-json": "^6.1.0",
|
||||||
"@rollup/plugin-node-resolve": "^15.2.3",
|
"@rollup/plugin-node-resolve": "^15.2.3",
|
||||||
|
"@rollup/plugin-replace": "^5.0.5",
|
||||||
|
"@types/jest": "^29.5.12",
|
||||||
"@types/node": "^20.11.4",
|
"@types/node": "^20.11.4",
|
||||||
|
"@types/os-utils": "^0.0.4",
|
||||||
"@types/tcp-port-used": "^1.0.4",
|
"@types/tcp-port-used": "^1.0.4",
|
||||||
"cpx": "^1.5.0",
|
"cpx": "^1.5.0",
|
||||||
"download-cli": "^1.1.1",
|
"download-cli": "^1.1.1",
|
||||||
|
"jest": "^29.7.0",
|
||||||
"rimraf": "^3.0.2",
|
"rimraf": "^3.0.2",
|
||||||
"rollup": "^2.38.5",
|
"rollup": "^2.38.5",
|
||||||
"rollup-plugin-define": "^1.0.1",
|
"rollup-plugin-define": "^1.0.1",
|
||||||
"rollup-plugin-sourcemaps": "^0.6.3",
|
"rollup-plugin-sourcemaps": "^0.6.3",
|
||||||
"rollup-plugin-typescript2": "^0.36.0",
|
"rollup-plugin-typescript2": "^0.36.0",
|
||||||
"run-script-os": "^1.1.6",
|
"run-script-os": "^1.1.6",
|
||||||
"typescript": "^5.3.3",
|
"ts-jest": "^29.1.2",
|
||||||
"@types/os-utils": "^0.0.4",
|
"typescript": "^5.3.3"
|
||||||
"@rollup/plugin-replace": "^5.0.5"
|
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@janhq/core": "file:../../core",
|
"@janhq/core": "file:../../core",
|
||||||
|
|||||||
6
extensions/inference-nitro-extension/src/babel.config.js
Normal file
6
extensions/inference-nitro-extension/src/babel.config.js
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
module.exports = {
|
||||||
|
presets: [
|
||||||
|
['@babel/preset-env', { targets: { node: 'current' } }],
|
||||||
|
'@babel/preset-typescript',
|
||||||
|
],
|
||||||
|
}
|
||||||
@ -1,66 +0,0 @@
|
|||||||
import { Model } from '@janhq/core'
|
|
||||||
import { Observable } from 'rxjs'
|
|
||||||
/**
|
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
|
||||||
* @returns An Observable that emits the generated response as a string.
|
|
||||||
*/
|
|
||||||
export function requestInference(
|
|
||||||
inferenceUrl: string,
|
|
||||||
recentMessages: any[],
|
|
||||||
model: Model,
|
|
||||||
controller?: AbortController
|
|
||||||
): Observable<string> {
|
|
||||||
return new Observable((subscriber) => {
|
|
||||||
const requestBody = JSON.stringify({
|
|
||||||
messages: recentMessages,
|
|
||||||
model: model.id,
|
|
||||||
stream: true,
|
|
||||||
...model.parameters,
|
|
||||||
})
|
|
||||||
fetch(inferenceUrl, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'Accept': model.parameters.stream
|
|
||||||
? 'text/event-stream'
|
|
||||||
: 'application/json',
|
|
||||||
},
|
|
||||||
body: requestBody,
|
|
||||||
signal: controller?.signal,
|
|
||||||
})
|
|
||||||
.then(async (response) => {
|
|
||||||
if (model.parameters.stream === false) {
|
|
||||||
const data = await response.json()
|
|
||||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
|
||||||
} else {
|
|
||||||
const stream = response.body
|
|
||||||
const decoder = new TextDecoder('utf-8')
|
|
||||||
const reader = stream?.getReader()
|
|
||||||
let content = ''
|
|
||||||
|
|
||||||
while (true && reader) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const text = decoder.decode(value)
|
|
||||||
const lines = text.trim().split('\n')
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
|
|
||||||
const data = JSON.parse(line.replace('data: ', ''))
|
|
||||||
content += data.choices[0]?.delta?.content ?? ''
|
|
||||||
if (content.startsWith('assistant: ')) {
|
|
||||||
content = content.replace('assistant: ', '')
|
|
||||||
}
|
|
||||||
subscriber.next(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
subscriber.complete()
|
|
||||||
})
|
|
||||||
.catch((err) => subscriber.error(err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -7,58 +7,31 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatCompletionRole,
|
|
||||||
ContentType,
|
|
||||||
MessageRequest,
|
|
||||||
MessageRequestType,
|
|
||||||
MessageStatus,
|
|
||||||
ThreadContent,
|
|
||||||
ThreadMessage,
|
|
||||||
events,
|
events,
|
||||||
executeOnMain,
|
executeOnMain,
|
||||||
fs,
|
|
||||||
Model,
|
Model,
|
||||||
joinPath,
|
|
||||||
InferenceExtension,
|
|
||||||
log,
|
|
||||||
InferenceEngine,
|
|
||||||
MessageEvent,
|
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
InferenceEvent,
|
LocalOAIEngine,
|
||||||
ModelSettingParams,
|
|
||||||
getJanDataFolderPath,
|
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulidx'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
export default class JanInferenceNitroExtension extends InferenceExtension {
|
export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
||||||
private static readonly _homeDir = 'file://engines'
|
nodeModule: string = NODE
|
||||||
private static readonly _settingsDir = 'file://settings'
|
provider: string = 'nitro'
|
||||||
private static readonly _engineMetadataFileName = 'nitro.json'
|
|
||||||
|
models(): Promise<Model[]> {
|
||||||
|
return Promise.resolve([])
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checking the health for Nitro's process each 5 secs.
|
* Checking the health for Nitro's process each 5 secs.
|
||||||
*/
|
*/
|
||||||
private static readonly _intervalHealthCheck = 5 * 1000
|
private static readonly _intervalHealthCheck = 5 * 1000
|
||||||
|
|
||||||
private _currentModel: Model | undefined
|
|
||||||
|
|
||||||
private _engineSettings: ModelSettingParams = {
|
|
||||||
ctx_len: 2048,
|
|
||||||
ngl: 100,
|
|
||||||
cpu_threads: 1,
|
|
||||||
cont_batching: false,
|
|
||||||
embedding: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
controller = new AbortController()
|
|
||||||
isCancelled = false
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The interval id for the health check. Used to stop the health check.
|
* The interval id for the health check. Used to stop the health check.
|
||||||
*/
|
*/
|
||||||
@ -69,114 +42,30 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
|
|||||||
*/
|
*/
|
||||||
private nitroProcessInfo: any = undefined
|
private nitroProcessInfo: any = undefined
|
||||||
|
|
||||||
private inferenceUrl = ''
|
/**
|
||||||
|
* The URL for making inference requests.
|
||||||
|
*/
|
||||||
|
inferenceUrl = ''
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subscribes to events emitted by the @janhq/core package.
|
* Subscribes to events emitted by the @janhq/core package.
|
||||||
*/
|
*/
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) {
|
|
||||||
try {
|
|
||||||
await fs.mkdirSync(JanInferenceNitroExtension._homeDir)
|
|
||||||
} catch (e) {
|
|
||||||
console.debug(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// init inference url
|
|
||||||
// @ts-ignore
|
|
||||||
const electronApi = window?.electronAPI
|
|
||||||
this.inferenceUrl = INFERENCE_URL
|
this.inferenceUrl = INFERENCE_URL
|
||||||
if (!electronApi) {
|
|
||||||
|
// If the extension is running in the browser, use the base API URL from the core package.
|
||||||
|
if (!('electronAPI' in window)) {
|
||||||
this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions`
|
this.inferenceUrl = `${window.core?.api?.baseApiUrl}/v1/chat/completions`
|
||||||
}
|
}
|
||||||
|
|
||||||
console.debug('Inference url: ', this.inferenceUrl)
|
console.debug('Inference url: ', this.inferenceUrl)
|
||||||
|
|
||||||
if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir)))
|
|
||||||
await fs.mkdirSync(JanInferenceNitroExtension._settingsDir)
|
|
||||||
this.writeDefaultEngineSettings()
|
|
||||||
|
|
||||||
// Events subscription
|
|
||||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
|
|
||||||
this.onMessageRequest(data)
|
|
||||||
)
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.onModelInit(model))
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.onModelStop(model))
|
|
||||||
|
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () =>
|
|
||||||
this.onInferenceStopped()
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stops the model inference.
|
|
||||||
*/
|
|
||||||
onUnload(): void {}
|
|
||||||
|
|
||||||
private async writeDefaultEngineSettings() {
|
|
||||||
try {
|
|
||||||
const engineFile = await joinPath([
|
|
||||||
JanInferenceNitroExtension._homeDir,
|
|
||||||
JanInferenceNitroExtension._engineMetadataFileName,
|
|
||||||
])
|
|
||||||
if (await fs.existsSync(engineFile)) {
|
|
||||||
const engine = await fs.readFileSync(engineFile, 'utf-8')
|
|
||||||
this._engineSettings =
|
|
||||||
typeof engine === 'object' ? engine : JSON.parse(engine)
|
|
||||||
} else {
|
|
||||||
await fs.writeFileSync(
|
|
||||||
engineFile,
|
|
||||||
JSON.stringify(this._engineSettings, null, 2)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} catch (err) {
|
|
||||||
console.error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async onModelInit(model: Model) {
|
|
||||||
if (model.engine !== InferenceEngine.nitro) return
|
|
||||||
|
|
||||||
const modelFolder = await joinPath([
|
|
||||||
await getJanDataFolderPath(),
|
|
||||||
'models',
|
|
||||||
model.id,
|
|
||||||
])
|
|
||||||
this._currentModel = model
|
|
||||||
const nitroInitResult = await executeOnMain(NODE, 'runModel', {
|
|
||||||
modelFolder,
|
|
||||||
model,
|
|
||||||
})
|
|
||||||
|
|
||||||
if (nitroInitResult?.error) {
|
|
||||||
events.emit(ModelEvent.OnModelFail, {
|
|
||||||
...model,
|
|
||||||
error: nitroInitResult.error,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
|
|
||||||
this.getNitroProcesHealthIntervalId = setInterval(
|
this.getNitroProcesHealthIntervalId = setInterval(
|
||||||
() => this.periodicallyGetNitroHealth(),
|
() => this.periodicallyGetNitroHealth(),
|
||||||
JanInferenceNitroExtension._intervalHealthCheck
|
JanInferenceNitroExtension._intervalHealthCheck
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
private async onModelStop(model: Model) {
|
super.onLoad()
|
||||||
if (model.engine !== 'nitro') return
|
|
||||||
|
|
||||||
await executeOnMain(NODE, 'stopModel')
|
|
||||||
events.emit(ModelEvent.OnModelStopped, {})
|
|
||||||
|
|
||||||
// stop the periocally health check
|
|
||||||
if (this.getNitroProcesHealthIntervalId) {
|
|
||||||
clearInterval(this.getNitroProcesHealthIntervalId)
|
|
||||||
this.getNitroProcesHealthIntervalId = undefined
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -193,118 +82,24 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
|
|||||||
this.nitroProcessInfo = health
|
this.nitroProcessInfo = health
|
||||||
}
|
}
|
||||||
|
|
||||||
private async onInferenceStopped() {
|
override loadModel(model: Model): Promise<void> {
|
||||||
this.isCancelled = true
|
if (model.engine !== this.provider) return Promise.resolve()
|
||||||
this.controller?.abort()
|
this.getNitroProcesHealthIntervalId = setInterval(
|
||||||
|
() => this.periodicallyGetNitroHealth(),
|
||||||
|
JanInferenceNitroExtension._intervalHealthCheck
|
||||||
|
)
|
||||||
|
return super.loadModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
override unloadModel(model: Model): void {
|
||||||
* Makes a single response inference request.
|
super.unloadModel(model)
|
||||||
* @param {MessageRequest} data - The data for the inference request.
|
|
||||||
* @returns {Promise<any>} A promise that resolves with the inference response.
|
if (model.engine && model.engine !== this.provider) return
|
||||||
*/
|
|
||||||
async inference(data: MessageRequest): Promise<ThreadMessage> {
|
// stop the periocally health check
|
||||||
const timestamp = Date.now()
|
if (this.getNitroProcesHealthIntervalId) {
|
||||||
const message: ThreadMessage = {
|
clearInterval(this.getNitroProcesHealthIntervalId)
|
||||||
thread_id: data.threadId,
|
this.getNitroProcesHealthIntervalId = undefined
|
||||||
created: timestamp,
|
|
||||||
updated: timestamp,
|
|
||||||
status: MessageStatus.Ready,
|
|
||||||
id: '',
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
object: 'thread.message',
|
|
||||||
content: [],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new Promise(async (resolve, reject) => {
|
|
||||||
if (!this._currentModel) return Promise.reject('No model loaded')
|
|
||||||
|
|
||||||
requestInference(
|
|
||||||
this.inferenceUrl,
|
|
||||||
data.messages ?? [],
|
|
||||||
this._currentModel
|
|
||||||
).subscribe({
|
|
||||||
next: (_content: any) => {},
|
|
||||||
complete: async () => {
|
|
||||||
resolve(message)
|
|
||||||
},
|
|
||||||
error: async (err: any) => {
|
|
||||||
reject(err)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles a new message request by making an inference request and emitting events.
|
|
||||||
* Function registered in event manager, should be static to avoid binding issues.
|
|
||||||
* Pass instance as a reference.
|
|
||||||
* @param {MessageRequest} data - The data for the new message request.
|
|
||||||
*/
|
|
||||||
private async onMessageRequest(data: MessageRequest) {
|
|
||||||
if (data.model?.engine !== InferenceEngine.nitro || !this._currentModel) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timestamp = Date.now()
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: ulid(),
|
|
||||||
thread_id: data.threadId,
|
|
||||||
type: data.type,
|
|
||||||
assistant_id: data.assistantId,
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created: timestamp,
|
|
||||||
updated: timestamp,
|
|
||||||
object: 'thread.message',
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.type !== MessageRequestType.Summary) {
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.isCancelled = false
|
|
||||||
this.controller = new AbortController()
|
|
||||||
|
|
||||||
// @ts-ignore
|
|
||||||
const model: Model = {
|
|
||||||
...(this._currentModel || {}),
|
|
||||||
...(data.model || {}),
|
|
||||||
}
|
|
||||||
requestInference(
|
|
||||||
this.inferenceUrl,
|
|
||||||
data.messages ?? [],
|
|
||||||
model,
|
|
||||||
this.controller
|
|
||||||
).subscribe({
|
|
||||||
next: (content: any) => {
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: content.trim(),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
complete: async () => {
|
|
||||||
message.status = message.content.length
|
|
||||||
? MessageStatus.Ready
|
|
||||||
: MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
error: async (err: any) => {
|
|
||||||
if (this.isCancelled || message.content.length) {
|
|
||||||
message.status = MessageStatus.Stopped
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
message.status = MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
log(`[APP]::Error: ${err.message}`)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
233
extensions/inference-nitro-extension/src/node/execute.test.ts
Normal file
233
extensions/inference-nitro-extension/src/node/execute.test.ts
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import { describe, expect, it } from '@jest/globals'
|
||||||
|
import { executableNitroFile } from './execute'
|
||||||
|
import { GpuSetting } from '@janhq/core'
|
||||||
|
import { sep } from 'path'
|
||||||
|
|
||||||
|
let testSettings: GpuSetting = {
|
||||||
|
run_mode: 'cpu',
|
||||||
|
vulkan: false,
|
||||||
|
cuda: {
|
||||||
|
exist: false,
|
||||||
|
version: '11',
|
||||||
|
},
|
||||||
|
gpu_highest_vram: '0',
|
||||||
|
gpus: [],
|
||||||
|
gpus_in_use: [],
|
||||||
|
is_initial: false,
|
||||||
|
notify: true,
|
||||||
|
nvidia_driver: {
|
||||||
|
exist: false,
|
||||||
|
version: '11',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
const originalPlatform = process.platform
|
||||||
|
|
||||||
|
describe('test executable nitro file', () => {
|
||||||
|
afterAll(function () {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: originalPlatform,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on MacOS ARM', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'darwin',
|
||||||
|
})
|
||||||
|
Object.defineProperty(process, 'arch', {
|
||||||
|
value: 'arm64',
|
||||||
|
})
|
||||||
|
expect(executableNitroFile(testSettings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`mac-arm64${sep}nitro`),
|
||||||
|
cudaVisibleDevices: '',
|
||||||
|
vkVisibleDevices: '',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on MacOS Intel', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'darwin',
|
||||||
|
})
|
||||||
|
Object.defineProperty(process, 'arch', {
|
||||||
|
value: 'x64',
|
||||||
|
})
|
||||||
|
expect(executableNitroFile(testSettings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`mac-x64${sep}nitro`),
|
||||||
|
cudaVisibleDevices: '',
|
||||||
|
vkVisibleDevices: '',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Windows CPU', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'win32',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'cpu',
|
||||||
|
cuda: {
|
||||||
|
exist: true,
|
||||||
|
version: '11',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`win-cpu${sep}nitro.exe`),
|
||||||
|
cudaVisibleDevices: '',
|
||||||
|
vkVisibleDevices: '',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Windows Cuda 11', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'win32',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'gpu',
|
||||||
|
cuda: {
|
||||||
|
exist: true,
|
||||||
|
version: '11',
|
||||||
|
},
|
||||||
|
nvidia_driver: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
gpus_in_use: ['0'],
|
||||||
|
gpus: [
|
||||||
|
{
|
||||||
|
id: '0',
|
||||||
|
name: 'NVIDIA GeForce GTX 1080',
|
||||||
|
vram: '80000000',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`win-cuda-11-7${sep}nitro.exe`),
|
||||||
|
cudaVisibleDevices: '0',
|
||||||
|
vkVisibleDevices: '0',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Windows Cuda 12', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'win32',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'gpu',
|
||||||
|
cuda: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
nvidia_driver: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
gpus_in_use: ['0'],
|
||||||
|
gpus: [
|
||||||
|
{
|
||||||
|
id: '0',
|
||||||
|
name: 'NVIDIA GeForce GTX 1080',
|
||||||
|
vram: '80000000',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`win-cuda-12-0${sep}nitro.exe`),
|
||||||
|
cudaVisibleDevices: '0',
|
||||||
|
vkVisibleDevices: '0',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Linux CPU', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'linux',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'cpu',
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`linux-cpu${sep}nitro`),
|
||||||
|
cudaVisibleDevices: '',
|
||||||
|
vkVisibleDevices: '',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Linux Cuda 11', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'linux',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'gpu',
|
||||||
|
cuda: {
|
||||||
|
exist: true,
|
||||||
|
version: '11',
|
||||||
|
},
|
||||||
|
nvidia_driver: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
gpus_in_use: ['0'],
|
||||||
|
gpus: [
|
||||||
|
{
|
||||||
|
id: '0',
|
||||||
|
name: 'NVIDIA GeForce GTX 1080',
|
||||||
|
vram: '80000000',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`linux-cuda-11-7${sep}nitro`),
|
||||||
|
cudaVisibleDevices: '0',
|
||||||
|
vkVisibleDevices: '0',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('executes on Linux Cuda 12', () => {
|
||||||
|
Object.defineProperty(process, 'platform', {
|
||||||
|
value: 'linux',
|
||||||
|
})
|
||||||
|
const settings: GpuSetting = {
|
||||||
|
...testSettings,
|
||||||
|
run_mode: 'gpu',
|
||||||
|
cuda: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
nvidia_driver: {
|
||||||
|
exist: true,
|
||||||
|
version: '12',
|
||||||
|
},
|
||||||
|
gpus_in_use: ['0'],
|
||||||
|
gpus: [
|
||||||
|
{
|
||||||
|
id: '0',
|
||||||
|
name: 'NVIDIA GeForce GTX 1080',
|
||||||
|
vram: '80000000',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
expect(executableNitroFile(settings)).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
executablePath: expect.stringContaining(`linux-cuda-12-0${sep}nitro`),
|
||||||
|
cudaVisibleDevices: '0',
|
||||||
|
vkVisibleDevices: '0',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -1,5 +1,4 @@
|
|||||||
import { getJanDataFolderPath } from '@janhq/core/node'
|
import { GpuSetting, SystemInformation } from '@janhq/core'
|
||||||
import { readFileSync } from 'fs'
|
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
|
|
||||||
export interface NitroExecutableOptions {
|
export interface NitroExecutableOptions {
|
||||||
@ -7,79 +6,56 @@ export interface NitroExecutableOptions {
|
|||||||
cudaVisibleDevices: string
|
cudaVisibleDevices: string
|
||||||
vkVisibleDevices: string
|
vkVisibleDevices: string
|
||||||
}
|
}
|
||||||
|
const runMode = (settings?: GpuSetting): string => {
|
||||||
|
if (process.platform === 'darwin')
|
||||||
|
// MacOS use arch instead of cpu / cuda
|
||||||
|
return process.arch === 'arm64' ? 'arm64' : 'x64'
|
||||||
|
|
||||||
export const GPU_INFO_FILE = path.join(
|
if (!settings) return 'cpu'
|
||||||
getJanDataFolderPath(),
|
|
||||||
'settings',
|
return settings.vulkan === true
|
||||||
'settings.json'
|
? 'vulkan'
|
||||||
)
|
: settings.run_mode === 'cpu'
|
||||||
|
? 'cpu'
|
||||||
|
: 'cuda'
|
||||||
|
}
|
||||||
|
|
||||||
|
const os = (): string => {
|
||||||
|
return process.platform === 'win32'
|
||||||
|
? 'win'
|
||||||
|
: process.platform === 'darwin'
|
||||||
|
? 'mac'
|
||||||
|
: 'linux'
|
||||||
|
}
|
||||||
|
|
||||||
|
const extension = (): '.exe' | '' => {
|
||||||
|
return process.platform === 'win32' ? '.exe' : ''
|
||||||
|
}
|
||||||
|
|
||||||
|
const cudaVersion = (settings?: GpuSetting): '11-7' | '12-0' | undefined => {
|
||||||
|
const isUsingCuda =
|
||||||
|
settings?.vulkan !== true && settings?.run_mode === 'gpu' && os() !== 'mac'
|
||||||
|
|
||||||
|
if (!isUsingCuda) return undefined
|
||||||
|
return settings?.cuda?.version === '11' ? '11-7' : '12-0'
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Find which executable file to run based on the current platform.
|
* Find which executable file to run based on the current platform.
|
||||||
* @returns The name of the executable file to run.
|
* @returns The name of the executable file to run.
|
||||||
*/
|
*/
|
||||||
export const executableNitroFile = (): NitroExecutableOptions => {
|
export const executableNitroFile = (
|
||||||
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
|
gpuSetting?: GpuSetting
|
||||||
let cudaVisibleDevices = ''
|
): NitroExecutableOptions => {
|
||||||
let vkVisibleDevices = ''
|
let binaryFolder = [os(), runMode(gpuSetting), cudaVersion(gpuSetting)]
|
||||||
let binaryName = 'nitro'
|
.filter((e) => !!e)
|
||||||
/**
|
.join('-')
|
||||||
* The binary folder is different for each platform.
|
let cudaVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
|
||||||
*/
|
let vkVisibleDevices = gpuSetting?.gpus_in_use.join(',') ?? ''
|
||||||
if (process.platform === 'win32') {
|
let binaryName = `nitro${extension()}`
|
||||||
/**
|
|
||||||
* For Windows: win-cpu, win-vulkan, win-cuda-11-7, win-cuda-12-0
|
|
||||||
*/
|
|
||||||
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
|
|
||||||
if (gpuInfo['run_mode'] === 'cpu') {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'win-cpu')
|
|
||||||
} else {
|
|
||||||
if (gpuInfo['cuda']?.version === '11') {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'win-cuda-11-7')
|
|
||||||
} else {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'win-cuda-12-0')
|
|
||||||
}
|
|
||||||
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
|
|
||||||
}
|
|
||||||
if (gpuInfo['vulkan'] === true) {
|
|
||||||
binaryFolder = path.join(__dirname, '..', 'bin')
|
|
||||||
binaryFolder = path.join(binaryFolder, 'win-vulkan')
|
|
||||||
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
|
|
||||||
}
|
|
||||||
binaryName = 'nitro.exe'
|
|
||||||
} else if (process.platform === 'darwin') {
|
|
||||||
/**
|
|
||||||
* For MacOS: mac-arm64 (Silicon), mac-x64 (InteL)
|
|
||||||
*/
|
|
||||||
if (process.arch === 'arm64') {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'mac-arm64')
|
|
||||||
} else {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'mac-x64')
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
/**
|
|
||||||
* For Linux: linux-cpu, linux-vulkan, linux-cuda-11-7, linux-cuda-12-0
|
|
||||||
*/
|
|
||||||
let gpuInfo = JSON.parse(readFileSync(GPU_INFO_FILE, 'utf-8'))
|
|
||||||
if (gpuInfo['run_mode'] === 'cpu') {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'linux-cpu')
|
|
||||||
} else {
|
|
||||||
if (gpuInfo['cuda']?.version === '11') {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'linux-cuda-11-7')
|
|
||||||
} else {
|
|
||||||
binaryFolder = path.join(binaryFolder, 'linux-cuda-12-0')
|
|
||||||
}
|
|
||||||
cudaVisibleDevices = gpuInfo['gpus_in_use'].join(',')
|
|
||||||
}
|
|
||||||
|
|
||||||
if (gpuInfo['vulkan'] === true) {
|
|
||||||
binaryFolder = path.join(__dirname, '..', 'bin')
|
|
||||||
binaryFolder = path.join(binaryFolder, 'linux-vulkan')
|
|
||||||
vkVisibleDevices = gpuInfo['gpus_in_use'].toString()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return {
|
return {
|
||||||
executablePath: path.join(binaryFolder, binaryName),
|
executablePath: path.join(__dirname, '..', 'bin', binaryFolder, binaryName),
|
||||||
cudaVisibleDevices,
|
cudaVisibleDevices,
|
||||||
vkVisibleDevices,
|
vkVisibleDevices,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import {
|
|||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
ModelSettingParams,
|
ModelSettingParams,
|
||||||
PromptTemplate,
|
PromptTemplate,
|
||||||
|
SystemInformation,
|
||||||
} from '@janhq/core/node'
|
} from '@janhq/core/node'
|
||||||
import { executableNitroFile } from './execute'
|
import { executableNitroFile } from './execute'
|
||||||
|
|
||||||
@ -51,7 +52,7 @@ let currentSettings: ModelSettingParams | undefined = undefined
|
|||||||
* @param wrapper - The model wrapper.
|
* @param wrapper - The model wrapper.
|
||||||
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
|
* @returns A Promise that resolves when the subprocess is terminated successfully, or rejects with an error message if the subprocess fails to terminate.
|
||||||
*/
|
*/
|
||||||
function stopModel(): Promise<void> {
|
function unloadModel(): Promise<void> {
|
||||||
return killSubprocess()
|
return killSubprocess()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -61,46 +62,47 @@ function stopModel(): Promise<void> {
|
|||||||
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
|
* @returns A Promise that resolves when the model is loaded successfully, or rejects with an error message if the model is not found or fails to load.
|
||||||
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
|
* TODO: Should pass absolute of the model file instead of just the name - So we can modurize the module.ts to npm package
|
||||||
*/
|
*/
|
||||||
async function runModel(
|
async function loadModel(
|
||||||
wrapper: ModelInitOptions
|
params: ModelInitOptions,
|
||||||
|
systemInfo?: SystemInformation
|
||||||
): Promise<ModelOperationResponse | void> {
|
): Promise<ModelOperationResponse | void> {
|
||||||
if (wrapper.model.engine !== InferenceEngine.nitro) {
|
if (params.model.engine !== InferenceEngine.nitro) {
|
||||||
// Not a nitro model
|
// Not a nitro model
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
if (wrapper.model.engine !== InferenceEngine.nitro) {
|
if (params.model.engine !== InferenceEngine.nitro) {
|
||||||
return Promise.reject('Not a nitro model')
|
return Promise.reject('Not a nitro model')
|
||||||
} else {
|
} else {
|
||||||
const nitroResourceProbe = await getSystemResourceInfo()
|
const nitroResourceProbe = await getSystemResourceInfo()
|
||||||
// Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt
|
// Convert settings.prompt_template to system_prompt, user_prompt, ai_prompt
|
||||||
if (wrapper.model.settings.prompt_template) {
|
if (params.model.settings.prompt_template) {
|
||||||
const promptTemplate = wrapper.model.settings.prompt_template
|
const promptTemplate = params.model.settings.prompt_template
|
||||||
const prompt = promptTemplateConverter(promptTemplate)
|
const prompt = promptTemplateConverter(promptTemplate)
|
||||||
if (prompt?.error) {
|
if (prompt?.error) {
|
||||||
return Promise.reject(prompt.error)
|
return Promise.reject(prompt.error)
|
||||||
}
|
}
|
||||||
wrapper.model.settings.system_prompt = prompt.system_prompt
|
params.model.settings.system_prompt = prompt.system_prompt
|
||||||
wrapper.model.settings.user_prompt = prompt.user_prompt
|
params.model.settings.user_prompt = prompt.user_prompt
|
||||||
wrapper.model.settings.ai_prompt = prompt.ai_prompt
|
params.model.settings.ai_prompt = prompt.ai_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// modelFolder is the absolute path to the running model folder
|
// modelFolder is the absolute path to the running model folder
|
||||||
// e.g. ~/jan/models/llama-2
|
// e.g. ~/jan/models/llama-2
|
||||||
let modelFolder = wrapper.modelFolder
|
let modelFolder = params.modelFolder
|
||||||
|
|
||||||
let llama_model_path = wrapper.model.settings.llama_model_path
|
let llama_model_path = params.model.settings.llama_model_path
|
||||||
|
|
||||||
// Absolute model path support
|
// Absolute model path support
|
||||||
if (
|
if (
|
||||||
wrapper.model?.sources.length &&
|
params.model?.sources.length &&
|
||||||
wrapper.model.sources.every((e) => fs.existsSync(e.url))
|
params.model.sources.every((e) => fs.existsSync(e.url))
|
||||||
) {
|
) {
|
||||||
llama_model_path =
|
llama_model_path =
|
||||||
wrapper.model.sources.length === 1
|
params.model.sources.length === 1
|
||||||
? wrapper.model.sources[0].url
|
? params.model.sources[0].url
|
||||||
: wrapper.model.sources.find((e) =>
|
: params.model.sources.find((e) =>
|
||||||
e.url.includes(llama_model_path ?? wrapper.model.id)
|
e.url.includes(llama_model_path ?? params.model.id)
|
||||||
)?.url
|
)?.url
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +116,7 @@ async function runModel(
|
|||||||
// 2. Prioritize GGUF File (manual import)
|
// 2. Prioritize GGUF File (manual import)
|
||||||
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) ||
|
file.toLowerCase().includes(SUPPORTED_MODEL_FORMAT) ||
|
||||||
// 3. Fallback Model ID (for backward compatibility)
|
// 3. Fallback Model ID (for backward compatibility)
|
||||||
file === wrapper.model.id
|
file === params.model.id
|
||||||
)
|
)
|
||||||
if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile)
|
if (ggufBinFile) llama_model_path = path.join(modelFolder, ggufBinFile)
|
||||||
}
|
}
|
||||||
@ -124,17 +126,17 @@ async function runModel(
|
|||||||
if (!llama_model_path) return Promise.reject('No GGUF model file found')
|
if (!llama_model_path) return Promise.reject('No GGUF model file found')
|
||||||
|
|
||||||
currentSettings = {
|
currentSettings = {
|
||||||
...wrapper.model.settings,
|
...params.model.settings,
|
||||||
llama_model_path,
|
llama_model_path,
|
||||||
// This is critical and requires real CPU physical core count (or performance core)
|
// This is critical and requires real CPU physical core count (or performance core)
|
||||||
cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore),
|
cpu_threads: Math.max(1, nitroResourceProbe.numCpuPhysicalCore),
|
||||||
...(wrapper.model.settings.mmproj && {
|
...(params.model.settings.mmproj && {
|
||||||
mmproj: path.isAbsolute(wrapper.model.settings.mmproj)
|
mmproj: path.isAbsolute(params.model.settings.mmproj)
|
||||||
? wrapper.model.settings.mmproj
|
? params.model.settings.mmproj
|
||||||
: path.join(modelFolder, wrapper.model.settings.mmproj),
|
: path.join(modelFolder, params.model.settings.mmproj),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
return runNitroAndLoadModel()
|
return runNitroAndLoadModel(systemInfo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,7 +146,7 @@ async function runModel(
|
|||||||
* 3. Validate model status
|
* 3. Validate model status
|
||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
async function runNitroAndLoadModel() {
|
async function runNitroAndLoadModel(systemInfo?: SystemInformation) {
|
||||||
// Gather system information for CPU physical cores and memory
|
// Gather system information for CPU physical cores and memory
|
||||||
return killSubprocess()
|
return killSubprocess()
|
||||||
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
|
.then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000))
|
||||||
@ -160,7 +162,7 @@ async function runNitroAndLoadModel() {
|
|||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(spawnNitroProcess)
|
.then(() => spawnNitroProcess(systemInfo))
|
||||||
.then(() => loadLLMModel(currentSettings))
|
.then(() => loadLLMModel(currentSettings))
|
||||||
.then(validateModelStatus)
|
.then(validateModelStatus)
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
@ -325,12 +327,12 @@ async function killSubprocess(): Promise<void> {
|
|||||||
* Spawns a Nitro subprocess.
|
* Spawns a Nitro subprocess.
|
||||||
* @returns A promise that resolves when the Nitro subprocess is started.
|
* @returns A promise that resolves when the Nitro subprocess is started.
|
||||||
*/
|
*/
|
||||||
function spawnNitroProcess(): Promise<any> {
|
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||||
log(`[NITRO]::Debug: Spawning Nitro subprocess...`)
|
log(`[NITRO]::Debug: Spawning Nitro subprocess...`)
|
||||||
|
|
||||||
return new Promise<void>(async (resolve, reject) => {
|
return new Promise<void>(async (resolve, reject) => {
|
||||||
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
|
let binaryFolder = path.join(__dirname, '..', 'bin') // Current directory by default
|
||||||
let executableOptions = executableNitroFile()
|
let executableOptions = executableNitroFile(systemInfo?.gpuSetting)
|
||||||
|
|
||||||
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
||||||
// Execute the binary
|
// Execute the binary
|
||||||
@ -402,9 +404,8 @@ const getCurrentNitroProcessInfo = (): NitroProcessInfo => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
runModel,
|
loadModel,
|
||||||
stopModel,
|
unloadModel,
|
||||||
killSubprocess,
|
|
||||||
dispose,
|
dispose,
|
||||||
getCurrentNitroProcessInfo,
|
getCurrentNitroProcessInfo,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,14 +1,14 @@
|
|||||||
# Jan inference plugin
|
# OpenAI Engine Extension
|
||||||
|
|
||||||
Created using Jan app example
|
Created using Jan extension example
|
||||||
|
|
||||||
# Create a Jan Plugin using Typescript
|
# Create a Jan Extension using Typescript
|
||||||
|
|
||||||
Use this template to bootstrap the creation of a TypeScript Jan plugin. 🚀
|
Use this template to bootstrap the creation of a TypeScript Jan extension. 🚀
|
||||||
|
|
||||||
## Create Your Own Plugin
|
## Create Your Own Extension
|
||||||
|
|
||||||
To create your own plugin, you can use this repository as a template! Just follow the below instructions:
|
To create your own extension, you can use this repository as a template! Just follow the below instructions:
|
||||||
|
|
||||||
1. Click the Use this template button at the top of the repository
|
1. Click the Use this template button at the top of the repository
|
||||||
2. Select Create a new repository
|
2. Select Create a new repository
|
||||||
@ -18,7 +18,7 @@ To create your own plugin, you can use this repository as a template! Just follo
|
|||||||
|
|
||||||
## Initial Setup
|
## Initial Setup
|
||||||
|
|
||||||
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your plugin.
|
After you've cloned the repository to your local machine or codespace, you'll need to perform some initial setup steps before you can develop your extension.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
>
|
>
|
||||||
@ -43,36 +43,37 @@ After you've cloned the repository to your local machine or codespace, you'll ne
|
|||||||
|
|
||||||
1. :white_check_mark: Check your artifact
|
1. :white_check_mark: Check your artifact
|
||||||
|
|
||||||
There will be a tgz file in your plugin directory now
|
There will be a tgz file in your extension directory now
|
||||||
|
|
||||||
## Update the Plugin Metadata
|
## Update the Extension Metadata
|
||||||
|
|
||||||
The [`package.json`](package.json) file defines metadata about your plugin, such as
|
The [`package.json`](package.json) file defines metadata about your extension, such as
|
||||||
plugin name, main entry, description and version.
|
extension name, main entry, description and version.
|
||||||
|
|
||||||
When you copy this repository, update `package.json` with the name, description for your plugin.
|
When you copy this repository, update `package.json` with the name, description for your extension.
|
||||||
|
|
||||||
## Update the Plugin Code
|
## Update the Extension Code
|
||||||
|
|
||||||
The [`src/`](./src/) directory is the heart of your plugin! This contains the
|
The [`src/`](./src/) directory is the heart of your extension! This contains the
|
||||||
source code that will be run when your plugin extension functions are invoked. You can replace the
|
source code that will be run when your extension functions are invoked. You can replace the
|
||||||
contents of this directory with your own code.
|
contents of this directory with your own code.
|
||||||
|
|
||||||
There are a few things to keep in mind when writing your plugin code:
|
There are a few things to keep in mind when writing your extension code:
|
||||||
|
|
||||||
- Most Jan Plugin Extension functions are processed asynchronously.
|
- Most Jan Extension functions are processed asynchronously.
|
||||||
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
|
In `index.ts`, you will see that the extension function will return a `Promise<any>`.
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
import { core } from "@janhq/core";
|
import { events, MessageEvent, MessageRequest } from '@janhq/core'
|
||||||
|
|
||||||
function onStart(): Promise<any> {
|
function onStart(): Promise<any> {
|
||||||
return core.invokePluginFunc(MODULE_PATH, "run", 0);
|
return events.on(MessageEvent.OnMessageSent, (data: MessageRequest) =>
|
||||||
|
this.inference(data)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
For more information about the Jan Plugin Core module, see the
|
For more information about the Jan Extension Core module, see the
|
||||||
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
|
[documentation](https://github.com/janhq/jan/blob/main/core/README.md).
|
||||||
|
|
||||||
So, what are you waiting for? Go ahead and start customizing your plugin!
|
So, what are you waiting for? Go ahead and start customizing your extension!
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
"description": "This extension enables OpenAI chat completion API calls",
|
"description": "This extension enables OpenAI chat completion API calls",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/module.js",
|
"module": "dist/module.js",
|
||||||
|
"engine": "openai",
|
||||||
"author": "Jan <service@jan.ai>",
|
"author": "Jan <service@jan.ai>",
|
||||||
"license": "AGPL-3.0",
|
"license": "AGPL-3.0",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
|
|||||||
@ -1,26 +0,0 @@
|
|||||||
declare const MODULE: string
|
|
||||||
declare const OPENAI_DOMAIN: string
|
|
||||||
|
|
||||||
declare interface EngineSettings {
|
|
||||||
full_url?: string
|
|
||||||
api_key?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
enum OpenAIChatCompletionModelName {
|
|
||||||
'gpt-3.5-turbo-instruct' = 'gpt-3.5-turbo-instruct',
|
|
||||||
'gpt-3.5-turbo-instruct-0914' = 'gpt-3.5-turbo-instruct-0914',
|
|
||||||
'gpt-4-1106-preview' = 'gpt-4-1106-preview',
|
|
||||||
'gpt-3.5-turbo-0613' = 'gpt-3.5-turbo-0613',
|
|
||||||
'gpt-3.5-turbo-0301' = 'gpt-3.5-turbo-0301',
|
|
||||||
'gpt-3.5-turbo' = 'gpt-3.5-turbo',
|
|
||||||
'gpt-3.5-turbo-16k-0613' = 'gpt-3.5-turbo-16k-0613',
|
|
||||||
'gpt-3.5-turbo-1106' = 'gpt-3.5-turbo-1106',
|
|
||||||
'gpt-4-vision-preview' = 'gpt-4-vision-preview',
|
|
||||||
'gpt-4' = 'gpt-4',
|
|
||||||
'gpt-4-0314' = 'gpt-4-0314',
|
|
||||||
'gpt-4-0613' = 'gpt-4-0613',
|
|
||||||
}
|
|
||||||
|
|
||||||
declare type OpenAIModel = Omit<Model, 'id'> & {
|
|
||||||
id: OpenAIChatCompletionModelName
|
|
||||||
}
|
|
||||||
@ -1,85 +0,0 @@
|
|||||||
import { ErrorCode } from '@janhq/core'
|
|
||||||
import { Observable } from 'rxjs'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
|
||||||
* @param engine - The engine settings to use for the inference.
|
|
||||||
* @param model - The model to use for the inference.
|
|
||||||
* @returns An Observable that emits the generated response as a string.
|
|
||||||
*/
|
|
||||||
export function requestInference(
|
|
||||||
recentMessages: any[],
|
|
||||||
engine: EngineSettings,
|
|
||||||
model: OpenAIModel,
|
|
||||||
controller?: AbortController
|
|
||||||
): Observable<string> {
|
|
||||||
return new Observable((subscriber) => {
|
|
||||||
let model_id: string = model.id
|
|
||||||
if (engine.full_url.includes(OPENAI_DOMAIN)) {
|
|
||||||
model_id = engine.full_url.split('/')[5]
|
|
||||||
}
|
|
||||||
const requestBody = JSON.stringify({
|
|
||||||
messages: recentMessages,
|
|
||||||
stream: true,
|
|
||||||
model: model_id,
|
|
||||||
...model.parameters,
|
|
||||||
})
|
|
||||||
fetch(`${engine.full_url}`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': model.parameters.stream
|
|
||||||
? 'text/event-stream'
|
|
||||||
: 'application/json',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
'Authorization': `Bearer ${engine.api_key}`,
|
|
||||||
'api-key': `${engine.api_key}`,
|
|
||||||
},
|
|
||||||
body: requestBody,
|
|
||||||
signal: controller?.signal,
|
|
||||||
})
|
|
||||||
.then(async (response) => {
|
|
||||||
if (!response.ok) {
|
|
||||||
const data = await response.json()
|
|
||||||
const error = {
|
|
||||||
message: data.error?.message ?? 'An error occurred.',
|
|
||||||
code: data.error?.code ?? ErrorCode.Unknown,
|
|
||||||
}
|
|
||||||
subscriber.error(error)
|
|
||||||
subscriber.complete()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if (model.parameters.stream === false) {
|
|
||||||
const data = await response.json()
|
|
||||||
subscriber.next(data.choices[0]?.message?.content ?? '')
|
|
||||||
} else {
|
|
||||||
const stream = response.body
|
|
||||||
const decoder = new TextDecoder('utf-8')
|
|
||||||
const reader = stream?.getReader()
|
|
||||||
let content = ''
|
|
||||||
|
|
||||||
while (true && reader) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const text = decoder.decode(value)
|
|
||||||
const lines = text.trim().split('\n')
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
|
|
||||||
const data = JSON.parse(line.replace('data: ', ''))
|
|
||||||
content += data.choices[0]?.delta?.content ?? ''
|
|
||||||
if (content.startsWith('assistant: ')) {
|
|
||||||
content = content.replace('assistant: ', '')
|
|
||||||
}
|
|
||||||
subscriber.next(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
subscriber.complete()
|
|
||||||
})
|
|
||||||
.catch((err) => subscriber.error(err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -5,75 +5,52 @@
|
|||||||
* @version 1.0.0
|
* @version 1.0.0
|
||||||
* @module inference-openai-extension/src/index
|
* @module inference-openai-extension/src/index
|
||||||
*/
|
*/
|
||||||
|
declare const ENGINE: string
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatCompletionRole,
|
|
||||||
ContentType,
|
|
||||||
MessageRequest,
|
|
||||||
MessageStatus,
|
|
||||||
ThreadContent,
|
|
||||||
ThreadMessage,
|
|
||||||
events,
|
events,
|
||||||
fs,
|
fs,
|
||||||
InferenceEngine,
|
|
||||||
BaseExtension,
|
|
||||||
MessageEvent,
|
|
||||||
MessageRequestType,
|
|
||||||
ModelEvent,
|
|
||||||
InferenceEvent,
|
|
||||||
AppConfigurationEventName,
|
AppConfigurationEventName,
|
||||||
joinPath,
|
joinPath,
|
||||||
|
RemoteOAIEngine,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulidx'
|
|
||||||
import { join } from 'path'
|
import { join } from 'path'
|
||||||
|
|
||||||
|
declare const COMPLETION_URL: string
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
export default class JanInferenceOpenAIExtension extends BaseExtension {
|
export default class JanInferenceOpenAIExtension extends RemoteOAIEngine {
|
||||||
private static readonly _engineDir = 'file://engines'
|
private static readonly _engineDir = 'file://engines'
|
||||||
private static readonly _engineMetadataFileName = 'openai.json'
|
private static readonly _engineMetadataFileName = `${ENGINE}.json`
|
||||||
|
|
||||||
private static _currentModel: OpenAIModel
|
private _engineSettings = {
|
||||||
|
full_url: COMPLETION_URL,
|
||||||
private static _engineSettings: EngineSettings = {
|
|
||||||
full_url: 'https://api.openai.com/v1/chat/completions',
|
|
||||||
api_key: 'sk-<your key here>',
|
api_key: 'sk-<your key here>',
|
||||||
}
|
}
|
||||||
|
|
||||||
controller = new AbortController()
|
inferenceUrl: string = COMPLETION_URL
|
||||||
isCancelled = false
|
provider: string = 'openai'
|
||||||
|
apiKey: string = ''
|
||||||
|
|
||||||
|
// TODO: Just use registerSettings from BaseExtension
|
||||||
|
// Remove these methods
|
||||||
/**
|
/**
|
||||||
* Subscribes to events emitted by the @janhq/core package.
|
* Subscribes to events emitted by the @janhq/core package.
|
||||||
*/
|
*/
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
|
super.onLoad()
|
||||||
|
|
||||||
if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) {
|
if (!(await fs.existsSync(JanInferenceOpenAIExtension._engineDir))) {
|
||||||
await fs
|
await fs
|
||||||
.mkdirSync(JanInferenceOpenAIExtension._engineDir)
|
.mkdirSync(JanInferenceOpenAIExtension._engineDir)
|
||||||
.catch((err) => console.debug(err))
|
.catch((err) => console.debug(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
|
this.writeDefaultEngineSettings()
|
||||||
|
|
||||||
// Events subscription
|
|
||||||
events.on(MessageEvent.OnMessageSent, (data) =>
|
|
||||||
JanInferenceOpenAIExtension.handleMessageRequest(data, this)
|
|
||||||
)
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: OpenAIModel) => {
|
|
||||||
JanInferenceOpenAIExtension.handleModelInit(model)
|
|
||||||
})
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: OpenAIModel) => {
|
|
||||||
JanInferenceOpenAIExtension.handleModelStop(model)
|
|
||||||
})
|
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () => {
|
|
||||||
JanInferenceOpenAIExtension.handleInferenceStopped(this)
|
|
||||||
})
|
|
||||||
|
|
||||||
const settingsFilePath = await joinPath([
|
const settingsFilePath = await joinPath([
|
||||||
JanInferenceOpenAIExtension._engineDir,
|
JanInferenceOpenAIExtension._engineDir,
|
||||||
@ -84,18 +61,12 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
|
|||||||
AppConfigurationEventName.OnConfigurationUpdate,
|
AppConfigurationEventName.OnConfigurationUpdate,
|
||||||
(settingsKey: string) => {
|
(settingsKey: string) => {
|
||||||
// Update settings on changes
|
// Update settings on changes
|
||||||
if (settingsKey === settingsFilePath)
|
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
|
||||||
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
async writeDefaultEngineSettings() {
|
||||||
* Stops the model inference.
|
|
||||||
*/
|
|
||||||
onUnload(): void {}
|
|
||||||
|
|
||||||
static async writeDefaultEngineSettings() {
|
|
||||||
try {
|
try {
|
||||||
const engineFile = join(
|
const engineFile = join(
|
||||||
JanInferenceOpenAIExtension._engineDir,
|
JanInferenceOpenAIExtension._engineDir,
|
||||||
@ -103,122 +74,18 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
|
|||||||
)
|
)
|
||||||
if (await fs.existsSync(engineFile)) {
|
if (await fs.existsSync(engineFile)) {
|
||||||
const engine = await fs.readFileSync(engineFile, 'utf-8')
|
const engine = await fs.readFileSync(engineFile, 'utf-8')
|
||||||
JanInferenceOpenAIExtension._engineSettings =
|
this._engineSettings =
|
||||||
typeof engine === 'object' ? engine : JSON.parse(engine)
|
typeof engine === 'object' ? engine : JSON.parse(engine)
|
||||||
|
this.inferenceUrl = this._engineSettings.full_url
|
||||||
|
this.apiKey = this._engineSettings.api_key
|
||||||
} else {
|
} else {
|
||||||
await fs.writeFileSync(
|
await fs.writeFileSync(
|
||||||
engineFile,
|
engineFile,
|
||||||
JSON.stringify(JanInferenceOpenAIExtension._engineSettings, null, 2)
|
JSON.stringify(this._engineSettings, null, 2)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private static async handleModelInit(model: OpenAIModel) {
|
|
||||||
if (model.engine !== InferenceEngine.openai) {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
JanInferenceOpenAIExtension._currentModel = model
|
|
||||||
JanInferenceOpenAIExtension.writeDefaultEngineSettings()
|
|
||||||
// Todo: Check model list with API key
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleModelStop(model: OpenAIModel) {
|
|
||||||
if (model.engine !== 'openai') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
events.emit(ModelEvent.OnModelStopped, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleInferenceStopped(
|
|
||||||
instance: JanInferenceOpenAIExtension
|
|
||||||
) {
|
|
||||||
instance.isCancelled = true
|
|
||||||
instance.controller?.abort()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles a new message request by making an inference request and emitting events.
|
|
||||||
* Function registered in event manager, should be static to avoid binding issues.
|
|
||||||
* Pass instance as a reference.
|
|
||||||
* @param {MessageRequest} data - The data for the new message request.
|
|
||||||
*/
|
|
||||||
private static async handleMessageRequest(
|
|
||||||
data: MessageRequest,
|
|
||||||
instance: JanInferenceOpenAIExtension
|
|
||||||
) {
|
|
||||||
if (data.model.engine !== 'openai') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timestamp = Date.now()
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: ulid(),
|
|
||||||
thread_id: data.threadId,
|
|
||||||
type: data.type,
|
|
||||||
assistant_id: data.assistantId,
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created: timestamp,
|
|
||||||
updated: timestamp,
|
|
||||||
object: 'thread.message',
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.type !== MessageRequestType.Summary) {
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
instance.isCancelled = false
|
|
||||||
instance.controller = new AbortController()
|
|
||||||
|
|
||||||
requestInference(
|
|
||||||
data?.messages ?? [],
|
|
||||||
this._engineSettings,
|
|
||||||
{
|
|
||||||
...JanInferenceOpenAIExtension._currentModel,
|
|
||||||
parameters: data.model.parameters,
|
|
||||||
},
|
|
||||||
instance.controller
|
|
||||||
).subscribe({
|
|
||||||
next: (content) => {
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: content.trim(),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
complete: async () => {
|
|
||||||
message.status = message.content.length
|
|
||||||
? MessageStatus.Ready
|
|
||||||
: MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
error: async (err) => {
|
|
||||||
if (instance.isCancelled || message.content.length > 0) {
|
|
||||||
message.status = MessageStatus.Stopped
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: 'An error occurred. ' + err.message,
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
message.status = MessageStatus.Error
|
|
||||||
message.error_code = err.code
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,8 +17,8 @@ module.exports = {
|
|||||||
},
|
},
|
||||||
plugins: [
|
plugins: [
|
||||||
new webpack.DefinePlugin({
|
new webpack.DefinePlugin({
|
||||||
MODULE: JSON.stringify(`${packageJson.name}/${packageJson.module}`),
|
ENGINE: JSON.stringify(packageJson.engine),
|
||||||
OPENAI_DOMAIN: JSON.stringify('openai.azure.com'),
|
COMPLETION_URL: JSON.stringify('https://api.openai.com/v1/chat/completions'),
|
||||||
}),
|
}),
|
||||||
],
|
],
|
||||||
output: {
|
output: {
|
||||||
|
|||||||
@ -1,5 +0,0 @@
|
|||||||
import { Model } from '@janhq/core'
|
|
||||||
|
|
||||||
declare interface EngineSettings {
|
|
||||||
base_url?: string
|
|
||||||
}
|
|
||||||
@ -1,63 +0,0 @@
|
|||||||
import { Observable } from 'rxjs'
|
|
||||||
import { EngineSettings } from '../@types/global'
|
|
||||||
import { Model } from '@janhq/core'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sends a request to the inference server to generate a response based on the recent messages.
|
|
||||||
* @param recentMessages - An array of recent messages to use as context for the inference.
|
|
||||||
* @param engine - The engine settings to use for the inference.
|
|
||||||
* @param model - The model to use for the inference.
|
|
||||||
* @returns An Observable that emits the generated response as a string.
|
|
||||||
*/
|
|
||||||
export function requestInference(
|
|
||||||
recentMessages: any[],
|
|
||||||
engine: EngineSettings,
|
|
||||||
model: Model,
|
|
||||||
controller?: AbortController
|
|
||||||
): Observable<string> {
|
|
||||||
return new Observable((subscriber) => {
|
|
||||||
const text_input = recentMessages.map((message) => message.text).join('\n')
|
|
||||||
const requestBody = JSON.stringify({
|
|
||||||
text_input: text_input,
|
|
||||||
max_tokens: 4096,
|
|
||||||
temperature: 0,
|
|
||||||
bad_words: '',
|
|
||||||
stop_words: '[DONE]',
|
|
||||||
stream: true,
|
|
||||||
})
|
|
||||||
fetch(`${engine.base_url}/v2/models/ensemble/generate_stream`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Accept': 'text/event-stream',
|
|
||||||
'Access-Control-Allow-Origin': '*',
|
|
||||||
},
|
|
||||||
body: requestBody,
|
|
||||||
signal: controller?.signal,
|
|
||||||
})
|
|
||||||
.then(async (response) => {
|
|
||||||
const stream = response.body
|
|
||||||
const decoder = new TextDecoder('utf-8')
|
|
||||||
const reader = stream?.getReader()
|
|
||||||
let content = ''
|
|
||||||
|
|
||||||
while (true && reader) {
|
|
||||||
const { done, value } = await reader.read()
|
|
||||||
if (done) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const text = decoder.decode(value)
|
|
||||||
const lines = text.trim().split('\n')
|
|
||||||
for (const line of lines) {
|
|
||||||
if (line.startsWith('data: ') && !line.includes('data: [DONE]')) {
|
|
||||||
const data = JSON.parse(line.replace('data: ', ''))
|
|
||||||
content += data.choices[0]?.delta?.content ?? ''
|
|
||||||
subscriber.next(content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
subscriber.complete()
|
|
||||||
})
|
|
||||||
.catch((err) => subscriber.error(err))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -7,212 +7,76 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import {
|
||||||
ChatCompletionRole,
|
AppConfigurationEventName,
|
||||||
ContentType,
|
|
||||||
MessageRequest,
|
|
||||||
MessageStatus,
|
|
||||||
ModelSettingParams,
|
|
||||||
ThreadContent,
|
|
||||||
ThreadMessage,
|
|
||||||
events,
|
events,
|
||||||
fs,
|
fs,
|
||||||
|
joinPath,
|
||||||
Model,
|
Model,
|
||||||
BaseExtension,
|
RemoteOAIEngine,
|
||||||
MessageEvent,
|
|
||||||
ModelEvent,
|
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { requestInference } from './helpers/sse'
|
|
||||||
import { ulid } from 'ulidx'
|
|
||||||
import { join } from 'path'
|
import { join } from 'path'
|
||||||
import { EngineSettings } from './@types/global'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
export default class JanInferenceTritonTrtLLMExtension extends BaseExtension {
|
export default class JanInferenceTritonTrtLLMExtension extends RemoteOAIEngine {
|
||||||
private static readonly _homeDir = 'file://engines'
|
private readonly _engineDir = 'file://engines'
|
||||||
private static readonly _engineMetadataFileName = 'triton_trtllm.json'
|
private readonly _engineMetadataFileName = 'triton_trtllm.json'
|
||||||
|
|
||||||
static _currentModel: Model
|
inferenceUrl: string = ''
|
||||||
|
provider: string = 'triton_trtllm'
|
||||||
|
apiKey: string = ''
|
||||||
|
|
||||||
static _engineSettings: EngineSettings = {
|
_engineSettings: {
|
||||||
base_url: '',
|
base_url: ''
|
||||||
|
api_key: ''
|
||||||
}
|
}
|
||||||
|
|
||||||
controller = new AbortController()
|
|
||||||
isCancelled = false
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Subscribes to events emitted by the @janhq/core package.
|
* Subscribes to events emitted by the @janhq/core package.
|
||||||
*/
|
*/
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
if (!(await fs.existsSync(JanInferenceTritonTrtLLMExtension._homeDir)))
|
super.onLoad()
|
||||||
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
if (!(await fs.existsSync(this._engineDir))) {
|
||||||
|
await fs.mkdirSync(this._engineDir).catch((err) => console.debug(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
this.writeDefaultEngineSettings()
|
||||||
|
|
||||||
|
const settingsFilePath = await joinPath([
|
||||||
|
this._engineDir,
|
||||||
|
this._engineMetadataFileName,
|
||||||
|
])
|
||||||
|
|
||||||
// Events subscription
|
// Events subscription
|
||||||
events.on(MessageEvent.OnMessageSent, (data) =>
|
events.on(
|
||||||
JanInferenceTritonTrtLLMExtension.handleMessageRequest(data, this)
|
AppConfigurationEventName.OnConfigurationUpdate,
|
||||||
|
(settingsKey: string) => {
|
||||||
|
// Update settings on changes
|
||||||
|
if (settingsKey === settingsFilePath) this.writeDefaultEngineSettings()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => {
|
|
||||||
JanInferenceTritonTrtLLMExtension.handleModelInit(model)
|
|
||||||
})
|
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => {
|
|
||||||
JanInferenceTritonTrtLLMExtension.handleModelStop(model)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
async writeDefaultEngineSettings() {
|
||||||
* Stops the model inference.
|
|
||||||
*/
|
|
||||||
onUnload(): void {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initializes the model with the specified file name.
|
|
||||||
* @param {string} modelId - The ID of the model to initialize.
|
|
||||||
* @returns {Promise<void>} A promise that resolves when the model is initialized.
|
|
||||||
*/
|
|
||||||
async initModel(
|
|
||||||
modelId: string,
|
|
||||||
settings?: ModelSettingParams
|
|
||||||
): Promise<void> {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
static async writeDefaultEngineSettings() {
|
|
||||||
try {
|
try {
|
||||||
const engine_json = join(
|
const engine_json = join(this._engineDir, this._engineMetadataFileName)
|
||||||
JanInferenceTritonTrtLLMExtension._homeDir,
|
|
||||||
JanInferenceTritonTrtLLMExtension._engineMetadataFileName
|
|
||||||
)
|
|
||||||
if (await fs.existsSync(engine_json)) {
|
if (await fs.existsSync(engine_json)) {
|
||||||
const engine = await fs.readFileSync(engine_json, 'utf-8')
|
const engine = await fs.readFileSync(engine_json, 'utf-8')
|
||||||
JanInferenceTritonTrtLLMExtension._engineSettings =
|
this._engineSettings =
|
||||||
typeof engine === 'object' ? engine : JSON.parse(engine)
|
typeof engine === 'object' ? engine : JSON.parse(engine)
|
||||||
|
this.inferenceUrl = this._engineSettings.base_url
|
||||||
|
this.apiKey = this._engineSettings.api_key
|
||||||
} else {
|
} else {
|
||||||
await fs.writeFileSync(
|
await fs.writeFileSync(
|
||||||
engine_json,
|
engine_json,
|
||||||
JSON.stringify(
|
JSON.stringify(this._engineSettings, null, 2)
|
||||||
JanInferenceTritonTrtLLMExtension._engineSettings,
|
|
||||||
null,
|
|
||||||
2
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err)
|
console.error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
|
||||||
* Stops the model.
|
|
||||||
* @returns {Promise<void>} A promise that resolves when the model is stopped.
|
|
||||||
*/
|
|
||||||
async stopModel(): Promise<void> {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stops streaming inference.
|
|
||||||
* @returns {Promise<void>} A promise that resolves when the streaming is stopped.
|
|
||||||
*/
|
|
||||||
async stopInference(): Promise<void> {
|
|
||||||
this.isCancelled = true
|
|
||||||
this.controller?.abort()
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleModelInit(model: Model) {
|
|
||||||
if (model.engine !== 'triton_trtllm') {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
JanInferenceTritonTrtLLMExtension._currentModel = model
|
|
||||||
JanInferenceTritonTrtLLMExtension.writeDefaultEngineSettings()
|
|
||||||
// Todo: Check model list with API key
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async handleModelStop(model: Model) {
|
|
||||||
if (model.engine !== 'triton_trtllm') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
events.emit(ModelEvent.OnModelStopped, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handles a new message request by making an inference request and emitting events.
|
|
||||||
* Function registered in event manager, should be static to avoid binding issues.
|
|
||||||
* Pass instance as a reference.
|
|
||||||
* @param {MessageRequest} data - The data for the new message request.
|
|
||||||
*/
|
|
||||||
private static async handleMessageRequest(
|
|
||||||
data: MessageRequest,
|
|
||||||
instance: JanInferenceTritonTrtLLMExtension
|
|
||||||
) {
|
|
||||||
if (data.model.engine !== 'triton_trtllm') {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timestamp = Date.now()
|
|
||||||
const message: ThreadMessage = {
|
|
||||||
id: ulid(),
|
|
||||||
thread_id: data.threadId,
|
|
||||||
assistant_id: data.assistantId,
|
|
||||||
role: ChatCompletionRole.Assistant,
|
|
||||||
content: [],
|
|
||||||
status: MessageStatus.Pending,
|
|
||||||
created: timestamp,
|
|
||||||
updated: timestamp,
|
|
||||||
object: 'thread.message',
|
|
||||||
}
|
|
||||||
events.emit(MessageEvent.OnMessageResponse, message)
|
|
||||||
|
|
||||||
instance.isCancelled = false
|
|
||||||
instance.controller = new AbortController()
|
|
||||||
|
|
||||||
requestInference(
|
|
||||||
data?.messages ?? [],
|
|
||||||
this._engineSettings,
|
|
||||||
{
|
|
||||||
...JanInferenceTritonTrtLLMExtension._currentModel,
|
|
||||||
parameters: data.model.parameters,
|
|
||||||
},
|
|
||||||
instance.controller
|
|
||||||
).subscribe({
|
|
||||||
next: (content) => {
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: content.trim(),
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
complete: async () => {
|
|
||||||
message.status = message.content.length
|
|
||||||
? MessageStatus.Ready
|
|
||||||
: MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
error: async (err) => {
|
|
||||||
if (instance.isCancelled || message.content.length) {
|
|
||||||
message.status = MessageStatus.Error
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const messageContent: ThreadContent = {
|
|
||||||
type: ContentType.Text,
|
|
||||||
text: {
|
|
||||||
value: 'An error occurred. ' + err.message,
|
|
||||||
annotations: [],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
message.content = [messageContent]
|
|
||||||
message.status = MessageStatus.Ready
|
|
||||||
events.emit(MessageEvent.OnMessageUpdate, message)
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,14 +43,14 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
private supportedPlatform = ['win32', 'linux']
|
private supportedPlatform = ['win32', 'linux']
|
||||||
private isUpdateAvailable = false
|
private isUpdateAvailable = false
|
||||||
|
|
||||||
compatibility() {
|
override compatibility() {
|
||||||
return COMPATIBILITY as unknown as Compatibility
|
return COMPATIBILITY as unknown as Compatibility
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* models implemented by the extension
|
* models implemented by the extension
|
||||||
* define pre-populated models
|
* define pre-populated models
|
||||||
*/
|
*/
|
||||||
async models(): Promise<Model[]> {
|
override async models(): Promise<Model[]> {
|
||||||
if ((await this.installationState()) === 'Installed')
|
if ((await this.installationState()) === 'Installed')
|
||||||
return models as unknown as Model[]
|
return models as unknown as Model[]
|
||||||
return []
|
return []
|
||||||
@ -160,11 +160,11 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||||
}
|
}
|
||||||
|
|
||||||
async onModelInit(model: Model): Promise<void> {
|
override async loadModel(model: Model): Promise<void> {
|
||||||
if (model.engine !== this.provider) return
|
if (model.engine !== this.provider) return
|
||||||
|
|
||||||
if ((await this.installationState()) === 'Installed')
|
if ((await this.installationState()) === 'Installed')
|
||||||
return super.onModelInit(model)
|
return super.loadModel(model)
|
||||||
else {
|
else {
|
||||||
events.emit(ModelEvent.OnModelFail, {
|
events.emit(ModelEvent.OnModelFail, {
|
||||||
...model,
|
...model,
|
||||||
@ -175,7 +175,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updatable() {
|
override updatable() {
|
||||||
return this.isUpdateAvailable
|
return this.isUpdateAvailable
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,8 +241,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled'
|
return (await fs.existsSync(enginePath)) ? 'Installed' : 'NotInstalled'
|
||||||
}
|
}
|
||||||
|
|
||||||
override onInferenceStopped() {
|
override stopInference() {
|
||||||
if (!this.isRunning) return
|
|
||||||
showToast(
|
showToast(
|
||||||
'Unable to Stop Inference',
|
'Unable to Stop Inference',
|
||||||
'The model does not support stopping inference.'
|
'The model does not support stopping inference.'
|
||||||
@ -250,8 +249,8 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
|
|||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
inference(data: MessageRequest): void {
|
override inference(data: MessageRequest): void {
|
||||||
if (!this.isRunning) return
|
if (!this.loadedModel) return
|
||||||
// TensorRT LLM Extension supports streaming only
|
// TensorRT LLM Extension supports streaming only
|
||||||
if (data.model) data.model.parameters.stream = true
|
if (data.model) data.model.parameters.stream = true
|
||||||
super.inference(data)
|
super.inference(data)
|
||||||
|
|||||||
@ -1,27 +1,26 @@
|
|||||||
{
|
{
|
||||||
"sources": [
|
"sources": [
|
||||||
{
|
{
|
||||||
"url": "https://groq.com"
|
"url": "https://groq.com"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"id": "llama2-70b-4096",
|
"id": "llama2-70b-4096",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"name": "Groq Llama 2 70b",
|
"name": "Groq Llama 2 70b",
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
"description": "Groq Llama 2 70b with supercharged speed!",
|
"description": "Groq Llama 2 70b with supercharged speed!",
|
||||||
"format": "api",
|
"format": "api",
|
||||||
"settings": {},
|
"settings": {},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"stop": null,
|
"stop": null,
|
||||||
"stream": true
|
"stream": true
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"author": "Meta",
|
"author": "Meta",
|
||||||
"tags": ["General", "Big Context Length"]
|
"tags": ["General", "Big Context Length"]
|
||||||
},
|
},
|
||||||
"engine": "groq"
|
"engine": "groq"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,27 +1,26 @@
|
|||||||
{
|
{
|
||||||
"sources": [
|
"sources": [
|
||||||
{
|
{
|
||||||
"url": "https://groq.com"
|
"url": "https://groq.com"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"id": "mixtral-8x7b-32768",
|
"id": "mixtral-8x7b-32768",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"name": "Groq Mixtral 8x7b Instruct",
|
"name": "Groq Mixtral 8x7b Instruct",
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
|
"description": "Groq Mixtral 8x7b Instruct is Mixtral with supercharged speed!",
|
||||||
"format": "api",
|
"format": "api",
|
||||||
"settings": {},
|
"settings": {},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
"top_p": 1,
|
"top_p": 1,
|
||||||
"stop": null,
|
"stop": null,
|
||||||
"stream": true
|
"stream": true
|
||||||
},
|
},
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"author": "Mistral",
|
"author": "Mistral",
|
||||||
"tags": ["General", "Big Context Length"]
|
"tags": ["General", "Big Context Length"]
|
||||||
},
|
},
|
||||||
"engine": "groq"
|
"engine": "groq"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -75,12 +75,14 @@ const DropdownListSidebar = ({
|
|||||||
|
|
||||||
// TODO: Update filter condition for the local model
|
// TODO: Update filter condition for the local model
|
||||||
const localModel = downloadedModels.filter(
|
const localModel = downloadedModels.filter(
|
||||||
(model) => model.engine !== InferenceEngine.openai
|
(model) =>
|
||||||
|
model.engine === InferenceEngine.nitro ||
|
||||||
|
model.engine === InferenceEngine.nitro_tensorrt_llm
|
||||||
)
|
)
|
||||||
const remoteModel = downloadedModels.filter(
|
const remoteModel = downloadedModels.filter(
|
||||||
(model) =>
|
(model) =>
|
||||||
model.engine === InferenceEngine.openai ||
|
model.engine !== InferenceEngine.nitro &&
|
||||||
model.engine === InferenceEngine.groq
|
model.engine !== InferenceEngine.nitro_tensorrt_llm
|
||||||
)
|
)
|
||||||
|
|
||||||
const modelOptions = isTabActive === 0 ? localModel : remoteModel
|
const modelOptions = isTabActive === 0 ? localModel : remoteModel
|
||||||
|
|||||||
@ -48,9 +48,8 @@ export default function RowModel(props: RowModelProps) {
|
|||||||
const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom)
|
const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom)
|
||||||
|
|
||||||
const isRemoteModel =
|
const isRemoteModel =
|
||||||
props.data.engine === InferenceEngine.openai ||
|
props.data.engine !== InferenceEngine.nitro &&
|
||||||
props.data.engine === InferenceEngine.groq ||
|
props.data.engine !== InferenceEngine.nitro_tensorrt_llm
|
||||||
props.data.engine === InferenceEngine.triton_trtllm
|
|
||||||
|
|
||||||
const onModelActionClick = (modelId: string) => {
|
const onModelActionClick = (modelId: string) => {
|
||||||
if (activeModel && activeModel.id === modelId) {
|
if (activeModel && activeModel.id === modelId) {
|
||||||
|
|||||||
@ -8,7 +8,6 @@ export const isCoreExtensionInstalled = () => {
|
|||||||
if (!extensionManager.get(ExtensionTypeEnum.Conversational)) {
|
if (!extensionManager.get(ExtensionTypeEnum.Conversational)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if (!extensionManager.get(ExtensionTypeEnum.Inference)) return false
|
|
||||||
if (!extensionManager.get(ExtensionTypeEnum.Model)) {
|
if (!extensionManager.get(ExtensionTypeEnum.Model)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -22,7 +21,6 @@ export const setupBaseExtensions = async () => {
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
!extensionManager.get(ExtensionTypeEnum.Conversational) ||
|
!extensionManager.get(ExtensionTypeEnum.Conversational) ||
|
||||||
!extensionManager.get(ExtensionTypeEnum.Inference) ||
|
|
||||||
!extensionManager.get(ExtensionTypeEnum.Model)
|
!extensionManager.get(ExtensionTypeEnum.Model)
|
||||||
) {
|
) {
|
||||||
const installed = await extensionManager.install(baseExtensions)
|
const installed = await extensionManager.install(baseExtensions)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user