chore: load, unload model and inference synchronously
This commit is contained in:
parent
1ad794c8a7
commit
9551996e34
@ -1,46 +0,0 @@
|
||||
import { events } from '../../events'
|
||||
import { Model, ModelEvent } from '../../types'
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
* Base OAI Remote Inference Provider
|
||||
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||
*/
|
||||
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||
// The inference engine
|
||||
abstract apiKey: string
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
super.onLoad()
|
||||
// These events are applicable to local inference providers
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
async loadModel(model: Model) {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
unloadModel(model: Model) {
|
||||
if (model.engine && model.engine.toString() !== this.provider) return
|
||||
events.emit(ModelEvent.OnModelStopped, {})
|
||||
}
|
||||
|
||||
/**
|
||||
* Headers for the inference request
|
||||
*/
|
||||
override headers(): HeadersInit {
|
||||
return {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'api-key': `${this.apiKey}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,7 +2,8 @@ import { getJanDataFolderPath, joinPath } from '../../core'
|
||||
import { events } from '../../events'
|
||||
import { BaseExtension } from '../../extension'
|
||||
import { fs } from '../../fs'
|
||||
import { Model, ModelEvent } from '../../types'
|
||||
import { MessageRequest, Model, ModelEvent } from '../../types'
|
||||
import { EngineManager } from './EngineManager'
|
||||
|
||||
/**
|
||||
* Base AIEngine
|
||||
@ -11,30 +12,71 @@ import { Model, ModelEvent } from '../../types'
|
||||
export abstract class AIEngine extends BaseExtension {
|
||||
// The inference engine
|
||||
abstract provider: string
|
||||
// The model folder
|
||||
modelFolder: string = 'models'
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
override onLoad() {
|
||||
this.registerEngine()
|
||||
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
|
||||
this.prePopulateModels()
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines models
|
||||
*/
|
||||
models(): Promise<Model[]> {
|
||||
return Promise.resolve([])
|
||||
}
|
||||
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
* Registers AI Engines
|
||||
*/
|
||||
onLoad() {
|
||||
this.prePopulateModels()
|
||||
registerEngine() {
|
||||
EngineManager.instance()?.register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the model.
|
||||
*/
|
||||
async loadModel(model: Model): Promise<any> {
|
||||
if (model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
async unloadModel(model?: Model): Promise<any> {
|
||||
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelStopped, model ?? {})
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
/*
|
||||
* Inference request
|
||||
*/
|
||||
inference(data: MessageRequest) {}
|
||||
|
||||
/**
|
||||
* Stop inference
|
||||
*/
|
||||
stopInference() {}
|
||||
|
||||
/**
|
||||
* Pre-populate models to App Data Folder
|
||||
*/
|
||||
prePopulateModels(): Promise<void> {
|
||||
const modelFolder = 'models'
|
||||
return this.models().then((models) => {
|
||||
const prePoluateOperations = models.map((model) =>
|
||||
getJanDataFolderPath()
|
||||
.then((janDataFolder) =>
|
||||
// Attempt to create the model folder
|
||||
joinPath([janDataFolder, this.modelFolder, model.id]).then((path) =>
|
||||
joinPath([janDataFolder, modelFolder, model.id]).then((path) =>
|
||||
fs
|
||||
.mkdir(path)
|
||||
.catch()
|
||||
34
core/src/extensions/engines/EngineManager.ts
Normal file
34
core/src/extensions/engines/EngineManager.ts
Normal file
@ -0,0 +1,34 @@
|
||||
import { log } from '../../core'
|
||||
import { AIEngine } from './AIEngine'
|
||||
|
||||
/**
|
||||
* Manages the registration and retrieval of inference engines.
|
||||
*/
|
||||
export class EngineManager {
|
||||
public engines = new Map<string, AIEngine>()
|
||||
|
||||
/**
|
||||
* Registers an engine.
|
||||
* @param engine - The engine to register.
|
||||
*/
|
||||
register<T extends AIEngine>(engine: T) {
|
||||
this.engines.set(engine.provider, engine)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a engine by provider.
|
||||
* @param provider - The name of the engine to retrieve.
|
||||
* @returns The engine, if found.
|
||||
*/
|
||||
get<T extends AIEngine>(provider: string): T | undefined {
|
||||
return this.engines.get(provider) as T | undefined
|
||||
}
|
||||
|
||||
static instance(): EngineManager | undefined {
|
||||
return window.core?.engineManager as EngineManager
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The singleton instance of the ExtensionManager.
|
||||
*/
|
||||
@ -16,7 +16,7 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
// These events are applicable to local inference providers
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
@ -26,10 +26,10 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
/**
|
||||
* Load the model.
|
||||
*/
|
||||
async loadModel(model: Model) {
|
||||
override async loadModel(model: Model): Promise<void> {
|
||||
if (model.engine.toString() !== this.provider) return
|
||||
|
||||
const modelFolder = await joinPath([await getJanDataFolderPath(), this.modelFolder, model.id])
|
||||
const modelFolderName = 'models'
|
||||
const modelFolder = await joinPath([await getJanDataFolderPath(), modelFolderName, model.id])
|
||||
const systemInfo = await systemInformation()
|
||||
const res = await executeOnMain(
|
||||
this.nodeModule,
|
||||
@ -42,24 +42,22 @@ export abstract class LocalOAIEngine extends OAIEngine {
|
||||
)
|
||||
|
||||
if (res?.error) {
|
||||
events.emit(ModelEvent.OnModelFail, {
|
||||
...model,
|
||||
error: res.error,
|
||||
})
|
||||
return
|
||||
events.emit(ModelEvent.OnModelFail, { error: res.error })
|
||||
return Promise.reject(res.error)
|
||||
} else {
|
||||
this.loadedModel = model
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
unloadModel(model: Model) {
|
||||
if (model.engine && model.engine?.toString() !== this.provider) return
|
||||
this.loadedModel = undefined
|
||||
override async unloadModel(model?: Model): Promise<void> {
|
||||
if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve()
|
||||
|
||||
executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||
this.loadedModel = undefined
|
||||
return executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => {
|
||||
events.emit(ModelEvent.OnModelStopped, {})
|
||||
})
|
||||
}
|
||||
@ -34,7 +34,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
onLoad() {
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
events.on(MessageEvent.OnMessageSent, (data: MessageRequest) => this.inference(data))
|
||||
events.on(InferenceEvent.OnInferenceStopped, () => this.stopInference())
|
||||
@ -43,12 +43,12 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* On extension unload
|
||||
*/
|
||||
onUnload(): void {}
|
||||
override onUnload(): void {}
|
||||
|
||||
/*
|
||||
* Inference request
|
||||
*/
|
||||
inference(data: MessageRequest) {
|
||||
override inference(data: MessageRequest) {
|
||||
if (data.model?.engine?.toString() !== this.provider) return
|
||||
|
||||
const timestamp = Date.now()
|
||||
@ -114,7 +114,7 @@ export abstract class OAIEngine extends AIEngine {
|
||||
/**
|
||||
* Stops the inference.
|
||||
*/
|
||||
stopInference() {
|
||||
override stopInference() {
|
||||
this.isCancelled = true
|
||||
this.controller?.abort()
|
||||
}
|
||||
26
core/src/extensions/engines/RemoteOAIEngine.ts
Normal file
26
core/src/extensions/engines/RemoteOAIEngine.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import { OAIEngine } from './OAIEngine'
|
||||
|
||||
/**
|
||||
* Base OAI Remote Inference Provider
|
||||
* Added the implementation of loading and unloading model (applicable to local inference providers)
|
||||
*/
|
||||
export abstract class RemoteOAIEngine extends OAIEngine {
|
||||
// The inference engine
|
||||
abstract apiKey: string
|
||||
/**
|
||||
* On extension load, subscribe to events.
|
||||
*/
|
||||
override onLoad() {
|
||||
super.onLoad()
|
||||
}
|
||||
|
||||
/**
|
||||
* Headers for the inference request
|
||||
*/
|
||||
override headers(): HeadersInit {
|
||||
return {
|
||||
'Authorization': `Bearer ${this.apiKey}`,
|
||||
'api-key': `${this.apiKey}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,3 +2,4 @@ export * from './AIEngine'
|
||||
export * from './OAIEngine'
|
||||
export * from './LocalOAIEngine'
|
||||
export * from './RemoteOAIEngine'
|
||||
export * from './EngineManager'
|
||||
@ -32,4 +32,4 @@ export { HuggingFaceExtension } from './huggingface'
|
||||
/**
|
||||
* Base AI Engines.
|
||||
*/
|
||||
export * from './ai-engines'
|
||||
export * from './engines'
|
||||
|
||||
@ -91,15 +91,14 @@ export default class JanInferenceNitroExtension extends LocalOAIEngine {
|
||||
return super.loadModel(model)
|
||||
}
|
||||
|
||||
override unloadModel(model: Model): void {
|
||||
super.unloadModel(model)
|
||||
|
||||
if (model.engine && model.engine !== this.provider) return
|
||||
override async unloadModel(model?: Model) {
|
||||
if (model?.engine && model.engine !== this.provider) return
|
||||
|
||||
// stop the periocally health check
|
||||
if (this.getNitroProcesHealthIntervalId) {
|
||||
clearInterval(this.getNitroProcesHealthIntervalId)
|
||||
this.getNitroProcesHealthIntervalId = undefined
|
||||
}
|
||||
return super.unloadModel(model)
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,26 +8,17 @@ import {
|
||||
ExtensionTypeEnum,
|
||||
MessageStatus,
|
||||
MessageRequest,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
MessageEvent,
|
||||
MessageRequestType,
|
||||
ModelEvent,
|
||||
Thread,
|
||||
ModelInitFailed,
|
||||
EngineManager,
|
||||
} from '@janhq/core'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import {
|
||||
activeModelAtom,
|
||||
loadModelErrorAtom,
|
||||
stateModelAtom,
|
||||
} from '@/hooks/useActiveModel'
|
||||
|
||||
import { queuedMessageAtom } from '@/hooks/useSendChatMessage'
|
||||
|
||||
import { toaster } from '../Toast'
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
@ -51,8 +42,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
const activeModel = useAtomValue(activeModelAtom)
|
||||
const setActiveModel = useSetAtom(activeModelAtom)
|
||||
const setStateModel = useSetAtom(stateModelAtom)
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
const setLoadModelError = useSetAtom(loadModelErrorAtom)
|
||||
|
||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
@ -88,44 +77,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
[addNewMessage]
|
||||
)
|
||||
|
||||
const onModelReady = useCallback(
|
||||
(model: Model) => {
|
||||
setActiveModel(model)
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been started.`,
|
||||
type: 'success',
|
||||
})
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
},
|
||||
[setActiveModel, setStateModel]
|
||||
)
|
||||
|
||||
const onModelStopped = useCallback(() => {
|
||||
setTimeout(() => {
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
}, 500)
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
}, [setActiveModel, setStateModel])
|
||||
|
||||
const onModelInitFailed = useCallback(
|
||||
(res: ModelInitFailed) => {
|
||||
console.error('Failed to load model: ', res.error.message)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: res.id,
|
||||
}))
|
||||
setLoadModelError(res.error.message)
|
||||
setQueuedMessage(false)
|
||||
},
|
||||
[setStateModel, setQueuedMessage, setLoadModelError]
|
||||
)
|
||||
|
||||
const updateThreadTitle = useCallback(
|
||||
(message: ThreadMessage) => {
|
||||
// Update only when it's finished
|
||||
@ -274,7 +230,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
|
||||
// 2. Update the title with the result of the inference
|
||||
setTimeout(() => {
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
const engine = EngineManager.instance()?.get(
|
||||
messageRequest.model?.engine ?? activeModelRef.current?.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
}, 1000)
|
||||
}
|
||||
}
|
||||
@ -283,23 +242,16 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
if (window.core?.events) {
|
||||
events.on(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||
events.on(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||
events.on(ModelEvent.OnModelReady, onModelReady)
|
||||
events.on(ModelEvent.OnModelFail, onModelInitFailed)
|
||||
events.on(ModelEvent.OnModelStopped, onModelStopped)
|
||||
}
|
||||
}, [
|
||||
onNewMessageResponse,
|
||||
onMessageResponseUpdate,
|
||||
onModelReady,
|
||||
onModelInitFailed,
|
||||
onModelStopped,
|
||||
])
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
events.off(MessageEvent.OnMessageResponse, onNewMessageResponse)
|
||||
events.off(MessageEvent.OnMessageUpdate, onMessageResponseUpdate)
|
||||
events.off(ModelEvent.OnModelStopped, onModelStopped)
|
||||
}
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate])
|
||||
}, [onNewMessageResponse, onMessageResponseUpdate, onModelStopped])
|
||||
return <Fragment>{children}</Fragment>
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import { BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||
import { AIEngine, BaseExtension, ExtensionTypeEnum } from '@janhq/core'
|
||||
|
||||
import Extension from './Extension'
|
||||
|
||||
@ -8,14 +8,26 @@ import Extension from './Extension'
|
||||
* Manages the registration and retrieval of extensions.
|
||||
*/
|
||||
export class ExtensionManager {
|
||||
// Registered extensions
|
||||
private extensions = new Map<string, BaseExtension>()
|
||||
|
||||
// Registered inference engines
|
||||
private engines = new Map<string, AIEngine>()
|
||||
|
||||
/**
|
||||
* Registers an extension.
|
||||
* @param extension - The extension to register.
|
||||
*/
|
||||
register<T extends BaseExtension>(name: string, extension: T) {
|
||||
this.extensions.set(extension.type() ?? name, extension)
|
||||
|
||||
// Register AI Engines
|
||||
if ('provider' in extension && typeof extension.provider === 'string') {
|
||||
this.engines.set(
|
||||
extension.provider as unknown as string,
|
||||
extension as unknown as AIEngine
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -29,6 +41,15 @@ export class ExtensionManager {
|
||||
return this.extensions.get(type) as T | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a extension by its type.
|
||||
* @param engine - The engine name to retrieve.
|
||||
* @returns The extension, if found.
|
||||
*/
|
||||
getEngine<T extends AIEngine>(engine: string): T | undefined {
|
||||
return this.engines.get(engine) as T | undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads all registered extension.
|
||||
*/
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
|
||||
import { events, Model, ModelEvent } from '@janhq/core'
|
||||
import { EngineManager, Model } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
|
||||
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
@ -38,19 +39,13 @@ export function useActiveModel() {
|
||||
(stateModel.model === modelId && stateModel.loading)
|
||||
) {
|
||||
console.debug(`Model ${modelId} is already initialized. Ignore..`)
|
||||
return
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
|
||||
|
||||
// Switch between engines
|
||||
if (model && activeModel && activeModel.engine !== model.engine) {
|
||||
stopModel()
|
||||
// TODO: Refactor inference provider would address this
|
||||
await new Promise((res) => setTimeout(res, 1000))
|
||||
}
|
||||
await stopModel().catch()
|
||||
|
||||
// TODO: incase we have multiple assistants, the configuration will be from assistant
|
||||
setLoadModelError(undefined)
|
||||
|
||||
setActiveModel(undefined)
|
||||
@ -68,7 +63,8 @@ export function useActiveModel() {
|
||||
loading: false,
|
||||
model: '',
|
||||
}))
|
||||
return
|
||||
|
||||
return Promise.reject(`Model ${modelId} not found!`)
|
||||
}
|
||||
|
||||
/// Apply thread model settings
|
||||
@ -83,15 +79,52 @@ export function useActiveModel() {
|
||||
}
|
||||
|
||||
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
|
||||
events.emit(ModelEvent.OnModelInit, model)
|
||||
const engine = EngineManager.instance()?.get(model.engine)
|
||||
return engine
|
||||
?.loadModel(model)
|
||||
.then(() => {
|
||||
setActiveModel(model)
|
||||
setStateModel(() => ({
|
||||
state: 'stop',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
toaster({
|
||||
title: 'Success!',
|
||||
description: `Model ${model.id} has been started.`,
|
||||
type: 'success',
|
||||
})
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to load model: ', error)
|
||||
setStateModel(() => ({
|
||||
state: 'start',
|
||||
loading: false,
|
||||
model: model.id,
|
||||
}))
|
||||
|
||||
toaster({
|
||||
title: 'Failed!',
|
||||
description: `Model ${model.id} failed to start.`,
|
||||
type: 'success',
|
||||
})
|
||||
setLoadModelError(error)
|
||||
})
|
||||
}
|
||||
|
||||
const stopModel = useCallback(async () => {
|
||||
if (activeModel) {
|
||||
setStateModel({ state: 'stop', loading: true, model: activeModel.id })
|
||||
events.emit(ModelEvent.OnModelStop, activeModel)
|
||||
const engine = EngineManager.instance()?.get(activeModel.engine)
|
||||
await engine
|
||||
?.unloadModel(activeModel)
|
||||
.catch()
|
||||
.then(() => {
|
||||
setActiveModel(undefined)
|
||||
setStateModel({ state: 'start', loading: false, model: '' })
|
||||
})
|
||||
}
|
||||
}, [activeModel, setStateModel])
|
||||
}, [activeModel, setActiveModel, setStateModel])
|
||||
|
||||
return { activeModel, startModel, stopModel, stateModel }
|
||||
}
|
||||
|
||||
@ -11,13 +11,12 @@ import {
|
||||
ExtensionTypeEnum,
|
||||
Thread,
|
||||
ThreadMessage,
|
||||
events,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
MessageEvent,
|
||||
InferenceEngine,
|
||||
ChatCompletionMessageContentType,
|
||||
AssistantTool,
|
||||
EngineManager,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
@ -65,7 +64,6 @@ export default function useSendChatMessage() {
|
||||
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const { activeModel, startModel } = useActiveModel()
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
const loadModelFailed = useAtomValue(loadModelErrorAtom)
|
||||
|
||||
const modelRef = useRef<Model | undefined>()
|
||||
@ -78,6 +76,7 @@ export default function useSendChatMessage() {
|
||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||
const activeThreadRef = useRef<Thread | undefined>()
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
|
||||
const selectedModelRef = useRef<Model | undefined>()
|
||||
|
||||
@ -142,10 +141,7 @@ export default function useSendChatMessage() {
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
startModel(modelId)
|
||||
await waitForModelStarting(modelId)
|
||||
setQueuedMessage(false)
|
||||
await startModel(modelId)
|
||||
}
|
||||
setIsGeneratingResponse(true)
|
||||
if (currentMessage.role !== ChatCompletionRole.User) {
|
||||
@ -160,7 +156,10 @@ export default function useSendChatMessage() {
|
||||
)
|
||||
}
|
||||
}
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
const engine = EngineManager.instance()?.get(
|
||||
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
}
|
||||
|
||||
const sendChatMessage = async (message: string) => {
|
||||
@ -364,30 +363,20 @@ export default function useSendChatMessage() {
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
startModel(modelId)
|
||||
await waitForModelStarting(modelId)
|
||||
await startModel(modelId)
|
||||
setQueuedMessage(false)
|
||||
}
|
||||
setIsGeneratingResponse(true)
|
||||
events.emit(MessageEvent.OnMessageSent, messageRequest)
|
||||
|
||||
const engine = EngineManager.instance()?.get(
|
||||
messageRequest.model?.engine ?? modelRequest.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
|
||||
setReloadModel(false)
|
||||
setEngineParamsUpdate(false)
|
||||
}
|
||||
|
||||
const waitForModelStarting = async (modelId: string) => {
|
||||
return new Promise<void>((resolve) => {
|
||||
setTimeout(async () => {
|
||||
if (modelRef.current?.id !== modelId && !loadModelFailedRef.current) {
|
||||
await waitForModelStarting(modelId)
|
||||
resolve()
|
||||
} else {
|
||||
resolve()
|
||||
}
|
||||
}, 200)
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
sendChatMessage,
|
||||
resendChatMessage,
|
||||
|
||||
@ -74,7 +74,8 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
|
||||
</p>
|
||||
<ModalTroubleShooting />
|
||||
</div>
|
||||
) : loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
||||
) : loadModelError &&
|
||||
loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? (
|
||||
<div
|
||||
key={message.id}
|
||||
className="flex w-full flex-col items-center text-center text-sm font-medium text-gray-500"
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { EngineManager } from '@janhq/core'
|
||||
|
||||
import { appService } from './appService'
|
||||
import { EventEmitter } from './eventsService'
|
||||
import { restAPI } from './restService'
|
||||
@ -12,6 +14,7 @@ export const setupCoreServices = () => {
|
||||
if (!window.core) {
|
||||
window.core = {
|
||||
events: new EventEmitter(),
|
||||
engineManager: new EngineManager(),
|
||||
api: {
|
||||
...(window.electronAPI ? window.electronAPI : restAPI),
|
||||
...appService,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user