fix: model reload state - reduce model unload events emit

This commit is contained in:
Louis 2024-10-22 15:21:30 +07:00
parent 523c745150
commit 40957f7686
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 92 additions and 56 deletions

View File

@ -18,7 +18,14 @@ export class ModelManager {
* @param model - The model to register. * @param model - The model to register.
*/ */
register<T extends Model>(model: T) { register<T extends Model>(model: T) {
this.models.set(model.id, model) if (this.models.has(model.id)) {
this.models.set(model.id, {
...model,
...this.models.get(model.id),
})
} else {
this.models.set(model.id, model)
}
events.emit(ModelEvent.OnModelsUpdate, {}) events.emit(ModelEvent.OnModelsUpdate, {})
} }

View File

@ -102,7 +102,7 @@ Enable the GPU acceleration option within the Jan application by following the [
], ],
"size": 669000000 "size": 669000000
}, },
"engine": "llama-cpp" "engine": "nitro"
} }
``` ```
### Step 2: Modify the `model.json` ### Step 2: Modify the `model.json`

View File

@ -51,6 +51,7 @@
"decompress": "^4.2.1", "decompress": "^4.2.1",
"fetch-retry": "^5.0.6", "fetch-retry": "^5.0.6",
"ky": "^1.7.2", "ky": "^1.7.2",
"p-queue": "^8.0.1",
"rxjs": "^7.8.1", "rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2", "tcp-port-used": "^1.0.2",
"terminate": "2.6.1", "terminate": "2.6.1",

View File

@ -114,7 +114,7 @@ export default [
]), ]),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson), DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson),
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291/v1'), CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
}), }),
// Allow json resolution // Allow json resolution
json(), json(),

View File

@ -16,7 +16,7 @@ import {
LocalOAIEngine, LocalOAIEngine,
InferenceEngine, InferenceEngine,
} from '@janhq/core' } from '@janhq/core'
import PQueue from 'p-queue'
import ky from 'ky' import ky from 'ky'
/** /**
@ -28,12 +28,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// DEPRECATED // DEPRECATED
nodeModule: string = 'node' nodeModule: string = 'node'
queue = new PQueue({ concurrency: 1 })
provider: string = InferenceEngine.cortex provider: string = InferenceEngine.cortex
/** /**
* The URL for making inference requests. * The URL for making inference requests.
*/ */
inferenceUrl = `${CORTEX_API_URL}/chat/completions` inferenceUrl = `${CORTEX_API_URL}/v1/chat/completions`
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
@ -47,7 +49,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// Run the process watchdog // Run the process watchdog
const systemInfo = await systemInformation() const systemInfo = await systemInformation()
executeOnMain(NODE, 'run', systemInfo) await executeOnMain(NODE, 'run', systemInfo)
this.queue.add(() => this.healthz())
} }
onUnload(): void { onUnload(): void {
@ -61,16 +65,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// Legacy model cache - should import // Legacy model cache - should import
if (model.engine === InferenceEngine.nitro && model.file_path) { if (model.engine === InferenceEngine.nitro && model.file_path) {
// Try importing the model // Try importing the model
await ky const modelPath = await this.modelPath(model)
.post(`${CORTEX_API_URL}/models/${model.id}`, { await this.queue.add(() =>
json: { model: model.id, modelPath: await this.modelPath(model) }, ky
}) .post(`${CORTEX_API_URL}/v1/models/${model.id}`, {
.json() json: { model: model.id, modelPath: modelPath },
.catch((e) => log(e.message ?? e ?? '')) })
.json()
.catch((e) => log(e.message ?? e ?? ''))
)
} }
return ky return await ky
.post(`${CORTEX_API_URL}/models/start`, { .post(`${CORTEX_API_URL}/v1/models/start`, {
json: { json: {
...model.settings, ...model.settings,
model: model.id, model: model.id,
@ -89,7 +96,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
override async unloadModel(model: Model): Promise<void> { override async unloadModel(model: Model): Promise<void> {
return ky return ky
.post(`${CORTEX_API_URL}/models/stop`, { .post(`${CORTEX_API_URL}/v1/models/stop`, {
json: { model: model.id }, json: { model: model.id },
}) })
.json() .json()
@ -108,4 +115,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
model.id, model.id,
]) ])
} }
/**
* Do health check on cortex.cpp
* @returns
*/
healthz(): Promise<void> {
return ky
.get(`${CORTEX_API_URL}/healthz`, {
retry: {
limit: 10,
methods: ['get'],
},
})
.then(() => {})
}
} }

View File

@ -1,9 +1,8 @@
{ {
"compilerOptions": { "compilerOptions": {
"moduleResolution": "node", "moduleResolution": "node",
"target": "ES2015", "target": "es2016",
"module": "ES2020", "module": "esnext",
"lib": ["es2015", "es2016", "es2017", "dom"],
"strict": true, "strict": true,
"sourceMap": true, "sourceMap": true,
"declaration": true, "declaration": true,

View File

@ -44,6 +44,11 @@ export default function ModelReload() {
Reloading model {stateModel.model?.id} Reloading model {stateModel.model?.id}
</span> </span>
</div> </div>
<div className="my-4 mb-2 text-center">
<span className="text-[hsla(var(--text-secondary)]">
Model is reloading to apply new changes.
</span>
</div>
</div> </div>
) )
} }

View File

@ -51,6 +51,10 @@ export function useActiveModel() {
console.debug(`Model ${modelId} is already initialized. Ignore..`) console.debug(`Model ${modelId} is already initialized. Ignore..`)
return Promise.resolve() return Promise.resolve()
} }
if (activeModel) {
stopModel(activeModel)
}
setPendingModelLoad(true) setPendingModelLoad(true)
let model = downloadedModelsRef?.current.find((e) => e.id === modelId) let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
@ -113,7 +117,7 @@ export function useActiveModel() {
setStateModel(() => ({ setStateModel(() => ({
state: 'start', state: 'start',
loading: false, loading: false,
model, undefined,
})) }))
if (!pendingModelLoad && abortable) { if (!pendingModelLoad && abortable) {
@ -130,28 +134,30 @@ export function useActiveModel() {
}) })
} }
const stopModel = useCallback(async () => { const stopModel = useCallback(
const stoppingModel = activeModel || stateModel.model async (model?: Model) => {
if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading)) const stoppingModel = model ?? activeModel ?? stateModel.model
return if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading))
return
setStateModel({ state: 'stop', loading: true, model: stoppingModel }) const engine = EngineManager.instance().get(stoppingModel.engine)
const engine = EngineManager.instance().get(stoppingModel.engine) return engine
return engine ?.unloadModel(stoppingModel)
?.unloadModel(stoppingModel) .catch((e) => console.error(e))
.catch((e) => console.error(e)) .then(() => {
.then(() => { setActiveModel(undefined)
setActiveModel(undefined) setStateModel({ state: 'start', loading: false, model: undefined })
setStateModel({ state: 'start', loading: false, model: undefined }) setPendingModelLoad(false)
setPendingModelLoad(false) })
}) },
}, [ [
activeModel, activeModel,
setActiveModel, setStateModel,
setStateModel, setActiveModel,
setPendingModelLoad, setPendingModelLoad,
stateModel, stateModel,
]) ]
)
const stopInference = useCallback(async () => { const stopInference = useCallback(async () => {
// Loading model // Loading model

View File

@ -31,10 +31,9 @@ const useModels = () => {
const getData = useCallback(() => { const getData = useCallback(() => {
const getDownloadedModels = async () => { const getDownloadedModels = async () => {
const localModels = await getModels() const localModels = await getModels()
const remoteModels = ModelManager.instance() const hubModels = ModelManager.instance().models.values().toArray()
.models.values()
.toArray() const remoteModels = hubModels.filter((e) => !isLocalEngine(e.engine))
.filter((e) => !isLocalEngine(e.engine))
setDownloadedModels([...localModels, ...remoteModels]) setDownloadedModels([...localModels, ...remoteModels])
} }

View File

@ -199,16 +199,7 @@ const ThreadCenterPanel = () => {
{!engineParamsUpdate && <ModelStart />} {!engineParamsUpdate && <ModelStart />}
{reloadModel && ( {reloadModel && <ModelReload />}
<Fragment>
<ModelReload />
<div className="mb-2 text-center">
<span className="text-[hsla(var(--text-secondary)]">
Model is reloading to apply new changes.
</span>
</div>
</Fragment>
)}
{activeModel && isGeneratingResponse && <GenerateResponse />} {activeModel && isGeneratingResponse && <GenerateResponse />}
<ChatInput /> <ChatInput />

View File

@ -15,6 +15,8 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import CopyOverInstruction from '@/containers/CopyInstruction' import CopyOverInstruction from '@/containers/CopyInstruction'
import EngineSetting from '@/containers/EngineSetting' import EngineSetting from '@/containers/EngineSetting'
import ModelDropdown from '@/containers/ModelDropdown' import ModelDropdown from '@/containers/ModelDropdown'
@ -168,6 +170,10 @@ const ThreadRightPanel = () => {
[activeThread, updateThreadMetadata] [activeThread, updateThreadMetadata]
) )
const resetModel = useDebouncedCallback(() => {
stopModel()
}, 300)
const onValueChanged = useCallback( const onValueChanged = useCallback(
(key: string, value: string | number | boolean) => { (key: string, value: string | number | boolean) => {
if (!activeThread) { if (!activeThread) {
@ -175,7 +181,7 @@ const ThreadRightPanel = () => {
} }
setEngineParamsUpdate(true) setEngineParamsUpdate(true)
stopModel() resetModel()
updateModelParameter(activeThread, { updateModelParameter(activeThread, {
params: { [key]: value }, params: { [key]: value },
@ -207,7 +213,7 @@ const ThreadRightPanel = () => {
} }
} }
}, },
[activeThread, setEngineParamsUpdate, stopModel, updateModelParameter] [activeThread, resetModel, setEngineParamsUpdate, updateModelParameter]
) )
if (!activeThread) { if (!activeThread) {