fix: Use Events for init, load, stop models

This commit is contained in:
hiro 2023-12-04 12:10:24 +07:00
parent 9aca37a30c
commit 1bc5fe64f3
5 changed files with 93 additions and 55 deletions

View File

@ -8,6 +8,14 @@ export enum EventName {
OnMessageResponse = "OnMessageResponse", OnMessageResponse = "OnMessageResponse",
/** The `OnMessageUpdate` event is emitted when a message is updated. */ /** The `OnMessageUpdate` event is emitted when a message is updated. */
OnMessageUpdate = "OnMessageUpdate", OnMessageUpdate = "OnMessageUpdate",
/** The `OnModelInit` event is emitted when a model inits. */
OnModelInit = "OnModelInit",
/** The `OnModelReady` event is emitted when a model ready. */
OnModelReady = "OnModelReady",
/** The `OnModelFail` event is emitted when a model fails loading. */
OnModelFail = "OnModelFail",
/** The `OnModelStop` event is emitted when a model fails loading. */
OnModelStop = "OnModelStop",
} }
/** /**

View File

@ -166,6 +166,17 @@ export type ThreadState = {
error?: Error; error?: Error;
lastMessage?: string; lastMessage?: string;
}; };
/**
* Represents the inference engine.
* @stored
*/
enum InferenceEngine {
llama_cpp = "llama_cpp",
openai = "openai",
nvidia_triton = "nvidia_triton",
hf_endpoint = "hf_endpoint",
}
/** /**
* Model type defines the shape of a model object. * Model type defines the shape of a model object.
@ -234,6 +245,10 @@ export interface Model {
* Metadata of the model. * Metadata of the model.
*/ */
metadata: ModelMetadata; metadata: ModelMetadata;
/**
* The model engine. Enum: "llamacpp" "openai"
*/
engine: InferenceEngine;
} }
export type ModelMetadata = { export type ModelMetadata = {

View File

@ -7,9 +7,10 @@ import {
ThreadMessage, ThreadMessage,
ExtensionType, ExtensionType,
MessageStatus, MessageStatus,
Model
} from '@janhq/core' } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core' import { ConversationalExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
@ -21,9 +22,16 @@ import {
threadsAtom, threadsAtom,
} from '@/helpers/atoms/Conversation.atom' } from '@/helpers/atoms/Conversation.atom'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { toaster } from '../Toast'
export default function EventHandler({ children }: { children: ReactNode }) { export default function EventHandler({ children }: { children: ReactNode }) {
const addNewMessage = useSetAtom(addNewMessageAtom) const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom) const updateMessage = useSetAtom(updateMessageAtom)
const { downloadedModels } = useGetDownloadedModels()
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
const [stateModel, setStateModel] = useAtom(stateModelAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom) const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const threads = useAtomValue(threadsAtom) const threads = useAtomValue(threadsAtom)
@ -37,6 +45,42 @@ export default function EventHandler({ children }: { children: ReactNode }) {
addNewMessage(message) addNewMessage(message)
} }
async function handleModelReady(res: any) {
const model = downloadedModels.find((e) => e.id === res.modelId)
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${res.modelId} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: res.modelId,
}))
}
async function handleModelStop(res: any) {
const model = downloadedModels.find((e) => e.id === res.modelId)
setTimeout(async () => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${res.modelId} has been stopped.`,
})
}, 500)
}
async function handleModelFail(res: any) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: res.modelId,
}))
}
async function handleMessageResponseUpdate(message: ThreadMessage) { async function handleMessageResponseUpdate(message: ThreadMessage) {
updateMessage( updateMessage(
message.id, message.id,
@ -73,6 +117,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
if (window.core.events) { if (window.core.events) {
events.on(EventName.OnMessageResponse, handleNewMessageResponse) events.on(EventName.OnMessageResponse, handleNewMessageResponse)
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate) events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
events.on(EventName.OnModelReady, handleModelReady)
events.on(EventName.OnModelFail, handleModelFail)
events.on(EventName.OnModelStop, handleModelStop)
} }
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])

View File

@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { ExtensionType, InferenceExtension } from '@janhq/core' import { EventName, ExtensionType, InferenceExtension, events } from '@janhq/core'
import { Model, ModelSettingParams } from '@janhq/core' import { Model, ModelSettingParams } from '@janhq/core'
import { atom, useAtom } from 'jotai' import { atom, useAtom } from 'jotai'
@ -9,9 +9,9 @@ import { useGetDownloadedModels } from './useGetDownloadedModels'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
const activeModelAtom = atom<Model | undefined>(undefined) export const activeModelAtom = atom<Model | undefined>(undefined)
const stateModelAtom = atom({ state: 'start', loading: false, model: '' }) export const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
export function useActiveModel() { export function useActiveModel() {
const [activeModel, setActiveModel] = useAtom(activeModelAtom) const [activeModel, setActiveModel] = useAtom(activeModelAtom)
@ -47,59 +47,13 @@ export function useActiveModel() {
return return
} }
const currentTime = Date.now() events.emit(EventName.OnModelInit, model)
const res = await initModel(modelId, model?.settings)
if (res && res.error) {
const errorMessage = `${res.error}`
alert(errorMessage)
setStateModel(() => ({
state: 'start',
loading: false,
model: modelId,
}))
} else {
console.debug(
`Model ${modelId} successfully initialized! Took ${
Date.now() - currentTime
}ms`
)
setActiveModel(model)
toaster({
title: 'Success!',
description: `Model ${modelId} has been started.`,
})
setStateModel(() => ({
state: 'stop',
loading: false,
model: modelId,
}))
}
} }
const stopModel = async (modelId: string) => { const stopModel = async (modelId: string) => {
setStateModel({ state: 'stop', loading: true, model: modelId }) setStateModel({ state: 'stop', loading: true, model: modelId })
setTimeout(async () => { events.emit(EventName.OnModelStop, modelId)
extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.stopModel()
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: '' })
toaster({
title: 'Success!',
description: `Model ${modelId} has been stopped.`,
})
}, 500)
} }
return { activeModel, startModel, stopModel, stateModel } return { activeModel, startModel, stopModel, stateModel }
} }
const initModel = async (
modelId: string,
settings?: ModelSettingParams
): Promise<any> => {
return extensionManager
.get<InferenceExtension>(ExtensionType.Inference)
?.initModel(modelId, settings)
}

View File

@ -55,9 +55,23 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
let downloadButton = ( let downloadButton;
<Button onClick={() => onDownloadClick()}>Download</Button>
) if (model.engine !== 'nitro') {
downloadButton = (
<Button onClick={() => onDownloadClick()}>
Use
</Button>
);
} else if (model.engine === 'nitro') {
downloadButton = (
<Button onClick={() => onDownloadClick()}>
{model.metadata.size
? `Download (${toGigabytes(model.metadata.size)})`
: 'Download'}
</Button>
);
}
const onUseModelClick = () => { const onUseModelClick = () => {
startModel(model.id) startModel(model.id)