fix: Use Events for init, load, stop models
This commit is contained in:
parent
9aca37a30c
commit
1bc5fe64f3
@ -8,6 +8,14 @@ export enum EventName {
|
|||||||
OnMessageResponse = "OnMessageResponse",
|
OnMessageResponse = "OnMessageResponse",
|
||||||
/** The `OnMessageUpdate` event is emitted when a message is updated. */
|
/** The `OnMessageUpdate` event is emitted when a message is updated. */
|
||||||
OnMessageUpdate = "OnMessageUpdate",
|
OnMessageUpdate = "OnMessageUpdate",
|
||||||
|
/** The `OnModelInit` event is emitted when a model inits. */
|
||||||
|
OnModelInit = "OnModelInit",
|
||||||
|
/** The `OnModelReady` event is emitted when a model ready. */
|
||||||
|
OnModelReady = "OnModelReady",
|
||||||
|
/** The `OnModelFail` event is emitted when a model fails loading. */
|
||||||
|
OnModelFail = "OnModelFail",
|
||||||
|
/** The `OnModelStop` event is emitted when a model fails loading. */
|
||||||
|
OnModelStop = "OnModelStop",
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -166,6 +166,17 @@ export type ThreadState = {
|
|||||||
error?: Error;
|
error?: Error;
|
||||||
lastMessage?: string;
|
lastMessage?: string;
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* Represents the inference engine.
|
||||||
|
* @stored
|
||||||
|
*/
|
||||||
|
|
||||||
|
enum InferenceEngine {
|
||||||
|
llama_cpp = "llama_cpp",
|
||||||
|
openai = "openai",
|
||||||
|
nvidia_triton = "nvidia_triton",
|
||||||
|
hf_endpoint = "hf_endpoint",
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model type defines the shape of a model object.
|
* Model type defines the shape of a model object.
|
||||||
@ -234,6 +245,10 @@ export interface Model {
|
|||||||
* Metadata of the model.
|
* Metadata of the model.
|
||||||
*/
|
*/
|
||||||
metadata: ModelMetadata;
|
metadata: ModelMetadata;
|
||||||
|
/**
|
||||||
|
* The model engine. Enum: "llamacpp" "openai"
|
||||||
|
*/
|
||||||
|
engine: InferenceEngine;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ModelMetadata = {
|
export type ModelMetadata = {
|
||||||
|
|||||||
@ -7,9 +7,10 @@ import {
|
|||||||
ThreadMessage,
|
ThreadMessage,
|
||||||
ExtensionType,
|
ExtensionType,
|
||||||
MessageStatus,
|
MessageStatus,
|
||||||
|
Model
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { ConversationalExtension } from '@janhq/core'
|
import { ConversationalExtension } from '@janhq/core'
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import {
|
||||||
@ -21,9 +22,16 @@ import {
|
|||||||
threadsAtom,
|
threadsAtom,
|
||||||
} from '@/helpers/atoms/Conversation.atom'
|
} from '@/helpers/atoms/Conversation.atom'
|
||||||
|
|
||||||
|
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||||
|
import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
|
||||||
|
import { toaster } from '../Toast'
|
||||||
|
|
||||||
export default function EventHandler({ children }: { children: ReactNode }) {
|
export default function EventHandler({ children }: { children: ReactNode }) {
|
||||||
const addNewMessage = useSetAtom(addNewMessageAtom)
|
const addNewMessage = useSetAtom(addNewMessageAtom)
|
||||||
const updateMessage = useSetAtom(updateMessageAtom)
|
const updateMessage = useSetAtom(updateMessageAtom)
|
||||||
|
const { downloadedModels } = useGetDownloadedModels()
|
||||||
|
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
|
||||||
|
const [stateModel, setStateModel] = useAtom(stateModelAtom)
|
||||||
|
|
||||||
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
|
||||||
const threads = useAtomValue(threadsAtom)
|
const threads = useAtomValue(threadsAtom)
|
||||||
@ -37,6 +45,42 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
addNewMessage(message)
|
addNewMessage(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleModelReady(res: any) {
|
||||||
|
const model = downloadedModels.find((e) => e.id === res.modelId)
|
||||||
|
setActiveModel(model)
|
||||||
|
toaster({
|
||||||
|
title: 'Success!',
|
||||||
|
description: `Model ${res.modelId} has been started.`,
|
||||||
|
})
|
||||||
|
setStateModel(() => ({
|
||||||
|
state: 'stop',
|
||||||
|
loading: false,
|
||||||
|
model: res.modelId,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleModelStop(res: any) {
|
||||||
|
const model = downloadedModels.find((e) => e.id === res.modelId)
|
||||||
|
setTimeout(async () => {
|
||||||
|
setActiveModel(undefined)
|
||||||
|
setStateModel({ state: 'start', loading: false, model: '' })
|
||||||
|
toaster({
|
||||||
|
title: 'Success!',
|
||||||
|
description: `Model ${res.modelId} has been stopped.`,
|
||||||
|
})
|
||||||
|
}, 500)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleModelFail(res: any) {
|
||||||
|
const errorMessage = `${res.error}`
|
||||||
|
alert(errorMessage)
|
||||||
|
setStateModel(() => ({
|
||||||
|
state: 'start',
|
||||||
|
loading: false,
|
||||||
|
model: res.modelId,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
async function handleMessageResponseUpdate(message: ThreadMessage) {
|
async function handleMessageResponseUpdate(message: ThreadMessage) {
|
||||||
updateMessage(
|
updateMessage(
|
||||||
message.id,
|
message.id,
|
||||||
@ -73,6 +117,9 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
if (window.core.events) {
|
if (window.core.events) {
|
||||||
events.on(EventName.OnMessageResponse, handleNewMessageResponse)
|
events.on(EventName.OnMessageResponse, handleNewMessageResponse)
|
||||||
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
|
events.on(EventName.OnMessageUpdate, handleMessageResponseUpdate)
|
||||||
|
events.on(EventName.OnModelReady, handleModelReady)
|
||||||
|
events.on(EventName.OnModelFail, handleModelFail)
|
||||||
|
events.on(EventName.OnModelStop, handleModelStop)
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { ExtensionType, InferenceExtension } from '@janhq/core'
|
import { EventName, ExtensionType, InferenceExtension, events } from '@janhq/core'
|
||||||
import { Model, ModelSettingParams } from '@janhq/core'
|
import { Model, ModelSettingParams } from '@janhq/core'
|
||||||
import { atom, useAtom } from 'jotai'
|
import { atom, useAtom } from 'jotai'
|
||||||
|
|
||||||
@ -9,9 +9,9 @@ import { useGetDownloadedModels } from './useGetDownloadedModels'
|
|||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
const activeModelAtom = atom<Model | undefined>(undefined)
|
export const activeModelAtom = atom<Model | undefined>(undefined)
|
||||||
|
|
||||||
const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
|
export const stateModelAtom = atom({ state: 'start', loading: false, model: '' })
|
||||||
|
|
||||||
export function useActiveModel() {
|
export function useActiveModel() {
|
||||||
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
|
const [activeModel, setActiveModel] = useAtom(activeModelAtom)
|
||||||
@ -47,59 +47,13 @@ export function useActiveModel() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentTime = Date.now()
|
events.emit(EventName.OnModelInit, model)
|
||||||
const res = await initModel(modelId, model?.settings)
|
|
||||||
if (res && res.error) {
|
|
||||||
const errorMessage = `${res.error}`
|
|
||||||
alert(errorMessage)
|
|
||||||
setStateModel(() => ({
|
|
||||||
state: 'start',
|
|
||||||
loading: false,
|
|
||||||
model: modelId,
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
console.debug(
|
|
||||||
`Model ${modelId} successfully initialized! Took ${
|
|
||||||
Date.now() - currentTime
|
|
||||||
}ms`
|
|
||||||
)
|
|
||||||
setActiveModel(model)
|
|
||||||
toaster({
|
|
||||||
title: 'Success!',
|
|
||||||
description: `Model ${modelId} has been started.`,
|
|
||||||
})
|
|
||||||
setStateModel(() => ({
|
|
||||||
state: 'stop',
|
|
||||||
loading: false,
|
|
||||||
model: modelId,
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const stopModel = async (modelId: string) => {
|
const stopModel = async (modelId: string) => {
|
||||||
setStateModel({ state: 'stop', loading: true, model: modelId })
|
setStateModel({ state: 'stop', loading: true, model: modelId })
|
||||||
setTimeout(async () => {
|
events.emit(EventName.OnModelStop, modelId)
|
||||||
extensionManager
|
|
||||||
.get<InferenceExtension>(ExtensionType.Inference)
|
|
||||||
?.stopModel()
|
|
||||||
|
|
||||||
setActiveModel(undefined)
|
|
||||||
setStateModel({ state: 'start', loading: false, model: '' })
|
|
||||||
toaster({
|
|
||||||
title: 'Success!',
|
|
||||||
description: `Model ${modelId} has been stopped.`,
|
|
||||||
})
|
|
||||||
}, 500)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return { activeModel, startModel, stopModel, stateModel }
|
return { activeModel, startModel, stopModel, stateModel }
|
||||||
}
|
}
|
||||||
|
|
||||||
const initModel = async (
|
|
||||||
modelId: string,
|
|
||||||
settings?: ModelSettingParams
|
|
||||||
): Promise<any> => {
|
|
||||||
return extensionManager
|
|
||||||
.get<InferenceExtension>(ExtensionType.Inference)
|
|
||||||
?.initModel(modelId, settings)
|
|
||||||
}
|
|
||||||
|
|||||||
@ -55,9 +55,23 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
|
|||||||
|
|
||||||
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
|
||||||
|
|
||||||
let downloadButton = (
|
let downloadButton;
|
||||||
<Button onClick={() => onDownloadClick()}>Download</Button>
|
|
||||||
)
|
if (model.engine !== 'nitro') {
|
||||||
|
downloadButton = (
|
||||||
|
<Button onClick={() => onDownloadClick()}>
|
||||||
|
Use
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
} else if (model.engine === 'nitro') {
|
||||||
|
downloadButton = (
|
||||||
|
<Button onClick={() => onDownloadClick()}>
|
||||||
|
{model.metadata.size
|
||||||
|
? `Download (${toGigabytes(model.metadata.size)})`
|
||||||
|
: 'Download'}
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
const onUseModelClick = () => {
|
const onUseModelClick = () => {
|
||||||
startModel(model.id)
|
startModel(model.id)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user