Merge pull request #4177 from janhq/fix/4171-model-loading-takes-extremely-long

fix: 4171 - Model loading gets stuck on stop
This commit is contained in:
Louis 2024-12-02 14:32:00 +07:00 committed by GitHub
commit 3118bbaa51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 120 additions and 27 deletions

View File

@ -1 +1 @@
1.0.4-rc4 1.0.4-rc5

View File

@ -1,33 +1,59 @@
[ [
{ {
"key": "test", "key": "cont_batching",
"title": "Test", "title": "Continuous batching",
"description": "Test", "description": "The number of parallel operations",
"controllerType": "input",
"controllerProps": {
"placeholder": "Test",
"value": ""
}
},
{
"key": "embedding",
"title": "Embedding",
"description": "Whether to enable embedding.",
"controllerType": "checkbox", "controllerType": "checkbox",
"controllerProps": { "controllerProps": {
"value": true "value": true
} }
}, },
{ {
"key": "ctx_len", "key": "n_parallel",
"title": "Context Length", "title": "Parallel operations",
"description": "The context length for model operations varies; the maximum depends on the specific model used.", "description": "The number of parallel operations",
"controllerType": "slider", "controllerType": "input",
"controllerProps": { "controllerProps": {
"min": 0, "value": "4",
"max": 4096, "placeholder": "4"
"step": 128, }
"value": 2048 },
{
"key": "flash_attn",
"title": "Flash Attention enabled",
"description": "To enable Flash Attention, default is true",
"controllerType": "checkbox",
"controllerProps": {
"value": true
}
},
{
"key": "caching_enabled",
"title": "Caching enabled",
"description": "To enable prompt caching or not",
"controllerType": "checkbox",
"controllerProps": {
"value": true
}
},
{
"key": "cache_type",
"title": "KV Cache Type",
"description": "KV cache type: f16, q8_0, q4_0, default is f16 (change this could break the model).",
"controllerType": "input",
"controllerProps": {
"placeholder": "f16",
"value": "f16"
}
},
{
"key": "use_mmap",
"title": "To enable mmap",
"description": "To enable mmap, default is true",
"controllerType": "checkbox",
"controllerProps": {
"value": true
} }
} }
] ]

View File

@ -117,7 +117,7 @@ export default [
qwen2572bJson, qwen2572bJson,
]), ]),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`), NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
DEFAULT_SETTINGS: JSON.stringify(defaultSettingJson), SETTINGS: JSON.stringify(defaultSettingJson),
CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'), CORTEX_API_URL: JSON.stringify('http://127.0.0.1:39291'),
CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'), CORTEX_SOCKET_URL: JSON.stringify('ws://127.0.0.1:39291'),
CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'), CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.40'),

View File

@ -2,7 +2,7 @@ declare const NODE: string
declare const CORTEX_API_URL: string declare const CORTEX_API_URL: string
declare const CORTEX_SOCKET_URL: string declare const CORTEX_SOCKET_URL: string
declare const CORTEX_ENGINE_VERSION: string declare const CORTEX_ENGINE_VERSION: string
declare const DEFAULT_SETTINGS: Array<any> declare const SETTINGS: Array<any>
declare const MODELS: Array<any> declare const MODELS: Array<any>
/** /**

View File

@ -36,6 +36,15 @@ enum DownloadTypes {
DownloadStarted = 'onFileDownloadStarted', DownloadStarted = 'onFileDownloadStarted',
} }
export enum Settings {
n_parallel = 'n_parallel',
cont_batching = 'cont_batching',
caching_enabled = 'caching_enabled',
flash_attn = 'flash_attn',
cache_type = 'cache_type',
use_mmap = 'use_mmap',
}
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
@ -50,6 +59,14 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
shouldReconnect = true shouldReconnect = true
/** Default Engine model load settings */
n_parallel: number = 4
cont_batching: boolean = true
caching_enabled: boolean = true
flash_attn: boolean = true
use_mmap: boolean = true
cache_type: string = 'f16'
/** /**
* The URL for making inference requests. * The URL for making inference requests.
*/ */
@ -60,6 +77,8 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
*/ */
socket?: WebSocket = undefined socket?: WebSocket = undefined
abortControllers = new Map<string, AbortController>()
/** /**
* Subscribes to events emitted by the @janhq/core package. * Subscribes to events emitted by the @janhq/core package.
*/ */
@ -70,6 +89,23 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
super.onLoad() super.onLoad()
// Register Settings
this.registerSettings(SETTINGS)
this.n_parallel =
Number(await this.getSetting<string>(Settings.n_parallel, '4')) ?? 4
this.cont_batching = await this.getSetting<boolean>(
Settings.cont_batching,
true
)
this.caching_enabled = await this.getSetting<boolean>(
Settings.caching_enabled,
true
)
this.flash_attn = await this.getSetting<boolean>(Settings.flash_attn, true)
this.use_mmap = await this.getSetting<boolean>(Settings.use_mmap, true)
this.cache_type = await this.getSetting<string>(Settings.cache_type, 'f16')
this.queue.add(() => this.clean()) this.queue.add(() => this.clean())
// Run the process watchdog // Run the process watchdog
@ -101,6 +137,22 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
super.onUnload() super.onUnload()
} }
onSettingUpdate<T>(key: string, value: T): void {
if (key === Settings.n_parallel && typeof value === 'string') {
this.n_parallel = Number(value) ?? 1
} else if (key === Settings.cont_batching && typeof value === 'boolean') {
this.cont_batching = value as boolean
} else if (key === Settings.caching_enabled && typeof value === 'boolean') {
this.caching_enabled = value as boolean
} else if (key === Settings.flash_attn && typeof value === 'boolean') {
this.flash_attn = value as boolean
} else if (key === Settings.cache_type && typeof value === 'string') {
this.cache_type = value as string
} else if (key === Settings.use_mmap && typeof value === 'boolean') {
this.use_mmap = value as boolean
}
}
override async loadModel( override async loadModel(
model: Model & { file_path?: string } model: Model & { file_path?: string }
): Promise<void> { ): Promise<void> {
@ -134,6 +186,10 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
const { mmproj, ...settings } = model.settings const { mmproj, ...settings } = model.settings
model.settings = settings model.settings = settings
} }
const controller = new AbortController()
const { signal } = controller
this.abortControllers.set(model.id, controller)
return await this.queue.add(() => return await this.queue.add(() =>
ky ky
@ -145,13 +201,21 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
model.engine === InferenceEngine.nitro // Legacy model cache model.engine === InferenceEngine.nitro // Legacy model cache
? InferenceEngine.cortex_llamacpp ? InferenceEngine.cortex_llamacpp
: model.engine, : model.engine,
cont_batching: this.cont_batching,
n_parallel: this.n_parallel,
caching_enabled: this.caching_enabled,
flash_attn: this.flash_attn,
cache_type: this.cache_type,
use_mmap: this.use_mmap,
}, },
timeout: false, timeout: false,
signal,
}) })
.json() .json()
.catch(async (e) => { .catch(async (e) => {
throw (await e.response?.json()) ?? e throw (await e.response?.json()) ?? e
}) })
.finally(() => this.abortControllers.delete(model.id))
.then() .then()
) )
} }
@ -162,6 +226,9 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
json: { model: model.id }, json: { model: model.id },
}) })
.json() .json()
.finally(() => {
this.abortControllers.get(model.id)?.abort()
})
.then() .then()
} }

View File

@ -52,7 +52,7 @@ const ErrorMessage = ({ message }: { message: ThreadMessage }) => {
) )
default: default:
return ( return (
<p data-testid="passthrough-error-message"> <p data-testid="passthrough-error-message" className="capitalize">
{message.content[0]?.text?.value && ( {message.content[0]?.text?.value && (
<AutoLink text={message.content[0].text.value} /> <AutoLink text={message.content[0].text.value} />
)} )}

View File

@ -49,7 +49,7 @@ const LoadModelError = () => {
} else { } else {
return ( return (
<div className="mx-6 flex flex-col items-center space-y-2 text-center font-medium text-[hsla(var(--text-secondary))]"> <div className="mx-6 flex flex-col items-center space-y-2 text-center font-medium text-[hsla(var(--text-secondary))]">
{loadModelError && <p>{loadModelError}</p>} {loadModelError && <p className="capitalize">{loadModelError}</p>}
<p> <p>
{`Something's wrong.`}&nbsp;Access&nbsp; {`Something's wrong.`}&nbsp;Access&nbsp;
<span <span

View File

@ -125,7 +125,7 @@ const MessageContainer: React.FC<ThreadMessage> = (props) => {
) : ( ) : (
<div <div
className={twMerge( className={twMerge(
'message max-width-[100%] flex flex-col gap-y-2 overflow-auto leading-relaxed' 'message max-width-[100%] flex flex-col gap-y-2 overflow-x-auto overflow-y-hidden leading-relaxed'
)} )}
dir="ltr" dir="ltr"
> >