fix: model reload state - reduce model unload events emit

This commit is contained in:
Louis 2024-10-22 15:21:30 +07:00
parent 523c745150
commit 40957f7686
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 92 additions and 56 deletions

View File

@ -18,7 +18,14 @@ export class ModelManager {
* @param model - The model to register.
*/
register<T extends Model>(model: T) {
this.models.set(model.id, model)
if (this.models.has(model.id)) {
this.models.set(model.id, {
...model,
...this.models.get(model.id),
})
} else {
this.models.set(model.id, model)
}
events.emit(ModelEvent.OnModelsUpdate, {})
}

View File

@ -102,7 +102,7 @@ Enable the GPU acceleration option within the Jan application by following the [
],
"size": 669000000
},
"engine": "llama-cpp"
"engine": "nitro"
}
```
### Step 2: Modify the `model.json`

View File

@ -51,6 +51,7 @@
"decompress": "^4.2.1",
"fetch-retry": "^5.0.6",
"ky": "^1.7.2",
"p-queue": "^8.0.1",
"rxjs": "^7.8.1",
"tcp-port-used": "^1.0.2",
"terminate": "2.6.1",

View File

@ -114,7 +114,7 @@ export default [
]),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson),
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291/v1'),
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
}),
// Allow json resolution
json(),

View File

@ -16,7 +16,7 @@ import {
LocalOAIEngine,
InferenceEngine,
} from '@janhq/core'
import PQueue from 'p-queue'
import ky from 'ky'
/**
@ -28,12 +28,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// DEPRECATED
nodeModule: string = 'node'
queue = new PQueue({ concurrency: 1 })
provider: string = InferenceEngine.cortex
/**
* The URL for making inference requests.
*/
inferenceUrl = `${CORTEX_API_URL}/chat/completions`
inferenceUrl = `${CORTEX_API_URL}/v1/chat/completions`
/**
* Subscribes to events emitted by the @janhq/core package.
@ -47,7 +49,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// Run the process watchdog
const systemInfo = await systemInformation()
executeOnMain(NODE, 'run', systemInfo)
await executeOnMain(NODE, 'run', systemInfo)
this.queue.add(() => this.healthz())
}
onUnload(): void {
@ -61,16 +65,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
// Legacy model cache - should import
if (model.engine === InferenceEngine.nitro && model.file_path) {
// Try importing the model
await ky
.post(`${CORTEX_API_URL}/models/${model.id}`, {
json: { model: model.id, modelPath: await this.modelPath(model) },
})
.json()
.catch((e) => log(e.message ?? e ?? ''))
const modelPath = await this.modelPath(model)
await this.queue.add(() =>
ky
.post(`${CORTEX_API_URL}/v1/models/${model.id}`, {
json: { model: model.id, modelPath: modelPath },
})
.json()
.catch((e) => log(e.message ?? e ?? ''))
)
}
return ky
.post(`${CORTEX_API_URL}/models/start`, {
return await ky
.post(`${CORTEX_API_URL}/v1/models/start`, {
json: {
...model.settings,
model: model.id,
@ -89,7 +96,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
override async unloadModel(model: Model): Promise<void> {
return ky
.post(`${CORTEX_API_URL}/models/stop`, {
.post(`${CORTEX_API_URL}/v1/models/stop`, {
json: { model: model.id },
})
.json()
@ -108,4 +115,19 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
model.id,
])
}
/**
* Do health check on cortex.cpp
* @returns
*/
healthz(): Promise<void> {
return ky
.get(`${CORTEX_API_URL}/healthz`, {
retry: {
limit: 10,
methods: ['get'],
},
})
.then(() => {})
}
}

View File

@ -1,9 +1,8 @@
{
"compilerOptions": {
"moduleResolution": "node",
"target": "ES2015",
"module": "ES2020",
"lib": ["es2015", "es2016", "es2017", "dom"],
"target": "es2016",
"module": "esnext",
"strict": true,
"sourceMap": true,
"declaration": true,

View File

@ -44,6 +44,11 @@ export default function ModelReload() {
Reloading model {stateModel.model?.id}
</span>
</div>
<div className="my-4 mb-2 text-center">
<span className="text-[hsla(var(--text-secondary)]">
Model is reloading to apply new changes.
</span>
</div>
</div>
)
}

View File

@ -51,6 +51,10 @@ export function useActiveModel() {
console.debug(`Model ${modelId} is already initialized. Ignore..`)
return Promise.resolve()
}
if (activeModel) {
stopModel(activeModel)
}
setPendingModelLoad(true)
let model = downloadedModelsRef?.current.find((e) => e.id === modelId)
@ -113,7 +117,7 @@ export function useActiveModel() {
setStateModel(() => ({
state: 'start',
loading: false,
model,
undefined,
}))
if (!pendingModelLoad && abortable) {
@ -130,28 +134,30 @@ export function useActiveModel() {
})
}
const stopModel = useCallback(async () => {
const stoppingModel = activeModel || stateModel.model
if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading))
return
const stopModel = useCallback(
async (model?: Model) => {
const stoppingModel = model ?? activeModel ?? stateModel.model
if (!stoppingModel || (stateModel.state === 'stop' && stateModel.loading))
return
setStateModel({ state: 'stop', loading: true, model: stoppingModel })
const engine = EngineManager.instance().get(stoppingModel.engine)
return engine
?.unloadModel(stoppingModel)
.catch((e) => console.error(e))
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: undefined })
setPendingModelLoad(false)
})
}, [
activeModel,
setActiveModel,
setStateModel,
setPendingModelLoad,
stateModel,
])
const engine = EngineManager.instance().get(stoppingModel.engine)
return engine
?.unloadModel(stoppingModel)
.catch((e) => console.error(e))
.then(() => {
setActiveModel(undefined)
setStateModel({ state: 'start', loading: false, model: undefined })
setPendingModelLoad(false)
})
},
[
activeModel,
setStateModel,
setActiveModel,
setPendingModelLoad,
stateModel,
]
)
const stopInference = useCallback(async () => {
// Loading model

View File

@ -31,10 +31,9 @@ const useModels = () => {
const getData = useCallback(() => {
const getDownloadedModels = async () => {
const localModels = await getModels()
const remoteModels = ModelManager.instance()
.models.values()
.toArray()
.filter((e) => !isLocalEngine(e.engine))
const hubModels = ModelManager.instance().models.values().toArray()
const remoteModels = hubModels.filter((e) => !isLocalEngine(e.engine))
setDownloadedModels([...localModels, ...remoteModels])
}

View File

@ -199,16 +199,7 @@ const ThreadCenterPanel = () => {
{!engineParamsUpdate && <ModelStart />}
{reloadModel && (
<Fragment>
<ModelReload />
<div className="mb-2 text-center">
<span className="text-[hsla(var(--text-secondary)]">
Model is reloading to apply new changes.
</span>
</div>
</Fragment>
)}
{reloadModel && <ModelReload />}
{activeModel && isGeneratingResponse && <GenerateResponse />}
<ChatInput />

View File

@ -15,6 +15,8 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { useDebouncedCallback } from 'use-debounce'
import CopyOverInstruction from '@/containers/CopyInstruction'
import EngineSetting from '@/containers/EngineSetting'
import ModelDropdown from '@/containers/ModelDropdown'
@ -168,6 +170,10 @@ const ThreadRightPanel = () => {
[activeThread, updateThreadMetadata]
)
const resetModel = useDebouncedCallback(() => {
stopModel()
}, 300)
const onValueChanged = useCallback(
(key: string, value: string | number | boolean) => {
if (!activeThread) {
@ -175,7 +181,7 @@ const ThreadRightPanel = () => {
}
setEngineParamsUpdate(true)
stopModel()
resetModel()
updateModelParameter(activeThread, {
params: { [key]: value },
@ -207,7 +213,7 @@ const ThreadRightPanel = () => {
}
}
},
[activeThread, setEngineParamsUpdate, stopModel, updateModelParameter]
[activeThread, resetModel, setEngineParamsUpdate, updateModelParameter]
)
if (!activeThread) {