From e9fd7f4554f88dadd6d662150138158184894ee7 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 18 Nov 2024 18:22:09 +0700 Subject: [PATCH] fix: models load --- core/src/browser/extension.ts | 1 - .../inference-cortex-extension/src/index.ts | 4 +- extensions/model-extension/src/index.ts | 33 ++++++------- web/containers/Providers/DataLoader.tsx | 7 +-- web/containers/Providers/EventListener.tsx | 2 +- .../Providers/ModelImportListener.tsx | 2 +- web/hooks/useModels.ts | 47 +++++++++++-------- 7 files changed, 53 insertions(+), 43 deletions(-) diff --git a/core/src/browser/extension.ts b/core/src/browser/extension.ts index d934e1c06..b7a9fca4e 100644 --- a/core/src/browser/extension.ts +++ b/core/src/browser/extension.ts @@ -113,7 +113,6 @@ export abstract class BaseExtension implements ExtensionType { for (const model of models) { ModelManager.instance().register(model) } - events.emit(ModelEvent.OnModelsUpdate, {}) } /** diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index e83a17561..34a376ac8 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -215,7 +215,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { // Delay for the state update from cortex.cpp // Just to be sure setTimeout(() => { - events.emit(ModelEvent.OnModelsUpdate, {}) + events.emit(ModelEvent.OnModelsUpdate, { + fetch: true, + }) }, 500) } }) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index b3ad2a012..63f505bd6 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -40,11 +40,6 @@ export default class JanModelExtension extends ModelExtension { async onLoad() { this.registerSettings(SETTINGS) - // Try get models from cortex.cpp - this.getModels().then((models) => { - this.registerModels(models) - }) - // Listen to app download events this.handleDesktopEvents() } @@ -163,19 +158,27 @@ export default class JanModelExtension extends ModelExtension { (e) => e.engine === InferenceEngine.nitro ) - await this.cortexAPI.getModels().then((models) => { - const existingIds = models.map((e) => e.id) - toImportModels = toImportModels.filter( - (e: Model) => !existingIds.includes(e.id) && !e.settings?.vision_model - ) - }) + /** + * Fetch models from cortex.cpp + */ + var fetchedModels = await this.cortexAPI.getModels().catch(() => []) + + // Checking if there are models to import + const existingIds = fetchedModels.map((e) => e.id) + toImportModels = toImportModels.filter( + (e: Model) => !existingIds.includes(e.id) && !e.settings?.vision_model + ) + + /** + * There is no model to import + * just return fetched models + */ + if (!toImportModels.length) return fetchedModels console.log('To import models:', toImportModels.length) /** * There are models to import - * do not return models from cortex.cpp yet - * otherwise it will reset the app cache - * */ + */ if (toImportModels.length > 0) { // Import models await Promise.all( @@ -202,8 +205,6 @@ export default class JanModelExtension extends ModelExtension { }) }) ) - - return currentModels } /** diff --git a/web/containers/Providers/DataLoader.tsx b/web/containers/Providers/DataLoader.tsx index ed4c07ec3..d3d747d02 100644 --- a/web/containers/Providers/DataLoader.tsx +++ b/web/containers/Providers/DataLoader.tsx @@ -29,17 +29,18 @@ const DataLoader: React.FC = ({ children }) => { const setQuickAskEnabled = useSetAtom(quickAskEnabledAtom) const setJanDefaultDataFolder = useSetAtom(defaultJanDataFolderAtom) const setJanSettingScreen = useSetAtom(janSettingScreenAtom) + const { loadDataModel } = useModels() useThreads() useAssistants() useGetSystemResources() useLoadTheme() - const { loadDataModel, isUpdated } = useModels() useEffect(() => { - // Listen for model updates + // Load data once loadDataModel() - }, [isUpdated, loadDataModel]) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []) useEffect(() => { window.core?.api diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 5cb0debab..c1dcf7c40 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -112,8 +112,8 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { state.downloadState = 'end' setDownloadState(state) removeDownloadingModel(state.modelId) + events.emit(ModelEvent.OnModelsUpdate, { fetch: true }) } - events.emit(ModelEvent.OnModelsUpdate, {}) }, [removeDownloadingModel, setDownloadState] ) diff --git a/web/containers/Providers/ModelImportListener.tsx b/web/containers/Providers/ModelImportListener.tsx index f1ca2a768..a60b7be80 100644 --- a/web/containers/Providers/ModelImportListener.tsx +++ b/web/containers/Providers/ModelImportListener.tsx @@ -43,7 +43,7 @@ const ModelImportListener = ({ children }: PropsWithChildren) => { const onImportModelSuccess = useCallback( (state: ImportingModel) => { if (!state.modelId) return - events.emit(ModelEvent.OnModelsUpdate, {}) + events.emit(ModelEvent.OnModelsUpdate, { fetch: true }) setImportingModelSuccess(state.importId, state.modelId) }, [setImportingModelSuccess] diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index 0aed91ed2..d2b05779f 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useRef } from 'react' +import { useCallback, useEffect } from 'react' import { ExtensionTypeEnum, @@ -9,7 +9,7 @@ import { ModelManager, } from '@janhq/core' -import { useSetAtom } from 'jotai' +import { useSetAtom, useAtom } from 'jotai' import { useDebouncedCallback } from 'use-debounce' @@ -27,17 +27,11 @@ import { * and updates the atoms accordingly. */ const useModels = () => { - const setDownloadedModels = useSetAtom(downloadedModelsAtom) + const [downloadedModels, setDownloadedModels] = useAtom(downloadedModelsAtom) const setExtensionModels = useSetAtom(configuredModelsAtom) - const hasFetchedDownloadedModels = useRef(false) // Track whether the function has been executed - - let isUpdated = false const getData = useCallback(() => { - if (hasFetchedDownloadedModels.current) return - const getDownloadedModels = async () => { - hasFetchedDownloadedModels.current = true const localModels = (await getModels()).map((e) => ({ ...e, name: ModelManager.instance().models.get(e.id)?.name ?? e.id, @@ -58,6 +52,8 @@ const useModels = () => { setDownloadedModels(toUpdate) + let isUpdated = false + toUpdate.forEach((model) => { if (!ModelManager.instance().models.has(model.id)) { ModelManager.instance().models.set(model.id, model) @@ -77,30 +73,41 @@ const useModels = () => { // Fetch all data getExtensionModels() getDownloadedModels() - }, []) + }, [setDownloadedModels, setExtensionModels]) const reloadData = useDebouncedCallback(() => getData(), 300) + const updateStates = useCallback(() => { + const cachedModels = ModelManager.instance().models.values().toArray() + const toUpdate = [ + ...downloadedModels, + ...cachedModels.filter( + (e: Model) => !downloadedModels.some((g: Model) => g.id === e.id) + ), + ] + + setDownloadedModels(toUpdate) + }, [downloadedModels, setDownloadedModels]) + const getModels = async (): Promise => extensionManager .get(ExtensionTypeEnum.Model) ?.getModels() ?? [] useEffect(() => { - // Try get data on mount - if (isUpdated) { - // Listen for model updates - events.on(ModelEvent.OnModelsUpdate, async () => reloadData()) - return () => { - // Remove listener on unmount - events.off(ModelEvent.OnModelsUpdate, async () => {}) - } + // Listen for model updates + events.on(ModelEvent.OnModelsUpdate, async (data: { fetch?: boolean }) => { + if (data.fetch) reloadData() + else updateStates() + }) + return () => { + // Remove listener on unmount + events.off(ModelEvent.OnModelsUpdate, async () => {}) } - }, [isUpdated, reloadData]) + }, [reloadData, updateStates]) return { loadDataModel: getData, - isUpdated: isUpdated, } }