chore: load, unload model and inference synchronously
This commit is contained in:
parent
1ad794c8a7
commit
9551996e34
@ -1,46 +0,0 @@
|
|||||||
import { events } from '../../events'
|
|
||||||
import { Model, ModelEvent } from '../../types'
|
|
||||||
import { OAIEngine } from './OAIEngine'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Base OAI Remote Inference Provider
|
|
||||||
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
|
||||||
*/
|
|
||||||
export abstract class RemoteOAIEngine extends OAIEngine {
|
|
||||||
// The inference engine
|
|
||||||
abstract apiKey: string
|
|
||||||
/**
|
|
||||||
* On extension load, subscribe to events.
|
|
||||||
*/
|
|
||||||
onLoad() {
|
|
||||||
super.onLoad()
|
|
||||||
// These events are applicable to local inference providers
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Load the model.
|
|
||||||
*/
|
|
||||||
async loadModel(model: Model) {
|
|
||||||
if (model.engine.toString() !== this.provider) return
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Stops the model.
|
|
||||||
*/
|
|
||||||
unloadModel(model: Model) {
|
|
||||||
if (model.engine && model.engine.toString() !== this.provider) return
|
|
||||||
events.emit(ModelEvent.OnModelStopped, {})
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Headers for the inference request
|
|
||||||
*/
|
|
||||||
override headers(): HeadersInit {
|
|
||||||
return {
|
|
||||||
'Authorization': `Bearer ${this.apiKey}`,
|
|
||||||
'api-key': `${this.apiKey}`,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
|
|||||||
import { events } from '../../events'
|
import { events } from '../../events'
|
||||||
import { BaseExtension } from '../../extension'
|
import { BaseExtension } from '../../extension'
|
||||||
import { fs } from '../../fs'
|
import { fs } from '../../fs'
|
||||||
import { Model, ModelEvent } from '../../types'
|
import { MessageRequest, Model, ModelEvent } from '../../types'
|
||||||
|
import { EngineManager } from './EngineManager'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base AIEngine
|
* Base AIEngine
|
||||||
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
|
|||||||
export abstract class AIEngine extends BaseExtension {
|
export abstract class AIEngine extends BaseExtension {
|
||||||
// The inference engine
|
// The inference engine
|
||||||
abstract provider: string
|
abstract provider: string
|
||||||
// The model folder
|
|
||||||
modelFolder: string = 'models'
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* On extension load, subscribe to events.
|
||||||
|
*/
|
||||||
|
override onLoad() {
|
||||||
|
this.registerEngine()
|
||||||
|
|
||||||
|
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||||
|
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||||
|
|
||||||
|
this.prePopulateModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines models
|
||||||
|
*/
|
||||||
models(): Promise<Model[]> {
|
models(): Promise<Model[]> {
|
||||||
return Promise.resolve([])
|
return Promise.resolve([])
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* Registers AI Engines
|
||||||
*/
|
*/
|
||||||
onLoad() {
|
registerEngine() {
|
||||||
this.prePopulateModels()
|
EngineManager.instance()?.register(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads the model.
|
||||||
|
*/
|
||||||
|
async loadModel(model: Model): Promise<any> {
|
||||||
|
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
||||||
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
|
return Promise.resolve()
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* Stops the model.
|
||||||
|
*/
|
||||||
|
async unloadModel(model?: Model): Promise<any> {
|
||||||
|
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
|
||||||
|
events.emit(ModelEvent.OnModelStopped, model ?? {})
|
||||||
|
return Promise.resolve()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Inference request
|
||||||
|
*/
|
||||||
|
inference(data: MessageRequest) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop inference
|
||||||
|
*/
|
||||||
|
stopInference() {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pre-populate models to App Data Folder
|
* Pre-populate models to App Data Folder
|
||||||
*/
|
*/
|
||||||
prePopulateModels(): Promise<void> {
|
prePopulateModels(): Promise<void> {
|
||||||
|
const modelFolder = 'models'
|
||||||
return this.models().then((models) => {
|
return this.models().then((models) => {
|
||||||
const prePoluateOperations = models.map((model) =>
|
const prePoluateOperations = models.map((model) =>
|
||||||
getJanDataFolderPath()
|
getJanDataFolderPath()
|
||||||
.then((janDataFolder) =>
|
.then((janDataFolder) =>
|
||||||
// Attempt to create the model folder
|
// Attempt to create the model folder
|
||||||
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
|
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
|
||||||
fs
|
fs
|
||||||
.mkdir(path)
|
.mkdir(path)
|
||||||
.catch()
|
.catch()
|
||||||
34
core/src/extensions/engines/EngineManager.ts
Normal file
34
core/src/extensions/engines/EngineManager.ts
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import { log } from '../../core'
|
||||||
|
import { AIEngine } from './AIEngine'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Manages the registration and retrieval of inference engines.
|
||||||
|
*/
|
||||||
|
export class EngineManager {
|
||||||
|
public engines = new Map<string, AIEngine>()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers an engine.
|
||||||
|
* @param engine - The engine to register.
|
||||||
|
*/
|
||||||
|
register<T extends AIEngine>(engine: T) {
|
||||||
|
this.engines.set(engine.provider, engine)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves a engine by provider.
|
||||||
|
* @param provider - The name of the engine to retrieve.
|
||||||
|
* @returns The engine, if found.
|
||||||
|
*/
|
||||||
|
get<T extends AIEngine>(provider: string): T | undefined {
|
||||||
|
return this.engines.get(provider) as T | undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
static instance(): EngineManager | undefined {
|
||||||
|
return window.core?.engineManager as EngineManager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The singleton instance of the ExtensionManager.
|
||||||
|
*/
|
||||||
@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* On extension load, subscribe to events.
|
||||||
*/
|
*/
|
||||||
onLoad() {
|
override onLoad() {
|
||||||
super.onLoad()
|
super.onLoad()
|
||||||
// These events are applicable to local inference providers
|
// These events are applicable to local inference providers
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||||
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
/**
|
/**
|
||||||
* Load the model.
|
* Load the model.
|
||||||
*/
|
*/
|
||||||
async loadModel(model: Model) {
|
override async loadModel(model: Model): Promise<void> {
|
||||||
if (model.engine.toString() !== this.provider) return
|
if (model.engine.toString() !== this.provider) return
|
||||||
|
const modelFolderName = 'models'
|
||||||
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
|
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
|
||||||
const systemInfo = await systemInformation()
|
const systemInfo = await systemInformation()
|
||||||
const res = await executeOnMain(
|
const res = await executeOnMain(
|
||||||
this.nodeModule,
|
this.nodeModule,
|
||||||
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (res?.error) {
|
if (res?.error) {
|
||||||
events.emit(ModelEvent.OnModelFail, {
|
events.emit(ModelEvent.OnModelFail, { error: res.error })
|
||||||
...model,
|
return Promise.reject(res.error)
|
||||||
error: res.error,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
} else {
|
} else {
|
||||||
this.loadedModel = model
|
this.loadedModel = model
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
events.emit(ModelEvent.OnModelReady, model)
|
||||||
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* Stops the model.
|
* Stops the model.
|
||||||
*/
|
*/
|
||||||
unloadModel(model: Model) {
|
override async unloadModel(model?: Model): Promise<void> {
|
||||||
if (model.engine && model.engine?.toString() !== this.provider) return
|
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
|
||||||
this.loadedModel = undefined
|
|
||||||
|
|
||||||
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
this.loadedModel = undefined
|
||||||
|
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||||
events.emit(ModelEvent.OnModelStopped, {})
|
events.emit(ModelEvent.OnModelStopped, {})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
/**
|
/**
|
||||||
* On extension load, subscribe to events.
|
* On extension load, subscribe to events.
|
||||||
*/
|
*/
|
||||||
onLoad() {
|
override onLoad() {
|
||||||
super.onLoad()
|
super.onLoad()
|
||||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
||||||
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
|
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
|
||||||
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
/**
|
/**
|
||||||
* On extension unload
|
* On extension unload
|
||||||
*/
|
*/
|
||||||
onUnload(): void {}
|
override onUnload(): void {}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Inference request
|
* Inference request
|
||||||
*/
|
*/
|
||||||
inference(data: MessageRequest) {
|
override inference(data: MessageRequest) {
|
||||||
if (data.model?.engine?.toString() !== this.provider) return
|
if (data.model?.engine?.toString() !== this.provider) return
|
||||||
|
|
||||||
const timestamp = Date.now()
|
const timestamp = Date.now()
|
||||||
@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
|
|||||||
/**
|
/**
|
||||||
* Stops the inference.
|
* Stops the inference.
|
||||||
*/
|
*/
|
||||||
stopInference() {
|
override stopInference() {
|
||||||
this.isCancelled = true
|
this.isCancelled = true
|
||||||
this.controller?.abort()
|
this.controller?.abort()
|
||||||
}
|
}
|
||||||
26
core/src/extensions/engines/RemoteOAIEngine.ts
Normal file
26
core/src/extensions/engines/RemoteOAIEngine.ts
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import { OAIEngine } from './OAIEngine'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base OAI Remote Inference Provider
|
||||||
|
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||||
|
*/
|
||||||
|
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||||
|
// The inference engine
|
||||||
|
abstract apiKey: string
|
||||||
|
/**
|
||||||
|
* On extension load, subscribe to events.
|
||||||
|
*/
|
||||||
|
override onLoad() {
|
||||||
|
super.onLoad()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Headers for the inference request
|
||||||
|
*/
|
||||||
|
override headers(): HeadersInit {
|
||||||
|
return {
|
||||||
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
|
'api-key': `${this.apiKey}`,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,3 +2,4 @@ export * from './AIEngine'
|
|||||||
export * from './OAIEngine'
|
export * from './OAIEngine'
|
||||||
export * from './LocalOAIEngine'
|
export * from './LocalOAIEngine'
|
||||||
export * from './RemoteOAIEngine'
|
export * from './RemoteOAIEngine'
|
||||||
|
export * from './EngineManager'
|
||||||
@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
|
|||||||
/**
|
/**
|
||||||
* Base AI Engines.
|
* Base AI Engines.
|
||||||
*/
|
*/
|
||||||
export * from './ai-engines'
|
export * from './engines'
|
||||||
|
|||||||
@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
|||||||
return super.loadModel(model)
|
return super.loadModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
override unloadModel(model: Model): void {
|
override async unloadModel(model?: Model) {
|
||||||
super.unloadModel(model)
|
if (model?.engine && model.engine !== this.provider) return
|
||||||
|
|
||||||
if (model.engine && model.engine !== this.provider) return
|
|
||||||
|
|
||||||
// stop the periocally health check
|
// stop the periocally health check
|
||||||
if (this.getNitroProcesHealthIntervalId) {
|
if (this.getNitroProcesHealthIntervalId) {
|
||||||
clearInterval(this.getNitroProcesHealthIntervalId)
|
clearInterval(this.getNitroProcesHealthIntervalId)
|
||||||
this.getNitroProcesHealthIntervalId = undefined
|
this.getNitroProcesHealthIntervalId = undefined
|
||||||
}
|
}
|
||||||
|
return super.unloadModel(model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,26 +8,17 @@ import {
|
|||||||
ExtensionTypeEnum,
|
ExtensionTypeEnum,
|
||||||
MessageStatus,
|
MessageStatus,
|
||||||
MessageRequest,
|
MessageRequest,
|
||||||
Model,
|
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
MessageEvent,
|
MessageEvent,
|
||||||
MessageRequestType,
|
MessageRequestType,
|
||||||
ModelEvent,
|
ModelEvent,
|
||||||
Thread,
|
Thread,
|
||||||
ModelInitFailed,
|
EngineManager,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
|
|
||||||
import {
|
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||||
activeModelAtom,
|
|
||||||
loadModelErrorAtom,
|
|
||||||
stateModelAtom,
|
|
||||||
} from '@/hooks/useActiveModel'
|
|
||||||
|
|
||||||
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
|
|
||||||
|
|
||||||
import { toaster } from '../Toast'
|
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import {
|
||||||
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
const activeModel = useAtomValue(activeModelAtom)
|
const activeModel = useAtomValue(activeModelAtom)
|
||||||
const setActiveModel = useSetAtom(activeModelAtom)
|
const setActiveModel = useSetAtom(activeModelAtom)
|
||||||
const setStateModel = useSetAtom(stateModelAtom)
|
const setStateModel = useSetAtom(stateModelAtom)
|
||||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
|
||||||
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
|
||||||
|
|
||||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||||
const threads = useAtomValue(threadsAtom)
|
const threads = useAtomValue(threadsAtom)
|
||||||
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
[addNewMessage]
|
[addNewMessage]
|
||||||
)
|
)
|
||||||
|
|
||||||
const onModelReady = useCallback(
|
|
||||||
(model: Model) => {
|
|
||||||
setActiveModel(model)
|
|
||||||
toaster({
|
|
||||||
title: 'Success!',
|
|
||||||
description: `Model ${model.id} has been started.`,
|
|
||||||
type: 'success',
|
|
||||||
})
|
|
||||||
setStateModel(() => ({
|
|
||||||
state: 'stop',
|
|
||||||
loading: false,
|
|
||||||
model: model.id,
|
|
||||||
}))
|
|
||||||
},
|
|
||||||
[setActiveModel, setStateModel]
|
|
||||||
)
|
|
||||||
|
|
||||||
const onModelStopped = useCallback(() => {
|
const onModelStopped = useCallback(() => {
|
||||||
setTimeout(() => {
|
setActiveModel(undefined)
|
||||||
setActiveModel(undefined)
|
setStateModel({ state: 'start', loading: false, model: '' })
|
||||||
setStateModel({ state: 'start', loading: false, model: '' })
|
|
||||||
}, 500)
|
|
||||||
}, [setActiveModel, setStateModel])
|
}, [setActiveModel, setStateModel])
|
||||||
|
|
||||||
const onModelInitFailed = useCallback(
|
|
||||||
(res: ModelInitFailed) => {
|
|
||||||
console.error('Failed to load model: ', res.error.message)
|
|
||||||
setStateModel(() => ({
|
|
||||||
state: 'start',
|
|
||||||
loading: false,
|
|
||||||
model: res.id,
|
|
||||||
}))
|
|
||||||
setLoadModelError(res.error.message)
|
|
||||||
setQueuedMessage(false)
|
|
||||||
},
|
|
||||||
[setStateModel, setQueuedMessage, setLoadModelError]
|
|
||||||
)
|
|
||||||
|
|
||||||
const updateThreadTitle = useCallback(
|
const updateThreadTitle = useCallback(
|
||||||
(message: ThreadMessage) => {
|
(message: ThreadMessage) => {
|
||||||
// Update only when it's finished
|
// Update only when it's finished
|
||||||
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
|
|
||||||
// 2. Update the title with the result of the inference
|
// 2. Update the title with the result of the inference
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
const engine = EngineManager.instance()?.get(
|
||||||
|
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
||||||
|
)
|
||||||
|
engine?.inference(messageRequest)
|
||||||
}, 1000)
|
}, 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
if (window.core?.events) {
|
if (window.core?.events) {
|
||||||
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||||
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||||
events.on(ModelEvent.OnModelReady, onModelReady)
|
|
||||||
events.on(ModelEvent.OnModelFail, onModelInitFailed)
|
|
||||||
events.on(ModelEvent.OnModelStopped, onModelStopped)
|
events.on(ModelEvent.OnModelStopped, onModelStopped)
|
||||||
}
|
}
|
||||||
}, [
|
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||||
onNewMessageResponse,
|
|
||||||
onMessageResponseUpdate,
|
|
||||||
onModelReady,
|
|
||||||
onModelInitFailed,
|
|
||||||
onModelStopped,
|
|
||||||
])
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => {
|
return () => {
|
||||||
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||||
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||||
|
events.off(ModelEvent.OnModelStopped, onModelStopped)
|
||||||
}
|
}
|
||||||
}, [onNewMessageResponse, onMessageResponseUpdate])
|
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||||
return <Fragment>{children}</Fragment>
|
return <Fragment>{children}</Fragment>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
|
|
||||||
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||||
|
|
||||||
import Extension from './Extension'
|
import Extension from './Extension'
|
||||||
|
|
||||||
@ -8,14 +8,26 @@ import Extension from './Extension'
|
|||||||
* Manages the registration and retrieval of extensions.
|
* Manages the registration and retrieval of extensions.
|
||||||
*/
|
*/
|
||||||
export class ExtensionManager {
|
export class ExtensionManager {
|
||||||
|
// Registered extensions
|
||||||
private extensions = new Map<string, BaseExtension>()
|
private extensions = new Map<string, BaseExtension>()
|
||||||
|
|
||||||
|
// Registered inference engines
|
||||||
|
private engines = new Map<string, AIEngine>()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Registers an extension.
|
* Registers an extension.
|
||||||
* @param extension - The extension to register.
|
* @param extension - The extension to register.
|
||||||
*/
|
*/
|
||||||
register<T extends BaseExtension>(name: string, extension: T) {
|
register<T extends BaseExtension>(name: string, extension: T) {
|
||||||
this.extensions.set(extension.type() ?? name, extension)
|
this.extensions.set(extension.type() ?? name, extension)
|
||||||
|
|
||||||
|
// Register AI Engines
|
||||||
|
if ('provider' in extension && typeof extension.provider === 'string') {
|
||||||
|
this.engines.set(
|
||||||
|
extension.provider as unknown as string,
|
||||||
|
extension as unknown as AIEngine
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -29,6 +41,15 @@ export class ExtensionManager {
|
|||||||
return this.extensions.get(type) as T | undefined
|
return this.extensions.get(type) as T | undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves a extension by its type.
|
||||||
|
* @param engine - The engine name to retrieve.
|
||||||
|
* @returns The extension, if found.
|
||||||
|
*/
|
||||||
|
getEngine<T extends AIEngine>(engine: string): T | undefined {
|
||||||
|
return this.engines.get(engine) as T | undefined
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads all registered extension.
|
* Loads all registered extension.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
import { useCallback, useEffect, useRef } from 'react'
|
import { useCallback, useEffect, useRef } from 'react'
|
||||||
|
|
||||||
import { events, Model, ModelEvent } from '@janhq/core'
|
import { EngineManager, Model } from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { toaster } from '@/containers/Toast'
|
import { toaster } from '@/containers/Toast'
|
||||||
|
|
||||||
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
||||||
|
|
||||||
|
import { extensionManager } from '@/extension'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -38,19 +39,13 @@ export function useActiveModel() {
|
|||||||
(stateModel.model === modelId && stateModel.loading)
|
(stateModel.model === modelId && stateModel.loading)
|
||||||
) {
|
) {
|
||||||
console.debug(`Model ${modelId} is already initialized. Ignore..`)
|
console.debug(`Model ${modelId} is already initialized. Ignore..`)
|
||||||
return
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
|
|
||||||
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
|
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
|
||||||
|
|
||||||
// Switch between engines
|
await stopModel().catch()
|
||||||
if (model && activeModel && activeModel.engine !== model.engine) {
|
|
||||||
stopModel()
|
|
||||||
// TODO: Refactor inference provider would address this
|
|
||||||
await new Promise((res) => setTimeout(res, 1000))
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: incase we have multiple assistants, the configuration will be from assistant
|
|
||||||
setLoadModelError(undefined)
|
setLoadModelError(undefined)
|
||||||
|
|
||||||
setActiveModel(undefined)
|
setActiveModel(undefined)
|
||||||
@ -68,7 +63,8 @@ export function useActiveModel() {
|
|||||||
loading: false,
|
loading: false,
|
||||||
model: '',
|
model: '',
|
||||||
}))
|
}))
|
||||||
return
|
|
||||||
|
return Promise.reject(`Model ${modelId} not found!`)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply thread model settings
|
/// Apply thread model settings
|
||||||
@ -83,15 +79,52 @@ export function useActiveModel() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
||||||
events.emit(ModelEvent.OnModelInit, model)
|
const engine = EngineManager.instance()?.get(model.engine)
|
||||||
|
return engine
|
||||||
|
?.loadModel(model)
|
||||||
|
.then(() => {
|
||||||
|
setActiveModel(model)
|
||||||
|
setStateModel(() => ({
|
||||||
|
state: 'stop',
|
||||||
|
loading: false,
|
||||||
|
model: model.id,
|
||||||
|
}))
|
||||||
|
toaster({
|
||||||
|
title: 'Success!',
|
||||||
|
description: `Model ${model.id} has been started.`,
|
||||||
|
type: 'success',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Failed to load model: ', error)
|
||||||
|
setStateModel(() => ({
|
||||||
|
state: 'start',
|
||||||
|
loading: false,
|
||||||
|
model: model.id,
|
||||||
|
}))
|
||||||
|
|
||||||
|
toaster({
|
||||||
|
title: 'Failed!',
|
||||||
|
description: `Model ${model.id} failed to start.`,
|
||||||
|
type: 'success',
|
||||||
|
})
|
||||||
|
setLoadModelError(error)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const stopModel = useCallback(async () => {
|
const stopModel = useCallback(async () => {
|
||||||
if (activeModel) {
|
if (activeModel) {
|
||||||
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
||||||
events.emit(ModelEvent.OnModelStop, activeModel)
|
const engine = EngineManager.instance()?.get(activeModel.engine)
|
||||||
|
await engine
|
||||||
|
?.unloadModel(activeModel)
|
||||||
|
.catch()
|
||||||
|
.then(() => {
|
||||||
|
setActiveModel(undefined)
|
||||||
|
setStateModel({ state: 'start', loading: false, model: '' })
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}, [activeModel, setStateModel])
|
}, [activeModel, setActiveModel, setStateModel])
|
||||||
|
|
||||||
return { activeModel, startModel, stopModel, stateModel }
|
return { activeModel, startModel, stopModel, stateModel }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,13 +11,12 @@ import {
|
|||||||
ExtensionTypeEnum,
|
ExtensionTypeEnum,
|
||||||
Thread,
|
Thread,
|
||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
events,
|
|
||||||
Model,
|
Model,
|
||||||
ConversationalExtension,
|
ConversationalExtension,
|
||||||
MessageEvent,
|
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
ChatCompletionMessageContentType,
|
ChatCompletionMessageContentType,
|
||||||
AssistantTool,
|
AssistantTool,
|
||||||
|
EngineManager,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -65,7 +64,6 @@ export default function useSendChatMessage() {
|
|||||||
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
|
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
|
||||||
const selectedModel = useAtomValue(selectedModelAtom)
|
const selectedModel = useAtomValue(selectedModelAtom)
|
||||||
const { activeModel, startModel } = useActiveModel()
|
const { activeModel, startModel } = useActiveModel()
|
||||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
|
||||||
const loadModelFailed = useAtomValue(loadModelErrorAtom)
|
const loadModelFailed = useAtomValue(loadModelErrorAtom)
|
||||||
|
|
||||||
const modelRef = useRef<Model | undefined>()
|
const modelRef = useRef<Model | undefined>()
|
||||||
@ -78,6 +76,7 @@ export default function useSendChatMessage() {
|
|||||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||||
const activeThreadRef = useRef<Thread | undefined>()
|
const activeThreadRef = useRef<Thread | undefined>()
|
||||||
|
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||||
|
|
||||||
const selectedModelRef = useRef<Model | undefined>()
|
const selectedModelRef = useRef<Model | undefined>()
|
||||||
|
|
||||||
@ -142,10 +141,7 @@ export default function useSendChatMessage() {
|
|||||||
activeThreadRef.current.assistants[0].model.id
|
activeThreadRef.current.assistants[0].model.id
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId) {
|
||||||
setQueuedMessage(true)
|
await startModel(modelId)
|
||||||
startModel(modelId)
|
|
||||||
await waitForModelStarting(modelId)
|
|
||||||
setQueuedMessage(false)
|
|
||||||
}
|
}
|
||||||
setIsGeneratingResponse(true)
|
setIsGeneratingResponse(true)
|
||||||
if (currentMessage.role !== ChatCompletionRole.User) {
|
if (currentMessage.role !== ChatCompletionRole.User) {
|
||||||
@ -160,7 +156,10 @@ export default function useSendChatMessage() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
const engine = EngineManager.instance()?.get(
|
||||||
|
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||||
|
)
|
||||||
|
engine?.inference(messageRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
const sendChatMessage = async (message: string) => {
|
const sendChatMessage = async (message: string) => {
|
||||||
@ -364,30 +363,20 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId) {
|
||||||
setQueuedMessage(true)
|
setQueuedMessage(true)
|
||||||
startModel(modelId)
|
await startModel(modelId)
|
||||||
await waitForModelStarting(modelId)
|
|
||||||
setQueuedMessage(false)
|
setQueuedMessage(false)
|
||||||
}
|
}
|
||||||
setIsGeneratingResponse(true)
|
setIsGeneratingResponse(true)
|
||||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
|
||||||
|
const engine = EngineManager.instance()?.get(
|
||||||
|
messageRequest.model?.engine ?? modelRequest.engine ?? ''
|
||||||
|
)
|
||||||
|
engine?.inference(messageRequest)
|
||||||
|
|
||||||
setReloadModel(false)
|
setReloadModel(false)
|
||||||
setEngineParamsUpdate(false)
|
setEngineParamsUpdate(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
const waitForModelStarting = async (modelId: string) => {
|
|
||||||
return new Promise<void>((resolve) => {
|
|
||||||
setTimeout(async () => {
|
|
||||||
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
|
|
||||||
await waitForModelStarting(modelId)
|
|
||||||
resolve()
|
|
||||||
} else {
|
|
||||||
resolve()
|
|
||||||
}
|
|
||||||
}, 200)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
sendChatMessage,
|
sendChatMessage,
|
||||||
resendChatMessage,
|
resendChatMessage,
|
||||||
|
|||||||
@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
|||||||
</p>
|
</p>
|
||||||
<ModalTroubleShooting />
|
<ModalTroubleShooting />
|
||||||
</div>
|
</div>
|
||||||
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
) : loadModelError &&
|
||||||
|
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
||||||
<div
|
<div
|
||||||
key={message.id}
|
key={message.id}
|
||||||
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"
|
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
import { EngineManager } from '@janhq/core'
|
||||||
|
|
||||||
import { appService } from './appService'
|
import { appService } from './appService'
|
||||||
import { EventEmitter } from './eventsService'
|
import { EventEmitter } from './eventsService'
|
||||||
import { restAPI } from './restService'
|
import { restAPI } from './restService'
|
||||||
@ -12,6 +14,7 @@ export const setupCoreServices = () => {
|
|||||||
if (!window.core) {
|
if (!window.core) {
|
||||||
window.core = {
|
window.core = {
|
||||||
events: new EventEmitter(),
|
events: new EventEmitter(),
|
||||||
|
engineManager: new EngineManager(),
|
||||||
api: {
|
api: {
|
||||||
...(window.electronAPI ? window.electronAPI : restAPI),
|
...(window.electronAPI ? window.electronAPI : restAPI),
|
||||||
...appService,
|
...appService,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user