diff --git a/core/src/browser/extensions/model.ts b/core/src/browser/extensions/model.ts index 1fb94fba3..e224ec5cc 100644 --- a/core/src/browser/extensions/model.ts +++ b/core/src/browser/extensions/model.ts @@ -15,7 +15,13 @@ export abstract class ModelExtension extends BaseExtension implements ModelInter abstract getModels(): Promise abstract pullModel(model: string, id?: string, name?: string): Promise abstract cancelModelPull(modelId: string): Promise - abstract importModel(model: string, modePath: string, name?: string, optionType?: OptionType): Promise + abstract importModel( + model: string, + modePath: string, + name?: string, + optionType?: OptionType + ): Promise abstract updateModel(modelInfo: Partial): Promise abstract deleteModel(model: string): Promise + abstract isModelLoaded(model: string): Promise } diff --git a/extensions/model-extension/src/cortex.ts b/extensions/model-extension/src/cortex.ts index 024aa2223..5b7d1e36b 100644 --- a/extensions/model-extension/src/cortex.ts +++ b/extensions/model-extension/src/cortex.ts @@ -9,7 +9,12 @@ interface ICortexAPI { getModel(model: string): Promise getModels(): Promise pullModel(model: string, id?: string, name?: string): Promise - importModel(path: string, modelPath: string, name?: string, option?: string): Promise + importModel( + path: string, + modelPath: string, + name?: string, + option?: string + ): Promise deleteModel(model: string): Promise updateModel(model: object): Promise cancelModelPull(model: string): Promise @@ -141,6 +146,17 @@ export class CortexAPI implements ICortexAPI { ) } + /** + * Check model status + * @param model + */ + async getModelStatus(model: string): Promise { + return this.queue + .add(() => ky.get(`${API_URL}/models/status/${model}`)) + .then((e) => true) + .catch(() => false) + } + /** * Do health check on cortex.cpp * @returns @@ -215,7 +231,7 @@ export class CortexAPI implements ICortexAPI { } model.metadata = model.metadata ?? { tags: [], - size: model.size ?? model.metadata?.size ?? 0 + size: model.size ?? model.metadata?.size ?? 0, } return model as Model } diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 7d7514f3b..8f50bd5d0 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -238,6 +238,14 @@ export default class JanModelExtension extends ModelExtension { return this.cortexAPI.importModel(model, modelPath, name, option) } + /** + * Check model status + * @param model + */ + async isModelLoaded(model: string): Promise { + return this.cortexAPI.getModelStatus(model) + } + /** * Handle download state from main app */ diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 0f5cf389d..6cad910f7 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -16,6 +16,7 @@ import { EngineManager, InferenceEngine, extractInferenceParams, + ModelExtension, } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulidx' @@ -180,8 +181,16 @@ export default function EventHandler({ children }: { children: ReactNode }) { } return } else if (message.status === MessageStatus.Error) { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: undefined }) + ;(async () => { + if ( + !(await extensionManager + .get(ExtensionTypeEnum.Model) + ?.isModelLoaded(activeModelRef.current?.id as string)) + ) { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: undefined }) + } + })() } // Mark the thread as not waiting for response updateThreadWaiting(message.thread_id, false) diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index 400e02793..b8b680715 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -35,6 +35,10 @@ const useModels = () => { const localModels = (await getModels()).map((e) => ({ ...e, name: ModelManager.instance().models.get(e.id)?.name ?? e.id, + settings: + ModelManager.instance().models.get(e.id)?.settings ?? e.settings, + parameters: + ModelManager.instance().models.get(e.id)?.parameters ?? e.parameters, metadata: ModelManager.instance().models.get(e.id)?.metadata ?? e.metadata, }))