/** * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. * @version 1.0.0 * @module inference-extension/src/index */ import { Model, EngineEvent, LocalOAIEngine, extractModelLoadParams, events, ModelEvent, } from '@janhq/core' import ky, { KyInstance } from 'ky' /** * Event subscription types of Downloader */ enum DownloadTypes { DownloadUpdated = 'onFileDownloadUpdate', DownloadError = 'onFileDownloadError', DownloadSuccess = 'onFileDownloadSuccess', DownloadStopped = 'onFileDownloadStopped', DownloadStarted = 'onFileDownloadStarted', } enum Settings { n_parallel = 'n_parallel', cont_batching = 'cont_batching', caching_enabled = 'caching_enabled', flash_attn = 'flash_attn', cache_type = 'cache_type', use_mmap = 'use_mmap', cpu_threads = 'cpu_threads', huggingfaceToken = 'hugging-face-access-token', auto_unload_models = 'auto_unload_models', context_shift = 'context_shift', } type LoadedModelResponse = { data: { engine: string; id: string }[] } /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ export default class JanInferenceCortexExtension extends LocalOAIEngine { nodeModule: string = 'node' provider: string = 'cortex' shouldReconnect = true /** Default Engine model load settings */ n_parallel?: number cont_batching: boolean = false caching_enabled: boolean = true flash_attn: boolean = true use_mmap: boolean = true cache_type: string = 'q8' cpu_threads?: number auto_unload_models: boolean = true reasoning_budget = -1 // Default reasoning budget in seconds context_shift = false /** * The URL for making inference requests. */ inferenceUrl = `${CORTEX_API_URL}/v1/chat/completions` /** * Socket instance of events subscription */ socket?: WebSocket = undefined abortControllers = new Map() api?: KyInstance /** * Get the API instance * @returns */ async apiInstance(): Promise { if (this.api) return this.api const apiKey = await window.core?.api.appToken() this.api = ky.extend({ prefixUrl: CORTEX_API_URL, headers: apiKey ? { Authorization: `Bearer ${apiKey}`, } : {}, retry: 10, }) return this.api } /** * Authorization headers for the API requests. * @returns */ headers(): Promise { return window.core?.api.appToken().then((token: string) => ({ Authorization: `Bearer ${token}`, })) } /** * Called when the extension is loaded. */ async onLoad() { super.onLoad() // Register Settings this.registerSettings(SETTINGS) const numParallel = await this.getSetting(Settings.n_parallel, '') if (numParallel.length > 0 && parseInt(numParallel) > 0) { this.n_parallel = parseInt(numParallel) } if (this.n_parallel && this.n_parallel > 1) this.cont_batching = await this.getSetting( Settings.cont_batching, false ) this.caching_enabled = await this.getSetting( Settings.caching_enabled, true ) this.flash_attn = await this.getSetting(Settings.flash_attn, true) this.context_shift = await this.getSetting( Settings.context_shift, false ) this.use_mmap = await this.getSetting(Settings.use_mmap, true) if (this.caching_enabled) this.cache_type = await this.getSetting(Settings.cache_type, 'q8') this.auto_unload_models = await this.getSetting( Settings.auto_unload_models, true ) const threads_number = Number( await this.getSetting(Settings.cpu_threads, '') ) if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number const huggingfaceToken = await this.getSetting( Settings.huggingfaceToken, '' ) if (huggingfaceToken) { this.updateCortexConfig({ huggingface_token: huggingfaceToken }) } this.subscribeToEvents() window.addEventListener('beforeunload', () => { this.clean() }) // Migrate configs if (!localStorage.getItem('cortex_migration_completed')) { const config = await this.getCortexConfig() console.log('Start cortex.cpp migration', config) if (config && config.huggingface_token) { this.updateSettings([ { key: Settings.huggingfaceToken, controllerProps: { value: config.huggingface_token, }, }, ]) this.updateCortexConfig({ huggingface_token: config.huggingface_token, }) localStorage.setItem('cortex_migration_completed', 'true') } } } async onUnload() { console.log('Clean up cortex.cpp services') this.shouldReconnect = false this.clean() super.onUnload() } /** * Subscribe to settings update and make change accordingly * @param key * @param value */ onSettingUpdate(key: string, value: T): void { if (key === Settings.n_parallel && typeof value === 'string') { if (value.length > 0 && parseInt(value) > 0) { this.n_parallel = parseInt(value) } } else if (key === Settings.cont_batching && typeof value === 'boolean') { this.cont_batching = value as boolean } else if (key === Settings.caching_enabled && typeof value === 'boolean') { this.caching_enabled = value as boolean } else if (key === Settings.flash_attn && typeof value === 'boolean') { this.flash_attn = value as boolean } else if (key === Settings.cache_type && typeof value === 'string') { this.cache_type = value as string } else if (key === Settings.use_mmap && typeof value === 'boolean') { this.use_mmap = value as boolean } else if (key === Settings.cpu_threads && typeof value === 'string') { const threads_number = Number(value) if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number } else if (key === Settings.huggingfaceToken) { this.updateCortexConfig({ huggingface_token: value }) } else if (key === Settings.auto_unload_models) { this.auto_unload_models = value as boolean } else if (key === Settings.context_shift && typeof value === 'boolean') { this.context_shift = value } } override async loadModel( model: Partial & { id: string settings?: object file_path?: string }, abortController: AbortController ): Promise { // Cortex will handle these settings const { llama_model_path, mmproj, ...settings } = model.settings ?? {} model.settings = settings const controller = abortController ?? new AbortController() const { signal } = controller this.abortControllers.set(model.id, controller) const loadedModels = await this.activeModels() // This is to avoid loading the same model multiple times if (loadedModels.some((e: { id: string }) => e.id === model.id)) { console.log(`Model ${model.id} already loaded`) return } if (this.auto_unload_models) { // Unload the last used model if it is not the same as the current one for (const lastUsedModel of loadedModels) { if (lastUsedModel.id !== model.id) { console.log(`Unloading last used model: ${lastUsedModel.id}`) await this.unloadModel(lastUsedModel as Model) } } } const modelSettings = extractModelLoadParams(model.settings) return await this.apiInstance().then((api) => api .post('v1/models/start', { json: { ...modelSettings, model: model.id, engine: model.engine === 'nitro' // Legacy model cache ? 'llama-cpp' : model.engine, ...(this.n_parallel ? { n_parallel: this.n_parallel } : {}), ...(this.use_mmap ? { use_mmap: true } : {}), ...(this.caching_enabled ? { caching_enabled: true } : {}), ...(this.flash_attn ? { flash_attn: true } : {}), ...(this.caching_enabled && this.cache_type ? { cache_type: this.cache_type } : {}), ...(this.cpu_threads && this.cpu_threads > 0 ? { cpu_threads: this.cpu_threads } : {}), ...(this.cont_batching && this.n_parallel && this.n_parallel > 1 ? { cont_batching: this.cont_batching } : {}), ...(model.id.toLowerCase().includes('jan-nano') ? { reasoning_budget: 0 } : { reasoning_budget: this.reasoning_budget }), ...(this.context_shift !== true // explicit true required to enable context shift ? { 'no-context-shift': true } : {}), ...(modelSettings.ngl === -1 || modelSettings.ngl === undefined ? { ngl: 100 } : {}), }, timeout: false, signal, }) .json() .catch(async (e) => { throw (await e.response?.json()) ?? e }) .finally(() => this.abortControllers.delete(model.id)) .then() ) } override async unloadModel(model: Model): Promise { return this.apiInstance().then((api) => api .post('v1/models/stop', { json: { model: model.id }, retry: { limit: 0, }, }) .json() .finally(() => { this.abortControllers.get(model.id)?.abort() }) .then() ) } async activeModels(): Promise<(object & { id: string })[]> { return await this.apiInstance() .then((e) => e.get('inferences/server/models', { retry: { limit: 0, // Do not retry }, }) ) .then((e) => e.json()) .then((e) => (e as LoadedModelResponse).data ?? []) .catch(() => []) } /** * Clean cortex processes * @returns */ private async clean(): Promise { return this.apiInstance() .then((api) => api.delete('processmanager/destroy', { timeout: 2000, // maximum 2 seconds retry: { limit: 0, }, }) ) .catch(() => { // Do nothing }) } /** * Update cortex config * @param body */ private async updateCortexConfig(body: { [key: string]: any }): Promise { return this.apiInstance() .then((api) => api.patch('v1/configs', { json: body }).then(() => {})) .catch((e) => console.debug(e)) } /** * Get cortex config * @param body */ private async getCortexConfig(): Promise { return this.apiInstance() .then((api) => api.get('v1/configs').json()) .catch((e) => console.debug(e)) } /** * Subscribe to cortex.cpp websocket events */ private subscribeToEvents() { this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`) this.socket.addEventListener('message', (event) => { const data = JSON.parse(event.data) const transferred = data.task.items.reduce( (acc: number, cur: any) => acc + cur.downloadedBytes, 0 ) const total = data.task.items.reduce( (acc: number, cur: any) => acc + cur.bytes, 0 ) const percent = total > 0 ? transferred / total : 0 events.emit(DownloadTypes[data.type as keyof typeof DownloadTypes], { modelId: data.task.id, percent: percent, size: { transferred: transferred, total: total, }, downloadType: data.task.type, }) if (data.task.type === 'Engine') { events.emit(EngineEvent.OnEngineUpdate, { type: DownloadTypes[data.type as keyof typeof DownloadTypes], percent: percent, id: data.task.id, }) } else { if (data.type === DownloadTypes.DownloadSuccess) { // Delay for the state update from cortex.cpp // Just to be sure setTimeout(() => { events.emit(ModelEvent.OnModelsUpdate, { fetch: true, }) }, 500) } } }) /** * This is to handle the server segfault issue */ this.socket.onclose = (event) => { // Notify app to update model running state events.emit(ModelEvent.OnModelStopped, {}) // Reconnect to the /events websocket if (this.shouldReconnect) { setTimeout(() => this.subscribeToEvents(), 1000) } } } }