From 9551996e349b1f75c8293cc117981f9a5e33b238 Mon Sep 17 00:00:00 2001
From: Louis
Date: Fri, 22 Mar 2024 22:29:14 +0700
Subject: [PATCH] chore: load, unload model and inference synchronously
---
.../extensions/ai-engines/RemoteOAIEngine.ts | 46 ------------
.../{ai-engines => engines}/AIEngine.ts | 56 +++++++++++++--
core/src/extensions/engines/EngineManager.ts | 34 +++++++++
.../{ai-engines => engines}/LocalOAIEngine.ts | 24 +++----
.../{ai-engines => engines}/OAIEngine.ts | 8 +--
.../src/extensions/engines/RemoteOAIEngine.ts | 26 +++++++
.../{ai-engines => engines}/helpers/sse.ts | 0
.../{ai-engines => engines}/index.ts | 1 +
core/src/extensions/index.ts | 2 +-
.../inference-nitro-extension/src/index.ts | 7 +-
web/containers/Providers/EventHandler.tsx | 70 +++----------------
web/extension/ExtensionManager.ts | 23 +++++-
web/hooks/useActiveModel.ts | 59 ++++++++++++----
web/hooks/useSendChatMessage.ts | 37 ++++------
web/screens/Chat/ErrorMessage/index.tsx | 3 +-
web/services/coreService.ts | 3 +
16 files changed, 226 insertions(+), 173 deletions(-)
delete mode 100644 core/src/extensions/ai-engines/RemoteOAIEngine.ts
rename core/src/extensions/{ai-engines => engines}/AIEngine.ts (56%)
create mode 100644 core/src/extensions/engines/EngineManager.ts
rename core/src/extensions/{ai-engines => engines}/LocalOAIEngine.ts (74%)
rename core/src/extensions/{ai-engines => engines}/OAIEngine.ts (95%)
create mode 100644 core/src/extensions/engines/RemoteOAIEngine.ts
rename core/src/extensions/{ai-engines => engines}/helpers/sse.ts (100%)
rename core/src/extensions/{ai-engines => engines}/index.ts (79%)
diff --git a/core/src/extensions/ai-engines/RemoteOAIEngine.ts b/core/src/extensions/ai-engines/RemoteOAIEngine.ts
deleted file mode 100644
index 5e9804b23..000000000
--- a/core/src/extensions/ai-engines/RemoteOAIEngine.ts
+++ /dev/null
@@ -1,46 +0,0 @@
-import { events } from '../../events'
-import { Model, ModelEvent } from '../../types'
-import { OAIEngine } from './OAIEngine'
-
-/**
- * Base OAI Remote Inference Provider
- * Added the implementation of loading and unloading model (applicable to local inference providers)
- */
-export abstract class RemoteOAIEngine extends OAIEngine {
- // The inference engine
- abstract apiKey: string
- /**
- * On extension load, subscribe to events.
- */
- onLoad() {
- super.onLoad()
- // These events are applicable to local inference providers
- events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
- events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
- }
-
- /**
- * Load the model.
- */
- async loadModel(model: Model) {
- if (model.engine.toString() !== this.provider) return
- events.emit(ModelEvent.OnModelReady, model)
- }
- /**
- * Stops the model.
- */
- unloadModel(model: Model) {
- if (model.engine && model.engine.toString() !== this.provider) return
- events.emit(ModelEvent.OnModelStopped, {})
- }
-
- /**
- * Headers for the inference request
- */
- override headers(): HeadersInit {
- return {
- 'Authorization': `Bearer ${this.apiKey}`,
- 'api-key': `${this.apiKey}`,
- }
- }
-}
diff --git a/core/src/extensions/ai-engines/AIEngine.ts b/core/src/extensions/engines/AIEngine.ts
similarity index 56%
rename from core/src/extensions/ai-engines/AIEngine.ts
rename to core/src/extensions/engines/AIEngine.ts
index 8af89f336..2323c07ef 100644
--- a/core/src/extensions/ai-engines/AIEngine.ts
+++ b/core/src/extensions/engines/AIEngine.ts
@@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { fs } from '../../fs'
-import { Model, ModelEvent } from '../../types'
+import { MessageRequest, Model, ModelEvent } from '../../types'
+import { EngineManager } from './EngineManager'
/**
* Base AIEngine
@@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
- // The model folder
- modelFolder: string = 'models'
+ /**
+ * On extension load, subscribe to events.
+ */
+ override onLoad() {
+ this.registerEngine()
+
+ events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
+ events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
+
+ this.prePopulateModels()
+ }
+
+ /**
+ * Defines models
+ */
models(): Promise {
return Promise.resolve([])
}
/**
- * On extension load, subscribe to events.
+ * Registers AI Engines
*/
- onLoad() {
- this.prePopulateModels()
+ registerEngine() {
+ EngineManager.instance()?.register(this)
}
+ /**
+ * Loads the model.
+ */
+ async loadModel(model: Model): Promise {
+ if (model.engine.toString() !== this.provider) return Promise.resolve()
+ events.emit(ModelEvent.OnModelReady, model)
+ return Promise.resolve()
+ }
+ /**
+ * Stops the model.
+ */
+ async unloadModel(model?: Model): Promise {
+ if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
+ events.emit(ModelEvent.OnModelStopped, model ?? {})
+ return Promise.resolve()
+ }
+
+ /*
+ * Inference request
+ */
+ inference(data: MessageRequest) {}
+
+ /**
+ * Stop inference
+ */
+ stopInference() {}
+
/**
* Pre-populate models to App Data Folder
*/
prePopulateModels(): Promise {
+ const modelFolder = 'models'
return this.models().then((models) => {
const prePoluateOperations = models.map((model) =>
getJanDataFolderPath()
.then((janDataFolder) =>
// Attempt to create the model folder
- joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
+ joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
fs
.mkdir(path)
.catch()
diff --git a/core/src/extensions/engines/EngineManager.ts b/core/src/extensions/engines/EngineManager.ts
new file mode 100644
index 000000000..6931e605e
--- /dev/null
+++ b/core/src/extensions/engines/EngineManager.ts
@@ -0,0 +1,34 @@
+import { log } from '../../core'
+import { AIEngine } from './AIEngine'
+
+/**
+ * Manages the registration and retrieval of inference engines.
+ */
+export class EngineManager {
+ public engines = new Map()
+
+ /**
+ * Registers an engine.
+ * @param engine - The engine to register.
+ */
+ register(engine: T) {
+ this.engines.set(engine.provider, engine)
+ }
+
+ /**
+ * Retrieves a engine by provider.
+ * @param provider - The name of the engine to retrieve.
+ * @returns The engine, if found.
+ */
+ get(provider: string): T | undefined {
+ return this.engines.get(provider) as T | undefined
+ }
+
+ static instance(): EngineManager | undefined {
+ return window.core?.engineManager as EngineManager
+ }
+}
+
+/**
+ * The singleton instance of the ExtensionManager.
+ */
diff --git a/core/src/extensions/ai-engines/LocalOAIEngine.ts b/core/src/extensions/engines/LocalOAIEngine.ts
similarity index 74%
rename from core/src/extensions/ai-engines/LocalOAIEngine.ts
rename to core/src/extensions/engines/LocalOAIEngine.ts
index f6557cd8f..ce92ac804 100644
--- a/core/src/extensions/ai-engines/LocalOAIEngine.ts
+++ b/core/src/extensions/engines/LocalOAIEngine.ts
@@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* On extension load, subscribe to events.
*/
- onLoad() {
+ override onLoad() {
super.onLoad()
// These events are applicable to local inference providers
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
@@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
/**
* Load the model.
*/
- async loadModel(model: Model) {
+ override async loadModel(model: Model): Promise {
if (model.engine.toString() !== this.provider) return
-
- const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
+ const modelFolderName = 'models'
+ const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
const systemInfo = await systemInformation()
const res = await executeOnMain(
this.nodeModule,
@@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
)
if (res?.error) {
- events.emit(ModelEvent.OnModelFail, {
- ...model,
- error: res.error,
- })
- return
+ events.emit(ModelEvent.OnModelFail, { error: res.error })
+ return Promise.reject(res.error)
} else {
this.loadedModel = model
events.emit(ModelEvent.OnModelReady, model)
+ return Promise.resolve()
}
}
/**
* Stops the model.
*/
- unloadModel(model: Model) {
- if (model.engine && model.engine?.toString() !== this.provider) return
- this.loadedModel = undefined
+ override async unloadModel(model?: Model): Promise {
+ if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
- executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
+ this.loadedModel = undefined
+ return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
events.emit(ModelEvent.OnModelStopped, {})
})
}
diff --git a/core/src/extensions/ai-engines/OAIEngine.ts b/core/src/extensions/engines/OAIEngine.ts
similarity index 95%
rename from core/src/extensions/ai-engines/OAIEngine.ts
rename to core/src/extensions/engines/OAIEngine.ts
index 5936005bb..772f6504f 100644
--- a/core/src/extensions/ai-engines/OAIEngine.ts
+++ b/core/src/extensions/engines/OAIEngine.ts
@@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension load, subscribe to events.
*/
- onLoad() {
+ override onLoad() {
super.onLoad()
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
@@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
/**
* On extension unload
*/
- onUnload(): void {}
+ override onUnload(): void {}
/*
* Inference request
*/
- inference(data: MessageRequest) {
+ override inference(data: MessageRequest) {
if (data.model?.engine?.toString() !== this.provider) return
const timestamp = Date.now()
@@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
/**
* Stops the inference.
*/
- stopInference() {
+ override stopInference() {
this.isCancelled = true
this.controller?.abort()
}
diff --git a/core/src/extensions/engines/RemoteOAIEngine.ts b/core/src/extensions/engines/RemoteOAIEngine.ts
new file mode 100644
index 000000000..2d5126c6b
--- /dev/null
+++ b/core/src/extensions/engines/RemoteOAIEngine.ts
@@ -0,0 +1,26 @@
+import { OAIEngine } from './OAIEngine'
+
+/**
+ * Base OAI Remote Inference Provider
+ * Added the implementation of loading and unloading model (applicable to local inference providers)
+ */
+export abstract class RemoteOAIEngine extends OAIEngine {
+ // The inference engine
+ abstract apiKey: string
+ /**
+ * On extension load, subscribe to events.
+ */
+ override onLoad() {
+ super.onLoad()
+ }
+
+ /**
+ * Headers for the inference request
+ */
+ override headers(): HeadersInit {
+ return {
+ 'Authorization': `Bearer ${this.apiKey}`,
+ 'api-key': `${this.apiKey}`,
+ }
+ }
+}
diff --git a/core/src/extensions/ai-engines/helpers/sse.ts b/core/src/extensions/engines/helpers/sse.ts
similarity index 100%
rename from core/src/extensions/ai-engines/helpers/sse.ts
rename to core/src/extensions/engines/helpers/sse.ts
diff --git a/core/src/extensions/ai-engines/index.ts b/core/src/extensions/engines/index.ts
similarity index 79%
rename from core/src/extensions/ai-engines/index.ts
rename to core/src/extensions/engines/index.ts
index fc341380a..34ef45afd 100644
--- a/core/src/extensions/ai-engines/index.ts
+++ b/core/src/extensions/engines/index.ts
@@ -2,3 +2,4 @@ export * from './AIEngine'
export * from './OAIEngine'
export * from './LocalOAIEngine'
export * from './RemoteOAIEngine'
+export * from './EngineManager'
diff --git a/core/src/extensions/index.ts b/core/src/extensions/index.ts
index c049f3b3a..768886d49 100644
--- a/core/src/extensions/index.ts
+++ b/core/src/extensions/index.ts
@@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
/**
* Base AI Engines.
*/
-export * from './ai-engines'
+export * from './engines'
diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts
index 3a23082ba..313b67365 100644
--- a/extensions/inference-nitro-extension/src/index.ts
+++ b/extensions/inference-nitro-extension/src/index.ts
@@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
return super.loadModel(model)
}
- override unloadModel(model: Model): void {
- super.unloadModel(model)
-
- if (model.engine && model.engine !== this.provider) return
+ override async unloadModel(model?: Model) {
+ if (model?.engine && model.engine !== this.provider) return
// stop the periocally health check
if (this.getNitroProcesHealthIntervalId) {
clearInterval(this.getNitroProcesHealthIntervalId)
this.getNitroProcesHealthIntervalId = undefined
}
+ return super.unloadModel(model)
}
}
diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx
index d44c950e1..d62cb3f8f 100644
--- a/web/containers/Providers/EventHandler.tsx
+++ b/web/containers/Providers/EventHandler.tsx
@@ -8,26 +8,17 @@ import {
ExtensionTypeEnum,
MessageStatus,
MessageRequest,
- Model,
ConversationalExtension,
MessageEvent,
MessageRequestType,
ModelEvent,
Thread,
- ModelInitFailed,
+ EngineManager,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
-import {
- activeModelAtom,
- loadModelErrorAtom,
- stateModelAtom,
-} from '@/hooks/useActiveModel'
-
-import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
-
-import { toaster } from '../Toast'
+import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { extensionManager } from '@/extension'
import {
@@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom)
- const setQueuedMessage = useSetAtom(queuedMessageAtom)
- const setLoadModelError = useSetAtom(loadModelErrorAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom)
@@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
[addNewMessage]
)
- const onModelReady = useCallback(
- (model: Model) => {
- setActiveModel(model)
- toaster({
- title: 'Success!',
- description: `Model ${model.id} has been started.`,
- type: 'success',
- })
- setStateModel(() => ({
- state: 'stop',
- loading: false,
- model: model.id,
- }))
- },
- [setActiveModel, setStateModel]
- )
-
const onModelStopped = useCallback(() => {
- setTimeout(() => {
- setActiveModel(undefined)
- setStateModel({ state: 'start', loading: false, model: '' })
- }, 500)
+ setActiveModel(undefined)
+ setStateModel({ state: 'start', loading: false, model: '' })
}, [setActiveModel, setStateModel])
- const onModelInitFailed = useCallback(
- (res: ModelInitFailed) => {
- console.error('Failed to load model: ', res.error.message)
- setStateModel(() => ({
- state: 'start',
- loading: false,
- model: res.id,
- }))
- setLoadModelError(res.error.message)
- setQueuedMessage(false)
- },
- [setStateModel, setQueuedMessage, setLoadModelError]
- )
-
const updateThreadTitle = useCallback(
(message: ThreadMessage) => {
// Update only when it's finished
@@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
// 2. Update the title with the result of the inference
setTimeout(() => {
- events.emit(MessageEvent.OnMessageSent, messageRequest)
+ const engine = EngineManager.instance()?.get(
+ messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
+ )
+ engine?.inference(messageRequest)
}, 1000)
}
}
@@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core?.events) {
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
- events.on(ModelEvent.OnModelReady, onModelReady)
- events.on(ModelEvent.OnModelFail, onModelInitFailed)
events.on(ModelEvent.OnModelStopped, onModelStopped)
}
- }, [
- onNewMessageResponse,
- onMessageResponseUpdate,
- onModelReady,
- onModelInitFailed,
- onModelStopped,
- ])
+ }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
useEffect(() => {
return () => {
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
+ events.off(ModelEvent.OnModelStopped, onModelStopped)
}
- }, [onNewMessageResponse, onMessageResponseUpdate])
+ }, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
return {children}
}
diff --git a/web/extension/ExtensionManager.ts b/web/extension/ExtensionManager.ts
index c976010c6..6d96d71b5 100644
--- a/web/extension/ExtensionManager.ts
+++ b/web/extension/ExtensionManager.ts
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
-import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
+import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
import Extension from './Extension'
@@ -8,14 +8,26 @@ import Extension from './Extension'
* Manages the registration and retrieval of extensions.
*/
export class ExtensionManager {
+ // Registered extensions
private extensions = new Map()
+ // Registered inference engines
+ private engines = new Map()
+
/**
* Registers an extension.
* @param extension - The extension to register.
*/
register(name: string, extension: T) {
this.extensions.set(extension.type() ?? name, extension)
+
+ // Register AI Engines
+ if ('provider' in extension && typeof extension.provider === 'string') {
+ this.engines.set(
+ extension.provider as unknown as string,
+ extension as unknown as AIEngine
+ )
+ }
}
/**
@@ -29,6 +41,15 @@ export class ExtensionManager {
return this.extensions.get(type) as T | undefined
}
+ /**
+ * Retrieves a extension by its type.
+ * @param engine - The engine name to retrieve.
+ * @returns The extension, if found.
+ */
+ getEngine(engine: string): T | undefined {
+ return this.engines.get(engine) as T | undefined
+ }
+
/**
* Loads all registered extension.
*/
diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts
index 98433c2ea..c2e11d182 100644
--- a/web/hooks/useActiveModel.ts
+++ b/web/hooks/useActiveModel.ts
@@ -1,12 +1,13 @@
import { useCallback, useEffect, useRef } from 'react'
-import { events, Model, ModelEvent } from '@janhq/core'
+import { EngineManager, Model } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toaster } from '@/containers/Toast'
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
+import { extensionManager } from '@/extension'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@@ -38,19 +39,13 @@ export function useActiveModel() {
(stateModel.model === modelId && stateModel.loading)
) {
console.debug(`Model ${modelId} is already initialized. Ignore..`)
- return
+ return Promise.resolve()
}
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
- // Switch between engines
- if (model && activeModel && activeModel.engine !== model.engine) {
- stopModel()
- // TODO: Refactor inference provider would address this
- await new Promise((res) => setTimeout(res, 1000))
- }
+ await stopModel().catch()
- // TODO: incase we have multiple assistants, the configuration will be from assistant
setLoadModelError(undefined)
setActiveModel(undefined)
@@ -68,7 +63,8 @@ export function useActiveModel() {
loading: false,
model: '',
}))
- return
+
+ return Promise.reject(`Model ${modelId} not found!`)
}
/// Apply thread model settings
@@ -83,15 +79,52 @@ export function useActiveModel() {
}
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
- events.emit(ModelEvent.OnModelInit, model)
+ const engine = EngineManager.instance()?.get(model.engine)
+ return engine
+ ?.loadModel(model)
+ .then(() => {
+ setActiveModel(model)
+ setStateModel(() => ({
+ state: 'stop',
+ loading: false,
+ model: model.id,
+ }))
+ toaster({
+ title: 'Success!',
+ description: `Model ${model.id} has been started.`,
+ type: 'success',
+ })
+ })
+ .catch((error) => {
+ console.error('Failed to load model: ', error)
+ setStateModel(() => ({
+ state: 'start',
+ loading: false,
+ model: model.id,
+ }))
+
+ toaster({
+ title: 'Failed!',
+ description: `Model ${model.id} failed to start.`,
+ type: 'success',
+ })
+ setLoadModelError(error)
+ })
}
const stopModel = useCallback(async () => {
if (activeModel) {
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
- events.emit(ModelEvent.OnModelStop, activeModel)
+ const engine = EngineManager.instance()?.get(activeModel.engine)
+ await engine
+ ?.unloadModel(activeModel)
+ .catch()
+ .then(() => {
+ setActiveModel(undefined)
+ setStateModel({ state: 'start', loading: false, model: '' })
+ })
}
- }, [activeModel, setStateModel])
+ }, [activeModel, setActiveModel, setStateModel])
return { activeModel, startModel, stopModel, stateModel }
}
diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts
index 1ba68f85e..65daf450f 100644
--- a/web/hooks/useSendChatMessage.ts
+++ b/web/hooks/useSendChatMessage.ts
@@ -11,13 +11,12 @@ import {
ExtensionTypeEnum,
Thread,
ThreadMessage,
- events,
Model,
ConversationalExtension,
- MessageEvent,
InferenceEngine,
ChatCompletionMessageContentType,
AssistantTool,
+ EngineManager,
} from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
@@ -65,7 +64,6 @@ export default function useSendChatMessage() {
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel()
- const setQueuedMessage = useSetAtom(queuedMessageAtom)
const loadModelFailed = useAtomValue(loadModelErrorAtom)
const modelRef = useRef()
@@ -78,6 +76,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef()
+ const setQueuedMessage = useSetAtom(queuedMessageAtom)
const selectedModelRef = useRef()
@@ -142,10 +141,7 @@ export default function useSendChatMessage() {
activeThreadRef.current.assistants[0].model.id
if (modelRef.current?.id !== modelId) {
- setQueuedMessage(true)
- startModel(modelId)
- await waitForModelStarting(modelId)
- setQueuedMessage(false)
+ await startModel(modelId)
}
setIsGeneratingResponse(true)
if (currentMessage.role !== ChatCompletionRole.User) {
@@ -160,7 +156,10 @@ export default function useSendChatMessage() {
)
}
}
- events.emit(MessageEvent.OnMessageSent, messageRequest)
+ const engine = EngineManager.instance()?.get(
+ messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
+ )
+ engine?.inference(messageRequest)
}
const sendChatMessage = async (message: string) => {
@@ -364,30 +363,20 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
- startModel(modelId)
- await waitForModelStarting(modelId)
+ await startModel(modelId)
setQueuedMessage(false)
}
setIsGeneratingResponse(true)
- events.emit(MessageEvent.OnMessageSent, messageRequest)
+
+ const engine = EngineManager.instance()?.get(
+ messageRequest.model?.engine ?? modelRequest.engine ?? ''
+ )
+ engine?.inference(messageRequest)
setReloadModel(false)
setEngineParamsUpdate(false)
}
- const waitForModelStarting = async (modelId: string) => {
- return new Promise((resolve) => {
- setTimeout(async () => {
- if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
- await waitForModelStarting(modelId)
- resolve()
- } else {
- resolve()
- }
- }, 200)
- })
- }
-
return {
sendChatMessage,
resendChatMessage,
diff --git a/web/screens/Chat/ErrorMessage/index.tsx b/web/screens/Chat/ErrorMessage/index.tsx
index 5be87a59d..2104beb92 100644
--- a/web/screens/Chat/ErrorMessage/index.tsx
+++ b/web/screens/Chat/ErrorMessage/index.tsx
@@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
- ) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
+ ) : loadModelError &&
+ loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
{
if (!window.core) {
window.core = {
events: new EventEmitter(),
+ engineManager: new EngineManager(),
api: {
...(window.electronAPI ? window.electronAPI : restAPI),
...appService,