From 1bc5fe64f3bf4e3deb390a0690a5425cce41dcfe Mon Sep 17 00:00:00 2001 From: hiro Date: Mon, 4 Dec 2023 12:10:24 +0700 Subject: [PATCH] fix: Use Events for init, load, stop models --- core/src/events.ts | 8 +++ core/src/types/index.ts | 15 +++++ web/containers/Providers/EventHandler.tsx | 49 +++++++++++++++- web/hooks/useActiveModel.ts | 56 ++----------------- .../ExploreModelItemHeader/index.tsx | 20 ++++++- 5 files changed, 93 insertions(+), 55 deletions(-) diff --git a/core/src/events.ts b/core/src/events.ts index f588daad7..81451c1f0 100644 --- a/core/src/events.ts +++ b/core/src/events.ts @@ -8,6 +8,14 @@ export enum EventName { OnMessageResponse = "OnMessageResponse", /** The `OnMessageUpdate` event is emitted when a message is updated. */ OnMessageUpdate = "OnMessageUpdate", + /** The `OnModelInit` event is emitted when a model inits. */ + OnModelInit = "OnModelInit", + /** The `OnModelReady` event is emitted when a model ready. */ + OnModelReady = "OnModelReady", + /** The `OnModelFail` event is emitted when a model fails loading. */ + OnModelFail = "OnModelFail", + /** The `OnModelStop` event is emitted when a model fails loading. */ + OnModelStop = "OnModelStop", } /** diff --git a/core/src/types/index.ts b/core/src/types/index.ts index 87343aa65..5b45d4cc8 100644 --- a/core/src/types/index.ts +++ b/core/src/types/index.ts @@ -166,6 +166,17 @@ export type ThreadState = { error?: Error; lastMessage?: string; }; +/** + * Represents the inference engine. + * @stored + */ + +enum InferenceEngine { + llama_cpp = "llama_cpp", + openai = "openai", + nvidia_triton = "nvidia_triton", + hf_endpoint = "hf_endpoint", +} /** * Model type defines the shape of a model object. @@ -234,6 +245,10 @@ export interface Model { * Metadata of the model. */ metadata: ModelMetadata; + /** + * The model engine. Enum: "llamacpp" "openai" + */ + engine: InferenceEngine; } export type ModelMetadata = { diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 46f4b19d4..a3910e266 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -7,9 +7,10 @@ import { ThreadMessage, ExtensionType, MessageStatus, + Model } from '@janhq/core' import { ConversationalExtension } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { extensionManager } from '@/extension' import { @@ -21,9 +22,16 @@ import { threadsAtom, } from '@/helpers/atoms/Conversation.atom' +import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' +import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' +import { toaster } from '../Toast' + export default function EventHandler({ children }: { children: ReactNode }) { const addNewMessage = useSetAtom(addNewMessageAtom) const updateMessage = useSetAtom(updateMessageAtom) + const { downloadedModels } = useGetDownloadedModels() + const [activeModel, setActiveModel] = useAtom(activeModelAtom) + const [stateModel, setStateModel] = useAtom(stateModelAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const threads = useAtomValue(threadsAtom) @@ -37,6 +45,42 @@ export default function EventHandler({ children }: { children: ReactNode }) { addNewMessage(message) } + async function handleModelReady(res: any) { + const model = downloadedModels.find((e) => e.id === res.modelId) + setActiveModel(model) + toaster({ + title: 'Success!', + description: `Model ${res.modelId} has been started.`, + }) + setStateModel(() => ({ + state: 'stop', + loading: false, + model: res.modelId, + })) + } + + async function handleModelStop(res: any) { + const model = downloadedModels.find((e) => e.id === res.modelId) + setTimeout(async () => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: '' }) + toaster({ + title: 'Success!', + description: `Model ${res.modelId} has been stopped.`, + }) + }, 500) + } + + async function handleModelFail(res: any) { + const errorMessage = `${res.error}` + alert(errorMessage) + setStateModel(() => ({ + state: 'start', + loading: false, + model: res.modelId, + })) + } + async function handleMessageResponseUpdate(message: ThreadMessage) { updateMessage( message.id, @@ -73,6 +117,9 @@ export default function EventHandler({ children }: { children: ReactNode }) { if (window.core.events) { events.on(EventName.OnMessageResponse, handleNewMessageResponse) events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate) + events.on(EventName.OnModelReady, handleModelReady) + events.on(EventName.OnModelFail, handleModelFail) + events.on(EventName.OnModelStop, handleModelStop) } // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 60be0f2c4..4f1565e15 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { ExtensionType, InferenceExtension } from '@janhq/core' +import { EventName, ExtensionType, InferenceExtension, events } from '@janhq/core' import { Model, ModelSettingParams } from '@janhq/core' import { atom, useAtom } from 'jotai' @@ -9,9 +9,9 @@ import { useGetDownloadedModels } from './useGetDownloadedModels' import { extensionManager } from '@/extension' -const activeModelAtom = atom(undefined) +export const activeModelAtom = atom(undefined) -const stateModelAtom = atom({ state: 'start', loading: false, model: '' }) +export const stateModelAtom = atom({ state: 'start', loading: false, model: '' }) export function useActiveModel() { const [activeModel, setActiveModel] = useAtom(activeModelAtom) @@ -47,59 +47,13 @@ export function useActiveModel() { return } - const currentTime = Date.now() - const res = await initModel(modelId, model?.settings) - if (res && res.error) { - const errorMessage = `${res.error}` - alert(errorMessage) - setStateModel(() => ({ - state: 'start', - loading: false, - model: modelId, - })) - } else { - console.debug( - `Model ${modelId} successfully initialized! Took ${ - Date.now() - currentTime - }ms` - ) - setActiveModel(model) - toaster({ - title: 'Success!', - description: `Model ${modelId} has been started.`, - }) - setStateModel(() => ({ - state: 'stop', - loading: false, - model: modelId, - })) - } + events.emit(EventName.OnModelInit, model) } const stopModel = async (modelId: string) => { setStateModel({ state: 'stop', loading: true, model: modelId }) - setTimeout(async () => { - extensionManager - .get(ExtensionType.Inference) - ?.stopModel() - - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: '' }) - toaster({ - title: 'Success!', - description: `Model ${modelId} has been stopped.`, - }) - }, 500) + events.emit(EventName.OnModelStop, modelId) } return { activeModel, startModel, stopModel, stateModel } } - -const initModel = async ( - modelId: string, - settings?: ModelSettingParams -): Promise => { - return extensionManager - .get(ExtensionType.Inference) - ?.initModel(modelId, settings) -} diff --git a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx index ba23056c6..f5d54f0be 100644 --- a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx +++ b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx @@ -55,9 +55,23 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null - let downloadButton = ( - - ) + let downloadButton; + + if (model.engine !== 'nitro') { + downloadButton = ( + + ); + } else if (model.engine === 'nitro') { + downloadButton = ( + + ); + } const onUseModelClick = () => { startModel(model.id)