chore: load, unload model and inference synchronously

This commit is contained in:
Louis 2024-03-22 22:29:14 +07:00
parent 1ad794c8a7
commit 9551996e34
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
16 changed files with 226 additions and 173 deletions

View File

@ -1,46 +0,0 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types'
import { MessageRequest, Model, ModelEvent } from '../../types'
import { EngineManager } from './EngineManager'
/**
* Base AIEngine
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The model folder
modelFolder: string = 'models'
/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
this.prePopulateModels()
}
/**
* Defines models
*/
models(): Promise<Model[]> {
return Promise.resolve([])
}
/**
* On extension load, subscribe to events.
* Registers AI Engines
*/
onLoad() {
this.prePopulateModels()
registerEngine() {
EngineManager.instance()?.register(this)
}
/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}
/*
* Inference request
*/
inference(data: MessageRequest) {}
/**
* Stop inference
*/
stopInference() {}
/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()

View File

@ -0,0 +1,34 @@
import { log } from '../../core'
import { AIEngine } from './AIEngine'
/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()
/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}
/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}
static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
}
}
/**
* The singleton instance of the ExtensionManager.
*/

View File

@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
async loadModel(model: Model) {
override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)
if (res?.error) {
events.emit(ModelEvent.OnModelFail, {
...model,
error: res.error,
})
return
events.emit(ModelEvent.OnModelFail, { error: res.error })
return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine?.toString() !== this.provider) return
this.loadedModel = undefined
override async unloadModel(model?: Model): Promise<void> {
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}

View File

@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
onLoad() {
override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
onUnload(): void {}
override onUnload(): void {}
/*
* Inference request
*/
inference(data: MessageRequest) {
override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return
const timestamp = Date.now()
@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
stopInference() {
override stopInference() {
this.isCancelled = true
this.controller?.abort()
}

View File

@ -0,0 +1,26 @@
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
export * from './EngineManager'

View File

@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/**
* Base AI Engines.
*/
export * from './ai-engines'
export * from './engines'

View File

@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}
override unloadModel(model: Model): void {
super.unloadModel(model)
if (model.engine && model.engine !== this.provider) return
override async unloadModel(model?: Model) {
if (model?.engine && model.engine !== this.provider) return
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
return super.unloadModel(model)
}
}

View File

@ -8,26 +8,17 @@ import {
ExtensionTypeEnum,
MessageStatus,
MessageRequest,
Model,
ConversationalExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
Thread,
ModelInitFailed,
EngineManager,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
import {
activeModelAtom,
loadModelErrorAtom,
stateModelAtom,
} from '@/hooks/useActiveModel'
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
import { toaster } from '../Toast'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { extensionManager } from '@/extension'
import {
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setLoadModelError = useSetAtom(loadModelErrorAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage]
)
const onModelReady = useCallback(
(model: Model) => {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
},
[setActiveModel, setStateModel]
)
const onModelStopped = useCallback(() => {
setTimeout(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, 500)
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
}, [setActiveModel, setStateModel])
const onModelInitFailed = useCallback(
(res: ModelInitFailed) => {
console.error('Failed to load model: ', res.error.message)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.id,
}))
setLoadModelError(res.error.message)
setQueuedMessage(false)
},
[setStateModel, setQueuedMessage, setLoadModelError]
)
const updateThreadTitle = useCallback(
(message: ThreadMessage) => {
// Update only when it's finished
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// 2. Update the title with the result of the inference
setTimeout(() => {
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
}, 1000)
}
}
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.on(ModelEvent.OnModelReady, onModelReady)
events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped)
}
}, [
onNewMessageResponse,
onMessageResponseUpdate,
onModelReady,
onModelInitFailed,
onModelStopped,
])
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
useEffect(() => {
return () => {
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.off(ModelEvent.OnModelStopped, onModelStopped)
}
}, [onNewMessageResponse, onMessageResponseUpdate])
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
return <Fragment>{children}</Fragment>
}

View File

@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import Extension from './Extension'
@ -8,14 +8,26 @@ import Extension from './Extension'
* Manages the registration and retrieval of extensions.
*/
export class ExtensionManager {
// Registered extensions
private extensions = new Map<string, BaseExtension>()
// Registered inference engines
private engines = new Map<string, AIEngine>()
/**
* Registers an extension.
* @param extension - The extension to register.
*/
register<T extends BaseExtension>(name: string, extension: T) {
this.extensions.set(extension.type() ?? name, extension)
// Register AI Engines
if ('provider' in extension && typeof extension.provider === 'string') {
this.engines.set(
extension.provider as unknown as string,
extension as unknown as AIEngine
)
}
}
/**
@ -29,6 +41,15 @@ export class ExtensionManager {
return this.extensions.get(type) as T | undefined
}
/**
* Retrieves a extension by its type.
* @param engine - The engine name to retrieve.
* @returns The extension, if found.
*/
getEngine<T extends AIEngine>(engine: string): T | undefined {
return this.engines.get(engine) as T | undefined
}
/**
* Loads all registered extension.
*/

View File

@ -1,12 +1,13 @@
import { useCallback, useEffect, useRef } from 'react'
import { events, Model, ModelEvent } from '@janhq/core'
import { EngineManager, Model } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast'
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
import { extensionManager } from '@/extension'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -38,19 +39,13 @@ export function useActiveModel() {
(stateModel.model === modelId && stateModel.loading)
) {
console.debug(`Model ${modelId} is already initialized. Ignore..`)
return
return Promise.resolve()
}
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
// Switch between engines
if (model && activeModel && activeModel.engine !== model.engine) {
stopModel()
// TODO: Refactor inference provider would address this
await new Promise((res) => setTimeout(res, 1000))
}
await stopModel().catch()
// TODO: incase we have multiple assistants, the configuration will be from assistant
setLoadModelError(undefined)
setActiveModel(undefined)
@ -68,7 +63,8 @@ export function useActiveModel() {
loading: false,
model: '',
}))
return
return Promise.reject(`Model ${modelId} not found!`)
}
/// Apply thread model settings
@ -83,15 +79,52 @@ export function useActiveModel() {
}
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
events.emit(ModelEvent.OnModelInit, model)
const engine = EngineManager.instance()?.get(model.engine)
return engine
?.loadModel(model)
.then(() => {
setActiveModel(model)
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
})
.catch((error) => {
console.error('Failed to load model: ', error)
setStateModel(() => ({
state: 'start',
loading: false,
model: model.id,
}))
toaster({
title: 'Failed!',
description: `Model ${model.id} failed to start.`,
type: 'success',
})
setLoadModelError(error)
})
}
const stopModel = useCallback(async () => {
if (activeModel) {
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
events.emit(ModelEvent.OnModelStop, activeModel)
const engine = EngineManager.instance()?.get(activeModel.engine)
await engine
?.unloadModel(activeModel)
.catch()
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
})
}
}, [activeModel, setStateModel])
}, [activeModel, setActiveModel, setStateModel])
return { activeModel, startModel, stopModel, stateModel }
}

View File

@ -11,13 +11,12 @@ import {
ExtensionTypeEnum,
Thread,
ThreadMessage,
events,
Model,
ConversationalExtension,
MessageEvent,
InferenceEngine,
ChatCompletionMessageContentType,
AssistantTool,
EngineManager,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -65,7 +64,6 @@ export default function useSendChatMessage() {
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const loadModelFailed = useAtomValue(loadModelErrorAtom)
const modelRef = useRef<Model | undefined>()
@ -78,6 +76,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const selectedModelRef = useRef<Model | undefined>()
@ -142,10 +141,7 @@ export default function useSendChatMessage() {
activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
startModel(modelId)
await waitForModelStarting(modelId)
setQueuedMessage(false)
await startModel(modelId)
}
setIsGeneratingResponse(true)
if (currentMessage.role !== ChatCompletionRole.User) {
@ -160,7 +156,10 @@ export default function useSendChatMessage() {
)
}
}
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
}
const sendChatMessage = async (message: string) => {
@ -364,30 +363,20 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
startModel(modelId)
await waitForModelStarting(modelId)
await startModel(modelId)
setQueuedMessage(false)
}
setIsGeneratingResponse(true)
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? modelRequest.engine ?? ''
)
engine?.inference(messageRequest)
setReloadModel(false)
setEngineParamsUpdate(false)
}
const waitForModelStarting = async (modelId: string) => {
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
await waitForModelStarting(modelId)
resolve()
} else {
resolve()
}
}, 200)
})
}
return {
sendChatMessage,
resendChatMessage,

View File

@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
</p>
<ModalTroubleShooting />
</div>
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
) : loadModelError &&
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
<div
key={message.id}
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"

View File

@ -1,3 +1,5 @@
import { EngineManager } from '@janhq/core'
import { appService } from './appService'
import { EventEmitter } from './eventsService'
import { restAPI } from './restService'
@ -12,6 +14,7 @@ export const setupCoreServices = () => {
if (!window.core) {
window.core = {
events: new EventEmitter(),
engineManager: new EngineManager(),
api: {
...(window.electronAPI ? window.electronAPI : restAPI),
...appService,