diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 1e648f60e..ce182483e 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -25,7 +25,7 @@ export const stateModelAtom = atom({ model: undefined, }) -export let loadModelController: AbortController | undefined +const pendingModelLoadAtom = atom(false) export function useActiveModel() { const [activeModel, setActiveModel] = useAtom(activeModelAtom) @@ -33,6 +33,7 @@ export function useActiveModel() { const [stateModel, setStateModel] = useAtom(stateModelAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) const setLoadModelError = useSetAtom(loadModelErrorAtom) + const [pendingModelLoad, setPendingModelLoad] = useAtom(pendingModelLoadAtom) const downloadedModelsRef = useRef([]) @@ -40,7 +41,7 @@ export function useActiveModel() { downloadedModelsRef.current = downloadedModels }, [downloadedModels]) - const startModel = async (modelId: string) => { + const startModel = async (modelId: string, abortable: boolean = true) => { if ( (activeModel && activeModel.id === modelId) || (stateModel.model?.id === modelId && stateModel.loading) @@ -48,7 +49,7 @@ export function useActiveModel() { console.debug(`Model ${modelId} is already initialized. Ignore..`) return Promise.resolve() } - loadModelController = new AbortController() + setPendingModelLoad(true) let model = downloadedModelsRef?.current.find((e) => e.id === modelId) @@ -107,15 +108,16 @@ export function useActiveModel() { }) }) .catch((error) => { - if (loadModelController?.signal.aborted) - return Promise.reject(new Error('aborted')) - setStateModel(() => ({ state: 'start', loading: false, model, })) + if (!pendingModelLoad && abortable) { + return Promise.reject(new Error('aborted')) + } + toaster({ title: 'Failed!', description: `Model ${model.id} failed to start.`, @@ -139,9 +141,15 @@ export function useActiveModel() { .then(() => { setActiveModel(undefined) setStateModel({ state: 'start', loading: false, model: undefined }) - loadModelController?.abort() + setPendingModelLoad(false) }) - }, [activeModel, setActiveModel, setStateModel, stateModel]) + }, [ + activeModel, + setActiveModel, + setStateModel, + setPendingModelLoad, + stateModel, + ]) const stopInference = useCallback(async () => { // Loading model diff --git a/web/screens/LocalServer/index.tsx b/web/screens/LocalServer/index.tsx index db7baec5a..aa7dbd57c 100644 --- a/web/screens/LocalServer/index.tsx +++ b/web/screens/LocalServer/index.tsx @@ -155,12 +155,12 @@ const LocalServerScreen = () => { isCorsEnabled, isVerboseEnabled, }) - await startModel(selectedModel.id) if (isStarted) setServerEnabled(true) if (firstTimeVisitAPIServer) { localStorage.setItem(FIRST_TIME_VISIT_API_SERVER, 'false') setFirstTimeVisitAPIServer(false) } + startModel(selectedModel.id, false).catch((e) => console.error(e)) } catch (e) { console.error(e) toaster({