436 lines
13 KiB
TypeScript

/**
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
* @version 1.0.0
* @module inference-extension/src/index
*/
import {
Model,
EngineEvent,
LocalOAIEngine,
extractModelLoadParams,
events,
ModelEvent,
} from '@janhq/core'
import ky, { KyInstance } from 'ky'
/**
* Event subscription types of Downloader
*/
enum DownloadTypes {
DownloadUpdated = 'onFileDownloadUpdate',
DownloadError = 'onFileDownloadError',
DownloadSuccess = 'onFileDownloadSuccess',
DownloadStopped = 'onFileDownloadStopped',
DownloadStarted = 'onFileDownloadStarted',
}
enum Settings {
n_parallel = 'n_parallel',
cont_batching = 'cont_batching',
caching_enabled = 'caching_enabled',
flash_attn = 'flash_attn',
cache_type = 'cache_type',
use_mmap = 'use_mmap',
cpu_threads = 'cpu_threads',
huggingfaceToken = 'hugging-face-access-token',
auto_unload_models = 'auto_unload_models',
context_shift = 'context_shift',
}
type LoadedModelResponse = { data: { engine: string; id: string }[] }
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class JanInferenceCortexExtension extends LocalOAIEngine {
nodeModule: string = 'node'
provider: string = 'cortex'
shouldReconnect = true
/** Default Engine model load settings */
n_parallel?: number
cont_batching: boolean = false
caching_enabled: boolean = true
flash_attn: boolean = true
use_mmap: boolean = true
cache_type: string = 'q8'
cpu_threads?: number
auto_unload_models: boolean = true
reasoning_budget = -1 // Default reasoning budget in seconds
context_shift = false
/**
* The URL for making inference requests.
*/
inferenceUrl = `${CORTEX_API_URL}/v1/chat/completions`
/**
* Socket instance of events subscription
*/
socket?: WebSocket = undefined
abortControllers = new Map<string, AbortController>()
api?: KyInstance
/**
* Get the API instance
* @returns
*/
async apiInstance(): Promise<KyInstance> {
if (this.api) return this.api
const apiKey = await window.core?.api.appToken()
this.api = ky.extend({
prefixUrl: CORTEX_API_URL,
headers: apiKey
? {
Authorization: `Bearer ${apiKey}`,
}
: {},
retry: 10,
})
return this.api
}
/**
* Authorization headers for the API requests.
* @returns
*/
headers(): Promise<HeadersInit> {
return window.core?.api.appToken().then((token: string) => ({
Authorization: `Bearer ${token}`,
}))
}
/**
* Called when the extension is loaded.
*/
async onLoad() {
super.onLoad()
// Register Settings
this.registerSettings(SETTINGS)
const numParallel = await this.getSetting<string>(Settings.n_parallel, '')
if (numParallel.length > 0 && parseInt(numParallel) > 0) {
this.n_parallel = parseInt(numParallel)
}
if (this.n_parallel && this.n_parallel > 1)
this.cont_batching = await this.getSetting<boolean>(
Settings.cont_batching,
false
)
this.caching_enabled = await this.getSetting<boolean>(
Settings.caching_enabled,
true
)
this.flash_attn = await this.getSetting<boolean>(Settings.flash_attn, true)
this.context_shift = await this.getSetting<boolean>(
Settings.context_shift,
false
)
this.use_mmap = await this.getSetting<boolean>(Settings.use_mmap, true)
if (this.caching_enabled)
this.cache_type = await this.getSetting<string>(Settings.cache_type, 'q8')
this.auto_unload_models = await this.getSetting<boolean>(
Settings.auto_unload_models,
true
)
const threads_number = Number(
await this.getSetting<string>(Settings.cpu_threads, '')
)
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
const huggingfaceToken = await this.getSetting<string>(
Settings.huggingfaceToken,
''
)
if (huggingfaceToken) {
this.updateCortexConfig({ huggingface_token: huggingfaceToken })
}
this.subscribeToEvents()
window.addEventListener('beforeunload', () => {
this.clean()
})
// Migrate configs
if (!localStorage.getItem('cortex_migration_completed')) {
const config = await this.getCortexConfig()
console.log('Start cortex.cpp migration', config)
if (config && config.huggingface_token) {
this.updateSettings([
{
key: Settings.huggingfaceToken,
controllerProps: {
value: config.huggingface_token,
},
},
])
this.updateCortexConfig({
huggingface_token: config.huggingface_token,
})
localStorage.setItem('cortex_migration_completed', 'true')
}
}
}
async onUnload() {
console.log('Clean up cortex.cpp services')
this.shouldReconnect = false
this.clean()
super.onUnload()
}
/**
* Subscribe to settings update and make change accordingly
* @param key
* @param value
*/
onSettingUpdate<T>(key: string, value: T): void {
if (key === Settings.n_parallel && typeof value === 'string') {
if (value.length > 0 && parseInt(value) > 0) {
this.n_parallel = parseInt(value)
}
} else if (key === Settings.cont_batching && typeof value === 'boolean') {
this.cont_batching = value as boolean
} else if (key === Settings.caching_enabled && typeof value === 'boolean') {
this.caching_enabled = value as boolean
} else if (key === Settings.flash_attn && typeof value === 'boolean') {
this.flash_attn = value as boolean
} else if (key === Settings.cache_type && typeof value === 'string') {
this.cache_type = value as string
} else if (key === Settings.use_mmap && typeof value === 'boolean') {
this.use_mmap = value as boolean
} else if (key === Settings.cpu_threads && typeof value === 'string') {
const threads_number = Number(value)
if (!Number.isNaN(threads_number)) this.cpu_threads = threads_number
} else if (key === Settings.huggingfaceToken) {
this.updateCortexConfig({ huggingface_token: value })
} else if (key === Settings.auto_unload_models) {
this.auto_unload_models = value as boolean
} else if (key === Settings.context_shift && typeof value === 'boolean') {
this.context_shift = value
}
}
override async loadModel(
model: Partial<Model> & {
id: string
settings?: object
file_path?: string
},
abortController: AbortController
): Promise<void> {
// Cortex will handle these settings
const { llama_model_path, mmproj, ...settings } = model.settings ?? {}
model.settings = settings
const controller = abortController ?? new AbortController()
const { signal } = controller
this.abortControllers.set(model.id, controller)
const loadedModels = await this.activeModels()
// This is to avoid loading the same model multiple times
if (loadedModels.some((e: { id: string }) => e.id === model.id)) {
console.log(`Model ${model.id} already loaded`)
return
}
if (this.auto_unload_models) {
// Unload the last used model if it is not the same as the current one
for (const lastUsedModel of loadedModels) {
if (lastUsedModel.id !== model.id) {
console.log(`Unloading last used model: ${lastUsedModel.id}`)
await this.unloadModel(lastUsedModel as Model)
}
}
}
const modelSettings = extractModelLoadParams(model.settings)
return await this.apiInstance().then((api) =>
api
.post('v1/models/start', {
json: {
...modelSettings,
model: model.id,
engine:
model.engine === 'nitro' // Legacy model cache
? 'llama-cpp'
: model.engine,
...(this.n_parallel ? { n_parallel: this.n_parallel } : {}),
...(this.use_mmap ? { use_mmap: true } : {}),
...(this.caching_enabled ? { caching_enabled: true } : {}),
...(this.flash_attn ? { flash_attn: true } : {}),
...(this.caching_enabled && this.cache_type
? { cache_type: this.cache_type }
: {}),
...(this.cpu_threads && this.cpu_threads > 0
? { cpu_threads: this.cpu_threads }
: {}),
...(this.cont_batching && this.n_parallel && this.n_parallel > 1
? { cont_batching: this.cont_batching }
: {}),
...(model.id.toLowerCase().includes('jan-nano')
? { reasoning_budget: 0 }
: { reasoning_budget: this.reasoning_budget }),
...(this.context_shift !== true
? { 'no-context-shift': true }
: {}),
...(modelSettings.ngl === -1 || modelSettings.ngl === undefined
? { ngl: 100 }
: {}),
},
timeout: false,
signal,
})
.json()
.catch(async (e) => {
throw (await e.response?.json()) ?? e
})
.finally(() => this.abortControllers.delete(model.id))
.then()
)
}
override async unloadModel(model: Model): Promise<void> {
return this.apiInstance().then((api) =>
api
.post('v1/models/stop', {
json: { model: model.id },
retry: {
limit: 0,
},
})
.json()
.finally(() => {
this.abortControllers.get(model.id)?.abort()
})
.then()
)
}
async activeModels(): Promise<(object & { id: string })[]> {
return await this.apiInstance()
.then((e) =>
e.get('inferences/server/models', {
retry: {
limit: 0, // Do not retry
},
})
)
.then((e) => e.json())
.then((e) => (e as LoadedModelResponse).data ?? [])
.catch(() => [])
}
/**
* Clean cortex processes
* @returns
*/
private async clean(): Promise<any> {
return this.apiInstance()
.then((api) =>
api.delete('processmanager/destroy', {
timeout: 2000, // maximum 2 seconds
retry: {
limit: 0,
},
})
)
.catch(() => {
// Do nothing
})
}
/**
* Update cortex config
* @param body
*/
private async updateCortexConfig(body: {
[key: string]: any
}): Promise<void> {
return this.apiInstance()
.then((api) => api.patch('v1/configs', { json: body }).then(() => {}))
.catch((e) => console.debug(e))
}
/**
* Get cortex config
* @param body
*/
private async getCortexConfig(): Promise<any> {
return this.apiInstance()
.then((api) => api.get('v1/configs').json())
.catch((e) => console.debug(e))
}
/**
* Subscribe to cortex.cpp websocket events
*/
private subscribeToEvents() {
this.socket = new WebSocket(`${CORTEX_SOCKET_URL}/events`)
this.socket.addEventListener('message', (event) => {
const data = JSON.parse(event.data)
const transferred = data.task.items.reduce(
(acc: number, cur: any) => acc + cur.downloadedBytes,
0
)
const total = data.task.items.reduce(
(acc: number, cur: any) => acc + cur.bytes,
0
)
const percent = total > 0 ? transferred / total : 0
events.emit(DownloadTypes[data.type as keyof typeof DownloadTypes], {
modelId: data.task.id,
percent: percent,
size: {
transferred: transferred,
total: total,
},
downloadType: data.task.type,
})
if (data.task.type === 'Engine') {
events.emit(EngineEvent.OnEngineUpdate, {
type: DownloadTypes[data.type as keyof typeof DownloadTypes],
percent: percent,
id: data.task.id,
})
} else {
if (data.type === DownloadTypes.DownloadSuccess) {
// Delay for the state update from cortex.cpp
// Just to be sure
setTimeout(() => {
events.emit(ModelEvent.OnModelsUpdate, {
fetch: true,
})
}, 500)
}
}
})
/**
* This is to handle the server segfault issue
*/
this.socket.onclose = (event) => {
// Notify app to update model running state
events.emit(ModelEvent.OnModelStopped, {})
// Reconnect to the /events websocket
if (this.shouldReconnect) {
setTimeout(() => this.subscribeToEvents(), 1000)
}
}
}
}