From 1eaf13b13ef612744165611995a30aed853ed61f Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 4 Apr 2024 10:57:54 +0700 Subject: [PATCH] fix: cancel loading model with stop action (#2607) --- .../inference-nitro-extension/package.json | 1 + .../src/node/index.ts | 51 ++++++++++----- web/containers/Loader/ModelReload.tsx | 2 +- web/containers/Loader/ModelStart.tsx | 2 +- web/containers/Providers/EventHandler.tsx | 2 +- web/hooks/useActiveModel.ts | 63 +++++++++++++------ web/hooks/useCreateNewThread.ts | 4 +- web/screens/Chat/ChatInput/index.tsx | 5 +- web/screens/Chat/EditChatInput/index.tsx | 4 +- web/screens/Chat/LoadModelError/index.tsx | 3 +- web/screens/Settings/Models/Row.tsx | 4 +- 11 files changed, 96 insertions(+), 45 deletions(-) diff --git a/extensions/inference-nitro-extension/package.json b/extensions/inference-nitro-extension/package.json index 25abaf049..b0a555cbc 100644 --- a/extensions/inference-nitro-extension/package.json +++ b/extensions/inference-nitro-extension/package.json @@ -51,6 +51,7 @@ "path-browserify": "^1.0.1", "rxjs": "^7.8.1", "tcp-port-used": "^1.0.2", + "terminate": "^2.6.1", "ulidx": "^2.3.0" }, "engines": { diff --git a/extensions/inference-nitro-extension/src/node/index.ts b/extensions/inference-nitro-extension/src/node/index.ts index 71adac72d..638d4c5eb 100644 --- a/extensions/inference-nitro-extension/src/node/index.ts +++ b/extensions/inference-nitro-extension/src/node/index.ts @@ -13,6 +13,7 @@ import { SystemInformation, } from '@janhq/core/node' import { executableNitroFile } from './execute' +import terminate from 'terminate' // Polyfill fetch with retry const fetchRetry = fetchRT(fetch) @@ -304,23 +305,43 @@ async function killSubprocess(): Promise { setTimeout(() => controller.abort(), 5000) log(`[NITRO]::Debug: Request to kill Nitro`) - return fetch(NITRO_HTTP_KILL_URL, { - method: 'DELETE', - signal: controller.signal, - }) - .then(() => { - subprocess?.kill() - subprocess = undefined + const killRequest = () => { + return fetch(NITRO_HTTP_KILL_URL, { + method: 'DELETE', + signal: controller.signal, }) - .catch(() => {}) // Do nothing with this attempt - .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) - .then(() => log(`[NITRO]::Debug: Nitro process is terminated`)) - .catch((err) => { - log( - `[NITRO]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` - ) - throw 'PORT_NOT_AVAILABLE' + .catch(() => {}) // Do nothing with this attempt + .then(() => tcpPortUsed.waitUntilFree(PORT, 300, 5000)) + .then(() => log(`[NITRO]::Debug: Nitro process is terminated`)) + .catch((err) => { + log( + `[NITRO]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}` + ) + throw 'PORT_NOT_AVAILABLE' + }) + } + + if (subprocess?.pid) { + log(`[NITRO]::Debug: Killing PID ${subprocess.pid}`) + const pid = subprocess.pid + return new Promise((resolve, reject) => { + terminate(pid, function (err) { + if (err) { + return killRequest() + } else { + return tcpPortUsed + .waitUntilFree(PORT, 300, 5000) + .then(() => resolve()) + .then(() => log(`[NITRO]::Debug: Nitro process is terminated`)) + .catch(() => { + killRequest() + }) + } + }) }) + } else { + return killRequest() + } } /** diff --git a/web/containers/Loader/ModelReload.tsx b/web/containers/Loader/ModelReload.tsx index a432927aa..44fbb9ab9 100644 --- a/web/containers/Loader/ModelReload.tsx +++ b/web/containers/Loader/ModelReload.tsx @@ -41,7 +41,7 @@ export default function ModelReload() { style={{ width: `${loader}%` }} /> - Reloading model {stateModel.model} + Reloading model {stateModel.model?.id} diff --git a/web/containers/Loader/ModelStart.tsx b/web/containers/Loader/ModelStart.tsx index 7002c7b40..f7bc04481 100644 --- a/web/containers/Loader/ModelStart.tsx +++ b/web/containers/Loader/ModelStart.tsx @@ -44,7 +44,7 @@ export default function ModelStart() { {stateModel.state === 'start' ? 'Starting' : 'Stopping'}  model  - {stateModel.model} + {stateModel.model?.id} diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 4d5555a46..110d36e36 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -79,7 +79,7 @@ export default function EventHandler({ children }: { children: ReactNode }) { const onModelStopped = useCallback(() => { setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: '' }) + setStateModel({ state: 'start', loading: false, model: undefined }) }, [setActiveModel, setStateModel]) const updateThreadTitle = useCallback( diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index d387861eb..e2cba75b7 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -13,10 +13,16 @@ import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' export const activeModelAtom = atom(undefined) export const loadModelErrorAtom = atom(undefined) -export const stateModelAtom = atom({ +type ModelState = { + state: string + loading: boolean + model?: Model +} + +export const stateModelAtom = atom({ state: 'start', loading: false, - model: '', + model: undefined, }) export function useActiveModel() { @@ -35,7 +41,7 @@ export function useActiveModel() { const startModel = async (modelId: string) => { if ( (activeModel && activeModel.id === modelId) || - (stateModel.model === modelId && stateModel.loading) + (stateModel.model?.id === modelId && stateModel.loading) ) { console.debug(`Model ${modelId} is already initialized. Ignore..`) return Promise.resolve() @@ -52,7 +58,7 @@ export function useActiveModel() { setActiveModel(undefined) - setStateModel({ state: 'start', loading: true, model: modelId }) + setStateModel({ state: 'start', loading: true, model }) if (!model) { toaster({ @@ -63,7 +69,7 @@ export function useActiveModel() { setStateModel(() => ({ state: 'start', loading: false, - model: '', + model: undefined, })) return Promise.reject(`Model ${modelId} not found!`) @@ -89,7 +95,7 @@ export function useActiveModel() { setStateModel(() => ({ state: 'stop', loading: false, - model: model.id, + model, })) toaster({ title: 'Success!', @@ -101,7 +107,7 @@ export function useActiveModel() { setStateModel(() => ({ state: 'start', loading: false, - model: model.id, + model, })) toaster({ @@ -114,20 +120,39 @@ export function useActiveModel() { }) } - const stopModel = useCallback(async () => { - if (!activeModel || (stateModel.state === 'stop' && stateModel.loading)) + const stopModel = useCallback( + async (model?: Model) => { + const stoppingModel = activeModel || model + if ( + !stoppingModel || + (!model && stateModel.state === 'stop' && stateModel.loading) + ) + return + + setStateModel({ state: 'stop', loading: true, model: stoppingModel }) + const engine = EngineManager.instance().get(stoppingModel.engine) + await engine + ?.unloadModel(stoppingModel) + .catch() + .then(() => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: undefined }) + }) + }, + [activeModel, setActiveModel, setStateModel, stateModel] + ) + + const stopInference = useCallback(async () => { + // Loading model + if (stateModel.loading) { + stopModel(stateModel.model) return + } + if (!activeModel) return - setStateModel({ state: 'stop', loading: true, model: activeModel.id }) const engine = EngineManager.instance().get(activeModel.engine) - await engine - ?.unloadModel(activeModel) - .catch() - .then(() => { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: '' }) - }) - }, [activeModel, stateModel, setActiveModel, setStateModel]) + engine?.stopInference() + }, [activeModel, stateModel, stopModel]) - return { activeModel, startModel, stopModel, stateModel } + return { activeModel, startModel, stopModel, stopInference, stateModel } } diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 27f27db60..9b4e4261e 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -19,6 +19,7 @@ import { fileUploadAtom } from '@/containers/Providers/Jotai' import { generateThreadId } from '@/utils/thread' +import { useActiveModel } from './useActiveModel' import useRecommendedModel from './useRecommendedModel' import useSetActiveThread from './useSetActiveThread' @@ -65,6 +66,7 @@ export const useCreateNewThread = () => { const { recommendedModel, downloadedModels } = useRecommendedModel() const threads = useAtomValue(threadsAtom) + const { stopInference } = useActiveModel() const requestCreateNewThread = async ( assistant: Assistant, @@ -72,7 +74,7 @@ export const useCreateNewThread = () => { ) => { // Stop generating if any setIsGeneratingResponse(false) - events.emit(InferenceEvent.OnInferenceStopped, {}) + stopInference() const defaultModel = model ?? recommendedModel ?? downloadedModels[0] diff --git a/web/screens/Chat/ChatInput/index.tsx b/web/screens/Chat/ChatInput/index.tsx index 8707e8bcd..f6fd299b9 100644 --- a/web/screens/Chat/ChatInput/index.tsx +++ b/web/screens/Chat/ChatInput/index.tsx @@ -44,7 +44,7 @@ import { const ChatInput: React.FC = () => { const activeThread = useAtomValue(activeThreadAtom) - const { stateModel } = useActiveModel() + const { stateModel, activeModel } = useActiveModel() const messages = useAtomValue(getCurrentChatMessagesAtom) const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom) @@ -60,6 +60,7 @@ const ChatInput: React.FC = () => { const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) const threadStates = useAtomValue(threadStatesAtom) + const { stopInference } = useActiveModel() const isStreamingResponse = Object.values(threadStates).some( (threadState) => threadState.waitingForResponse @@ -107,7 +108,7 @@ const ChatInput: React.FC = () => { } const onStopInferenceClick = async () => { - events.emit(InferenceEvent.OnInferenceStopped, {}) + stopInference() } /** diff --git a/web/screens/Chat/EditChatInput/index.tsx b/web/screens/Chat/EditChatInput/index.tsx index 240dc5106..543817c6b 100644 --- a/web/screens/Chat/EditChatInput/index.tsx +++ b/web/screens/Chat/EditChatInput/index.tsx @@ -50,7 +50,7 @@ type Props = { const EditChatInput: React.FC = ({ message }) => { const activeThread = useAtomValue(activeThreadAtom) - const { stateModel } = useActiveModel() + const { stateModel, stopInference } = useActiveModel() const messages = useAtomValue(getCurrentChatMessagesAtom) const [editPrompt, setEditPrompt] = useAtom(editPromptAtom) @@ -127,7 +127,7 @@ const EditChatInput: React.FC = ({ message }) => { } const onStopInferenceClick = async () => { - events.emit(InferenceEvent.OnInferenceStopped, {}) + stopInference() } return ( diff --git a/web/screens/Chat/LoadModelError/index.tsx b/web/screens/Chat/LoadModelError/index.tsx index 3bf01f8cb..9bfa328c1 100644 --- a/web/screens/Chat/LoadModelError/index.tsx +++ b/web/screens/Chat/LoadModelError/index.tsx @@ -34,7 +34,8 @@ const LoadModelError = () => { ) : loadModelError && - loadModelError?.includes('EXTENSION_IS_NOT_INSTALLED') ? ( + typeof loadModelError.includes === 'function' && + loadModelError.includes('EXTENSION_IS_NOT_INSTALLED') ? (

Model is currently unavailable. Please switch to a different model diff --git a/web/screens/Settings/Models/Row.tsx b/web/screens/Settings/Models/Row.tsx index 9707f6194..1d9283efa 100644 --- a/web/screens/Settings/Models/Row.tsx +++ b/web/screens/Settings/Models/Row.tsx @@ -43,7 +43,7 @@ export default function RowModel(props: RowModelProps) { const { activeModel, startModel, stopModel, stateModel } = useActiveModel() const { deleteModel } = useDeleteModel() - const isActiveModel = stateModel.model === props.data.id + const isActiveModel = stateModel.model?.id === props.data.id const [serverEnabled, setServerEnabled] = useAtom(serverEnabledAtom) @@ -84,7 +84,7 @@ export default function RowModel(props: RowModelProps) { Active - ) : stateModel.loading && stateModel.model === props.data.id ? ( + ) : stateModel.loading && stateModel.model?.id === props.data.id ? (