From 40957f7686ee69fd292ebdecbd81c0aaafb66682 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 22 Oct 2024 15:21:30 +0700 Subject: [PATCH] fix: model reload state - reduce model unload events emit --- core/src/browser/models/manager.ts | 9 +++- docs/src/pages/docs/built-in/llama-cpp.mdx | 2 +- .../inference-cortex-extension/package.json | 1 + .../rollup.config.ts | 2 +- .../inference-cortex-extension/src/index.ts | 46 ++++++++++++----- .../inference-cortex-extension/tsconfig.json | 5 +- web/containers/Loader/ModelReload.tsx | 5 ++ web/hooks/useActiveModel.ts | 50 +++++++++++-------- web/hooks/useModels.ts | 7 ++- .../Thread/ThreadCenterPanel/index.tsx | 11 +--- web/screens/Thread/ThreadRightPanel/index.tsx | 10 +++- 11 files changed, 92 insertions(+), 56 deletions(-) diff --git a/core/src/browser/models/manager.ts b/core/src/browser/models/manager.ts index 4853989fe..d5afe83d5 100644 --- a/core/src/browser/models/manager.ts +++ b/core/src/browser/models/manager.ts @@ -18,7 +18,14 @@ export class ModelManager { * @param model - The model to register. */ register(model: T) { - this.models.set(model.id, model) + if (this.models.has(model.id)) { + this.models.set(model.id, { + ...model, + ...this.models.get(model.id), + }) + } else { + this.models.set(model.id, model) + } events.emit(ModelEvent.OnModelsUpdate, {}) } diff --git a/docs/src/pages/docs/built-in/llama-cpp.mdx b/docs/src/pages/docs/built-in/llama-cpp.mdx index 8e2fa8498..5b7b0453a 100644 --- a/docs/src/pages/docs/built-in/llama-cpp.mdx +++ b/docs/src/pages/docs/built-in/llama-cpp.mdx @@ -102,7 +102,7 @@ Enable the GPU acceleration option within the Jan application by following the [ ], "size": 669000000 }, - "engine": "llama-cpp" + "engine": "nitro" } ``` ### Step 2: Modify the `model.json` diff --git a/extensions/inference-cortex-extension/package.json b/extensions/inference-cortex-extension/package.json index 920989f3b..5a9fc56e9 100644 --- a/extensions/inference-cortex-extension/package.json +++ b/extensions/inference-cortex-extension/package.json @@ -51,6 +51,7 @@ "decompress": "^4.2.1", "fetch-retry": "^5.0.6", "ky": "^1.7.2", + "p-queue": "^8.0.1", "rxjs": "^7.8.1", "tcp-port-used": "^1.0.2", "terminate": "2.6.1", diff --git a/extensions/inference-cortex-extension/rollup.config.ts b/extensions/inference-cortex-extension/rollup.config.ts index d0e9f5fbe..ea873990b 100644 --- a/extensions/inference-cortex-extension/rollup.config.ts +++ b/extensions/inference-cortex-extension/rollup.config.ts @@ -114,7 +114,7 @@ export default [ ]), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson), - CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291/v1'), + CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'), }), // Allow json resolution json(), diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 93036fc4d..364bfe79c 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -16,7 +16,7 @@ import { LocalOAIEngine, InferenceEngine, } from '@janhq/core' - +import PQueue from 'p-queue' import ky from 'ky' /** @@ -28,12 +28,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { // DEPRECATED nodeModule: string = 'node' + queue = new PQueue({ concurrency: 1 }) + provider: string = InferenceEngine.cortex /** * The URL for making inference requests. */ - inferenceUrl = `${CORTEX_API_URL}/chat/completions` + inferenceUrl = `${CORTEX_API_URL}/v1/chat/completions` /** * Subscribes to events emitted by the @janhq/core package. @@ -47,7 +49,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { // Run the process watchdog const systemInfo = await systemInformation() - executeOnMain(NODE, 'run', systemInfo) + await executeOnMain(NODE, 'run', systemInfo) + + this.queue.add(() => this.healthz()) } onUnload(): void { @@ -61,16 +65,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { // Legacy model cache - should import if (model.engine === InferenceEngine.nitro && model.file_path) { // Try importing the model - await ky - .post(`${CORTEX_API_URL}/models/${model.id}`, { - json: { model: model.id, modelPath: await this.modelPath(model) }, - }) - .json() - .catch((e) => log(e.message ?? e ?? '')) + const modelPath = await this.modelPath(model) + await this.queue.add(() => + ky + .post(`${CORTEX_API_URL}/v1/models/${model.id}`, { + json: { model: model.id, modelPath: modelPath }, + }) + .json() + .catch((e) => log(e.message ?? e ?? '')) + ) } - return ky - .post(`${CORTEX_API_URL}/models/start`, { + return await ky + .post(`${CORTEX_API_URL}/v1/models/start`, { json: { ...model.settings, model: model.id, @@ -89,7 +96,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { override async unloadModel(model: Model): Promise { return ky - .post(`${CORTEX_API_URL}/models/stop`, { + .post(`${CORTEX_API_URL}/v1/models/stop`, { json: { model: model.id }, }) .json() @@ -108,4 +115,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { model.id, ]) } + + /** + * Do health check on cortex.cpp + * @returns + */ + healthz(): Promise { + return ky + .get(`${CORTEX_API_URL}/healthz`, { + retry: { + limit: 10, + methods: ['get'], + }, + }) + .then(() => {}) + } } diff --git a/extensions/inference-cortex-extension/tsconfig.json b/extensions/inference-cortex-extension/tsconfig.json index bdb35163a..af00a035a 100644 --- a/extensions/inference-cortex-extension/tsconfig.json +++ b/extensions/inference-cortex-extension/tsconfig.json @@ -1,9 +1,8 @@ { "compilerOptions": { "moduleResolution": "node", - "target": "ES2015", - "module": "ES2020", - "lib": ["es2015", "es2016", "es2017", "dom"], + "target": "es2016", + "module": "esnext", "strict": true, "sourceMap": true, "declaration": true, diff --git a/web/containers/Loader/ModelReload.tsx b/web/containers/Loader/ModelReload.tsx index fbe673788..29709c0da 100644 --- a/web/containers/Loader/ModelReload.tsx +++ b/web/containers/Loader/ModelReload.tsx @@ -44,6 +44,11 @@ export default function ModelReload() { Reloading model {stateModel.model?.id} +
+ + Model is reloading to apply new changes. + +
) } diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 8dd71fcc5..353288337 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -51,6 +51,10 @@ export function useActiveModel() { console.debug(`Model ${modelId} is already initialized. Ignore..`) return Promise.resolve() } + + if (activeModel) { + stopModel(activeModel) + } setPendingModelLoad(true) let model = downloadedModelsRef?.current.find((e) => e.id === modelId) @@ -113,7 +117,7 @@ export function useActiveModel() { setStateModel(() => ({ state: 'start', loading: false, - model, + undefined, })) if (!pendingModelLoad && abortable) { @@ -130,28 +134,30 @@ export function useActiveModel() { }) } - const stopModel = useCallback(async () => { - const stoppingModel = activeModel || stateModel.model - if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading)) - return + const stopModel = useCallback( + async (model?: Model) => { + const stoppingModel = model ?? activeModel ?? stateModel.model + if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading)) + return - setStateModel({ state: 'stop', loading: true, model: stoppingModel }) - const engine = EngineManager.instance().get(stoppingModel.engine) - return engine - ?.unloadModel(stoppingModel) - .catch((e) => console.error(e)) - .then(() => { - setActiveModel(undefined) - setStateModel({ state: 'start', loading: false, model: undefined }) - setPendingModelLoad(false) - }) - }, [ - activeModel, - setActiveModel, - setStateModel, - setPendingModelLoad, - stateModel, - ]) + const engine = EngineManager.instance().get(stoppingModel.engine) + return engine + ?.unloadModel(stoppingModel) + .catch((e) => console.error(e)) + .then(() => { + setActiveModel(undefined) + setStateModel({ state: 'start', loading: false, model: undefined }) + setPendingModelLoad(false) + }) + }, + [ + activeModel, + setStateModel, + setActiveModel, + setPendingModelLoad, + stateModel, + ] + ) const stopInference = useCallback(async () => { // Loading model diff --git a/web/hooks/useModels.ts b/web/hooks/useModels.ts index b09839457..742d09beb 100644 --- a/web/hooks/useModels.ts +++ b/web/hooks/useModels.ts @@ -31,10 +31,9 @@ const useModels = () => { const getData = useCallback(() => { const getDownloadedModels = async () => { const localModels = await getModels() - const remoteModels = ModelManager.instance() - .models.values() - .toArray() - .filter((e) => !isLocalEngine(e.engine)) + const hubModels = ModelManager.instance().models.values().toArray() + + const remoteModels = hubModels.filter((e) => !isLocalEngine(e.engine)) setDownloadedModels([...localModels, ...remoteModels]) } diff --git a/web/screens/Thread/ThreadCenterPanel/index.tsx b/web/screens/Thread/ThreadCenterPanel/index.tsx index fe7993e9a..c83a38a1a 100644 --- a/web/screens/Thread/ThreadCenterPanel/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/index.tsx @@ -199,16 +199,7 @@ const ThreadCenterPanel = () => { {!engineParamsUpdate && } - {reloadModel && ( - - -
- - Model is reloading to apply new changes. - -
-
- )} + {reloadModel && } {activeModel && isGeneratingResponse && } diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 7ccc4957a..5a8fd3ebb 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -15,6 +15,8 @@ import { import { useAtom, useAtomValue, useSetAtom } from 'jotai' +import { useDebouncedCallback } from 'use-debounce' + import CopyOverInstruction from '@/containers/CopyInstruction' import EngineSetting from '@/containers/EngineSetting' import ModelDropdown from '@/containers/ModelDropdown' @@ -168,6 +170,10 @@ const ThreadRightPanel = () => { [activeThread, updateThreadMetadata] ) + const resetModel = useDebouncedCallback(() => { + stopModel() + }, 300) + const onValueChanged = useCallback( (key: string, value: string | number | boolean) => { if (!activeThread) { @@ -175,7 +181,7 @@ const ThreadRightPanel = () => { } setEngineParamsUpdate(true) - stopModel() + resetModel() updateModelParameter(activeThread, { params: { [key]: value }, @@ -207,7 +213,7 @@ const ThreadRightPanel = () => { } } }, - [activeThread, setEngineParamsUpdate, stopModel, updateModelParameter] + [activeThread, resetModel, setEngineParamsUpdate, updateModelParameter] ) if (!activeThread) {