chore: load, unload model and inference synchronously

This commit is contained in:
Louis 2024-03-22 22:29:14 +07:00
parent 1ad794c8a7
commit 9551996e34
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
16 changed files with 226 additions and 173 deletions

View File

@ -1,46 +0,0 @@
import { events } from '../../events'
import { Model, ModelEvent } from '../../types'
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
}
/**
* Load the model.
*/
async loadModel(model: Model) {
if (model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelReady, model)
}
/**
* Stops the model.
*/
unloadModel(model: Model) {
if (model.engine && model.engine.toString() !== this.provider) return
events.emit(ModelEvent.OnModelStopped, {})
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events' import { events } from '../../events'
import { BaseExtension } from '../../extension' import { BaseExtension } from '../../extension'
import { fs } from '../../fs' import { fs } from '../../fs'
import { Model, ModelEvent } from '../../types' import { MessageRequest, Model, ModelEvent } from '../../types'
import { EngineManager } from './EngineManager'
/** /**
* Base AIEngine * Base AIEngine
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension { export abstract class AIEngine extends BaseExtension {
// The inference engine // The inference engine
abstract provider: string abstract provider: string
// The model folder
modelFolder: string = 'models'
/**
* On extension load, subscribe to events.
*/
override onLoad() {
this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
this.prePopulateModels()
}
/**
* Defines models
*/
models(): Promise<Model[]> { models(): Promise<Model[]> {
return Promise.resolve([]) return Promise.resolve([])
} }
/** /**
* On extension load, subscribe to events. * Registers AI Engines
*/ */
onLoad() { registerEngine() {
this.prePopulateModels() EngineManager.instance()?.register(this)
} }
/**
* Loads the model.
*/
async loadModel(model: Model): Promise<any> {
if (model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Model): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}
/*
* Inference request
*/
inference(data: MessageRequest) {}
/**
* Stop inference
*/
stopInference() {}
/** /**
* Pre-populate models to App Data Folder * Pre-populate models to App Data Folder
*/ */
prePopulateModels(): Promise<void> { prePopulateModels(): Promise<void> {
const modelFolder = 'models'
return this.models().then((models) => { return this.models().then((models) => {
const prePoluateOperations = models.map((model) => const prePoluateOperations = models.map((model) =>
getJanDataFolderPath() getJanDataFolderPath()
.then((janDataFolder) => .then((janDataFolder) =>
// Attempt to create the model folder // Attempt to create the model folder
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) => joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs fs
.mkdir(path) .mkdir(path)
.catch() .catch()

View File

@ -0,0 +1,34 @@
import { log } from '../../core'
import { AIEngine } from './AIEngine'
/**
* Manages the registration and retrieval of inference engines.
*/
export class EngineManager {
public engines = new Map<string, AIEngine>()
/**
* Registers an engine.
* @param engine - The engine to register.
*/
register<T extends AIEngine>(engine: T) {
this.engines.set(engine.provider, engine)
}
/**
* Retrieves a engine by provider.
* @param provider - The name of the engine to retrieve.
* @returns The engine, if found.
*/
get<T extends AIEngine>(provider: string): T | undefined {
return this.engines.get(provider) as T | undefined
}
static instance(): EngineManager | undefined {
return window.core?.engineManager as EngineManager
}
}
/**
* The singleton instance of the ExtensionManager.
*/

View File

@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.
*/ */
onLoad() { override onLoad() {
super.onLoad() super.onLoad()
// These events are applicable to local inference providers // These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/** /**
* Load the model. * Load the model.
*/ */
async loadModel(model: Model) { override async loadModel(model: Model): Promise<void> {
if (model.engine.toString() !== this.provider) return if (model.engine.toString() !== this.provider) return
const modelFolderName = 'models'
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id]) const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
const res = await executeOnMain( const res = await executeOnMain(
this.nodeModule, this.nodeModule,
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
) )
if (res?.error) { if (res?.error) {
events.emit(ModelEvent.OnModelFail, { events.emit(ModelEvent.OnModelFail, { error: res.error })
...model, return Promise.reject(res.error)
error: res.error,
})
return
} else { } else {
this.loadedModel = model this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model) events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
} }
} }
/** /**
* Stops the model. * Stops the model.
*/ */
unloadModel(model: Model) { override async unloadModel(model?: Model): Promise<void> {
if (model.engine && model.engine?.toString() !== this.provider) return if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
this.loadedModel = undefined
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { this.loadedModel = undefined
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {}) events.emit(ModelEvent.OnModelStopped, {})
}) })
} }

View File

@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* On extension load, subscribe to events. * On extension load, subscribe to events.
*/ */
onLoad() { override onLoad() {
super.onLoad() super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data)) events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference()) events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* On extension unload * On extension unload
*/ */
onUnload(): void {} override onUnload(): void {}
/* /*
* Inference request * Inference request
*/ */
inference(data: MessageRequest) { override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return if (data.model?.engine?.toString() !== this.provider) return
const timestamp = Date.now() const timestamp = Date.now()
@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
/** /**
* Stops the inference. * Stops the inference.
*/ */
stopInference() { override stopInference() {
this.isCancelled = true this.isCancelled = true
this.controller?.abort() this.controller?.abort()
} }

View File

@ -0,0 +1,26 @@
import { OAIEngine } from './OAIEngine'
/**
* Base OAI Remote Inference Provider
* Added the implementation of loading and unloading model (applicable to local inference providers)
*/
export abstract class RemoteOAIEngine extends OAIEngine {
// The inference engine
abstract apiKey: string
/**
* On extension load, subscribe to events.
*/
override onLoad() {
super.onLoad()
}
/**
* Headers for the inference request
*/
override headers(): HeadersInit {
return {
'Authorization': `Bearer ${this.apiKey}`,
'api-key': `${this.apiKey}`,
}
}
}

View File

@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine' export * from './OAIEngine'
export * from './LocalOAIEngine' export * from './LocalOAIEngine'
export * from './RemoteOAIEngine' export * from './RemoteOAIEngine'
export * from './EngineManager'

View File

@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/** /**
* Base AI Engines. * Base AI Engines.
*/ */
export * from './ai-engines' export * from './engines'

View File

@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model) return super.loadModel(model)
} }
override unloadModel(model: Model): void { override async unloadModel(model?: Model) {
super.unloadModel(model) if (model?.engine && model.engine !== this.provider) return
if (model.engine && model.engine !== this.provider) return
// stop the periocally health check // stop the periocally health check
if (this.getNitroProcesHealthIntervalId) { if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId) clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined this.getNitroProcesHealthIntervalId = undefined
} }
return super.unloadModel(model)
} }
} }

View File

@ -8,26 +8,17 @@ import {
ExtensionTypeEnum, ExtensionTypeEnum,
MessageStatus, MessageStatus,
MessageRequest, MessageRequest,
Model,
ConversationalExtension, ConversationalExtension,
MessageEvent, MessageEvent,
MessageRequestType, MessageRequestType,
ModelEvent, ModelEvent,
Thread, Thread,
ModelInitFailed, EngineManager,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx' import { ulid } from 'ulidx'
import { import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
activeModelAtom,
loadModelErrorAtom,
stateModelAtom,
} from '@/hooks/useActiveModel'
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
import { toaster } from '../Toast'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom) const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom) const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const setLoadModelError = useSetAtom(loadModelErrorAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage] [addNewMessage]
) )
const onModelReady = useCallback(
(model: Model) => {
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
},
[setActiveModel, setStateModel]
)
const onModelStopped = useCallback(() => { const onModelStopped = useCallback(() => {
setTimeout(() => { setActiveModel(undefined)
setActiveModel(undefined) setStateModel({ state: 'start', loading: false, model: '' })
setStateModel({ state: 'start', loading: false, model: '' })
}, 500)
}, [setActiveModel, setStateModel]) }, [setActiveModel, setStateModel])
const onModelInitFailed = useCallback(
(res: ModelInitFailed) => {
console.error('Failed to load model: ', res.error.message)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.id,
}))
setLoadModelError(res.error.message)
setQueuedMessage(false)
},
[setStateModel, setQueuedMessage, setLoadModelError]
)
const updateThreadTitle = useCallback( const updateThreadTitle = useCallback(
(message: ThreadMessage) => { (message: ThreadMessage) => {
// Update only when it's finished // Update only when it's finished
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// 2. Update the title with the result of the inference // 2. Update the title with the result of the inference
setTimeout(() => { setTimeout(() => {
events.emit(MessageEvent.OnMessageSent, messageRequest) const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
}, 1000) }, 1000)
} }
} }
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) { if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse) events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.on(ModelEvent.OnModelReady, onModelReady)
events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped) events.on(ModelEvent.OnModelStopped, onModelStopped)
} }
}, [ }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
onNewMessageResponse,
onMessageResponseUpdate,
onModelReady,
onModelInitFailed,
onModelStopped,
])
useEffect(() => { useEffect(() => {
return () => { return () => {
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse) events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate) events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
events.off(ModelEvent.OnModelStopped, onModelStopped)
} }
}, [onNewMessageResponse, onMessageResponseUpdate]) }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
return <Fragment>{children}</Fragment> return <Fragment>{children}</Fragment>
} }

View File

@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core' import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import Extension from './Extension' import Extension from './Extension'
@ -8,14 +8,26 @@ import Extension from './Extension'
* Manages the registration and retrieval of extensions. * Manages the registration and retrieval of extensions.
*/ */
export class ExtensionManager { export class ExtensionManager {
// Registered extensions
private extensions = new Map<string, BaseExtension>() private extensions = new Map<string, BaseExtension>()
// Registered inference engines
private engines = new Map<string, AIEngine>()
/** /**
* Registers an extension. * Registers an extension.
* @param extension - The extension to register. * @param extension - The extension to register.
*/ */
register<T extends BaseExtension>(name: string, extension: T) { register<T extends BaseExtension>(name: string, extension: T) {
this.extensions.set(extension.type() ?? name, extension) this.extensions.set(extension.type() ?? name, extension)
// Register AI Engines
if ('provider' in extension && typeof extension.provider === 'string') {
this.engines.set(
extension.provider as unknown as string,
extension as unknown as AIEngine
)
}
} }
/** /**
@ -29,6 +41,15 @@ export class ExtensionManager {
return this.extensions.get(type) as T | undefined return this.extensions.get(type) as T | undefined
} }
/**
* Retrieves a extension by its type.
* @param engine - The engine name to retrieve.
* @returns The extension, if found.
*/
getEngine<T extends AIEngine>(engine: string): T | undefined {
return this.engines.get(engine) as T | undefined
}
/** /**
* Loads all registered extension. * Loads all registered extension.
*/ */

View File

@ -1,12 +1,13 @@
import { useCallback, useEffect, useRef } from 'react' import { useCallback, useEffect, useRef } from 'react'
import { events, Model, ModelEvent } from '@janhq/core' import { EngineManager, Model } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
import { LAST_USED_MODEL_ID } from './useRecommendedModel' import { LAST_USED_MODEL_ID } from './useRecommendedModel'
import { extensionManager } from '@/extension'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -38,19 +39,13 @@ export function useActiveModel() {
(stateModel.model === modelId && stateModel.loading) (stateModel.model === modelId && stateModel.loading)
) { ) {
console.debug(`Model ${modelId} is already initialized. Ignore..`) console.debug(`Model ${modelId} is already initialized. Ignore..`)
return return Promise.resolve()
} }
let model = downloadedModelsRef?.current.find((e) => e.id === modelId) let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
// Switch between engines await stopModel().catch()
if (model && activeModel && activeModel.engine !== model.engine) {
stopModel()
// TODO: Refactor inference provider would address this
await new Promise((res) => setTimeout(res, 1000))
}
// TODO: incase we have multiple assistants, the configuration will be from assistant
setLoadModelError(undefined) setLoadModelError(undefined)
setActiveModel(undefined) setActiveModel(undefined)
@ -68,7 +63,8 @@ export function useActiveModel() {
loading: false, loading: false,
model: '', model: '',
})) }))
return
return Promise.reject(`Model ${modelId} not found!`)
} }
/// Apply thread model settings /// Apply thread model settings
@ -83,15 +79,52 @@ export function useActiveModel() {
} }
localStorage.setItem(LAST_USED_MODEL_ID, model.id) localStorage.setItem(LAST_USED_MODEL_ID, model.id)
events.emit(ModelEvent.OnModelInit, model) const engine = EngineManager.instance()?.get(model.engine)
return engine
?.loadModel(model)
.then(() => {
setActiveModel(model)
setStateModel(() => ({
state: 'stop',
loading: false,
model: model.id,
}))
toaster({
title: 'Success!',
description: `Model ${model.id} has been started.`,
type: 'success',
})
})
.catch((error) => {
console.error('Failed to load model: ', error)
setStateModel(() => ({
state: 'start',
loading: false,
model: model.id,
}))
toaster({
title: 'Failed!',
description: `Model ${model.id} failed to start.`,
type: 'success',
})
setLoadModelError(error)
})
} }
const stopModel = useCallback(async () => { const stopModel = useCallback(async () => {
if (activeModel) { if (activeModel) {
setStateModel({ state: 'stop', loading: true, model: activeModel.id }) setStateModel({ state: 'stop', loading: true, model: activeModel.id })
events.emit(ModelEvent.OnModelStop, activeModel) const engine = EngineManager.instance()?.get(activeModel.engine)
await engine
?.unloadModel(activeModel)
.catch()
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
})
} }
}, [activeModel, setStateModel]) }, [activeModel, setActiveModel, setStateModel])
return { activeModel, startModel, stopModel, stateModel } return { activeModel, startModel, stopModel, stateModel }
} }

View File

@ -11,13 +11,12 @@ import {
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
ThreadMessage, ThreadMessage,
events,
Model, Model,
ConversationalExtension, ConversationalExtension,
MessageEvent,
InferenceEngine, InferenceEngine,
ChatCompletionMessageContentType, ChatCompletionMessageContentType,
AssistantTool, AssistantTool,
EngineManager,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@ -65,7 +64,6 @@ export default function useSendChatMessage() {
const currentMessages = useAtomValue(getCurrentChatMessagesAtom) const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom) const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel() const { activeModel, startModel } = useActiveModel()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const loadModelFailed = useAtomValue(loadModelErrorAtom) const loadModelFailed = useAtomValue(loadModelErrorAtom)
const modelRef = useRef<Model | undefined>() const modelRef = useRef<Model | undefined>()
@ -78,6 +76,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>() const activeThreadRef = useRef<Thread | undefined>()
const setQueuedMessage = useSetAtom(queuedMessageAtom)
const selectedModelRef = useRef<Model | undefined>() const selectedModelRef = useRef<Model | undefined>()
@ -142,10 +141,7 @@ export default function useSendChatMessage() {
activeThreadRef.current.assistants[0].model.id activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
setQueuedMessage(true) await startModel(modelId)
startModel(modelId)
await waitForModelStarting(modelId)
setQueuedMessage(false)
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
if (currentMessage.role !== ChatCompletionRole.User) { if (currentMessage.role !== ChatCompletionRole.User) {
@ -160,7 +156,10 @@ export default function useSendChatMessage() {
) )
} }
} }
events.emit(MessageEvent.OnMessageSent, messageRequest) const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
)
engine?.inference(messageRequest)
} }
const sendChatMessage = async (message: string) => { const sendChatMessage = async (message: string) => {
@ -364,30 +363,20 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
setQueuedMessage(true) setQueuedMessage(true)
startModel(modelId) await startModel(modelId)
await waitForModelStarting(modelId)
setQueuedMessage(false) setQueuedMessage(false)
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
events.emit(MessageEvent.OnMessageSent, messageRequest)
const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? modelRequest.engine ?? ''
)
engine?.inference(messageRequest)
setReloadModel(false) setReloadModel(false)
setEngineParamsUpdate(false) setEngineParamsUpdate(false)
} }
const waitForModelStarting = async (modelId: string) => {
return new Promise<void>((resolve) => {
setTimeout(async () => {
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
await waitForModelStarting(modelId)
resolve()
} else {
resolve()
}
}, 200)
})
}
return { return {
sendChatMessage, sendChatMessage,
resendChatMessage, resendChatMessage,

View File

@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
</p> </p>
<ModalTroubleShooting /> <ModalTroubleShooting />
</div> </div>
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? ( ) : loadModelError &&
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
<div <div
key={message.id} key={message.id}
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500" className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"

View File

@ -1,3 +1,5 @@
import { EngineManager } from '@janhq/core'
import { appService } from './appService' import { appService } from './appService'
import { EventEmitter } from './eventsService' import { EventEmitter } from './eventsService'
import { restAPI } from './restService' import { restAPI } from './restService'
@ -12,6 +14,7 @@ export const setupCoreServices = () => {
if (!window.core) { if (!window.core) {
window.core = { window.core = {
events: new EventEmitter(), events: new EventEmitter(),
engineManager: new EngineManager(),
api: { api: {
...(window.electronAPI ? window.electronAPI : restAPI), ...(window.electronAPI ? window.electronAPI : restAPI),
...appService, ...appService,