/** * @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. * @version 1.0.0 * @module llamacpp-extension/src/index */ import { AIEngine, getJanDataFolderPath, fs, joinPath, modelInfo, SessionInfo, UnloadResult, chatCompletion, chatCompletionChunk, ImportOptions, chatCompletionRequest, events, AppEvent, DownloadEvent, } from '@janhq/core' import { error, info, warn } from '@tauri-apps/plugin-log' import { listen } from '@tauri-apps/api/event' import { listSupportedBackends, downloadBackend, isBackendInstalled, getBackendExePath, getBackendDir, } from './backend' import { invoke } from '@tauri-apps/api/core' import { getProxyConfig } from './util' import { basename } from '@tauri-apps/api/path' import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api' import { getSystemUsage } from '@janhq/tauri-plugin-hardware-api' type LlamacppConfig = { version_backend: string auto_update_engine: boolean auto_unload: boolean llamacpp_env: string memory_util: string chat_template: string n_gpu_layers: number offload_mmproj: boolean override_tensor_buffer_t: string ctx_size: number threads: number threads_batch: number n_predict: number batch_size: number ubatch_size: number device: string split_mode: string main_gpu: number flash_attn: boolean cont_batching: boolean no_mmap: boolean mlock: boolean no_kv_offload: boolean cache_type_k: string cache_type_v: string defrag_thold: number rope_scaling: string rope_scale: number rope_freq_base: number rope_freq_scale: number ctx_shift: boolean } type ModelPlan = { gpuLayers: number maxContextLength: number noOffloadKVCache: boolean noOffloadMmproj?: boolean mode: 'GPU' | 'Hybrid' | 'CPU' | 'Unsupported' } interface DownloadItem { url: string save_path: string proxy?: Record sha256?: string size?: number } interface ModelConfig { model_path: string mmproj_path?: string name: string // user-friendly // some model info that we cache upon import size_bytes: number sha256?: string mmproj_sha256?: string mmproj_size_bytes?: number } interface EmbeddingResponse { model: string object: string usage: { prompt_tokens: number total_tokens: number } data: EmbeddingData[] } interface EmbeddingData { embedding: number[] index: number object: string } interface DeviceList { id: string name: string mem: number free: number } interface SystemMemory { totalVRAM: number totalRAM: number totalMemory: number } /** * Override the default app.log function to use Jan's logging system. * @param args */ const logger = { info: function (...args: any[]) { console.log(...args) info(args.map((arg) => ` ${arg}`).join(` `)) }, warn: function (...args: any[]) { console.warn(...args) warn(args.map((arg) => ` ${arg}`).join(` `)) }, error: function (...args: any[]) { console.error(...args) error(args.map((arg) => ` ${arg}`).join(` `)) }, } /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ // Folder structure for llamacpp extension: // /llamacpp // - models// // - model.yml (required) // - model.gguf (optional, present if downloaded from URL) // - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL) // Contents of model.yml can be found in ModelConfig interface // // - backends/// // - build/bin/llama-server (or llama-server.exe on Windows) // // - lib/ // - e.g. libcudart.so.12 export default class llamacpp_extension extends AIEngine { provider: string = 'llamacpp' autoUnload: boolean = true llamacpp_env: string = '' memoryMode: string = 'high' readonly providerId: string = 'llamacpp' private config: LlamacppConfig private providerPath!: string private apiSecret: string = 'JustAskNow' private pendingDownloads: Map> = new Map() private isConfiguringBackends: boolean = false private loadingModels = new Map>() // Track loading promises private unlistenValidationStarted?: () => void override async onLoad(): Promise { super.onLoad() // Calls registerEngine() from AIEngine let settings = structuredClone(SETTINGS) // Clone to modify settings definition before registration // This makes the settings (including the backend options and initial value) available to the Jan UI. this.registerSettings(settings) let loadedConfig: any = {} for (const item of settings) { const defaultValue = item.controllerProps.value // Use the potentially updated default value from the settings array as the fallback for getSetting loadedConfig[item.key] = await this.getSetting( item.key, defaultValue ) } this.config = loadedConfig as LlamacppConfig this.autoUnload = this.config.auto_unload this.llamacpp_env = this.config.llamacpp_env this.memoryMode = this.config.memory_util // This sets the base directory where model files for this provider are stored. this.providerPath = await joinPath([ await getJanDataFolderPath(), this.providerId, ]) // Set up validation event listeners to bridge Tauri events to frontend this.unlistenValidationStarted = await listen<{ modelId: string downloadType: string }>('onModelValidationStarted', (event) => { console.debug( 'LlamaCPP: bridging onModelValidationStarted event', event.payload ) events.emit(DownloadEvent.onModelValidationStarted, event.payload) }) this.configureBackends() } private getStoredBackendType(): string | null { try { return localStorage.getItem('llama_cpp_backend_type') } catch (error) { logger.warn('Failed to read backend type from localStorage:', error) return null } } private setStoredBackendType(backendType: string): void { try { localStorage.setItem('llama_cpp_backend_type', backendType) logger.info(`Stored backend type preference: ${backendType}`) } catch (error) { logger.warn('Failed to store backend type in localStorage:', error) } } private clearStoredBackendType(): void { try { localStorage.removeItem('llama_cpp_backend_type') logger.info('Cleared stored backend type preference') } catch (error) { logger.warn('Failed to clear backend type from localStorage:', error) } } private findLatestVersionForBackend( version_backends: { version: string; backend: string }[], backendType: string ): string | null { const matchingBackends = version_backends.filter( (vb) => vb.backend === backendType ) if (matchingBackends.length === 0) { return null } // Sort by version (newest first) and get the latest matchingBackends.sort((a, b) => b.version.localeCompare(a.version)) return `${matchingBackends[0].version}/${matchingBackends[0].backend}` } async configureBackends(): Promise { if (this.isConfiguringBackends) { logger.info( 'configureBackends already in progress, skipping duplicate call' ) return } this.isConfiguringBackends = true try { let version_backends: { version: string; backend: string }[] = [] try { version_backends = await listSupportedBackends() if (version_backends.length === 0) { throw new Error( 'No supported backend binaries found for this system. Backend selection and auto-update will be unavailable.' ) } else { version_backends.sort((a, b) => b.version.localeCompare(a.version)) } } catch (error) { throw new Error( `Failed to fetch supported backends: ${ error instanceof Error ? error.message : error }` ) } // Get stored backend preference const storedBackendType = this.getStoredBackendType() let bestAvailableBackendString = '' if (storedBackendType) { // Find the latest version of the stored backend type const preferredBackendString = this.findLatestVersionForBackend( version_backends, storedBackendType ) if (preferredBackendString) { bestAvailableBackendString = preferredBackendString logger.info( `Using stored backend preference: ${bestAvailableBackendString}` ) } else { logger.warn( `Stored backend type '${storedBackendType}' not available, falling back to best backend` ) // Clear the invalid stored preference this.clearStoredBackendType() bestAvailableBackendString = this.determineBestBackend(version_backends) } } else { bestAvailableBackendString = this.determineBestBackend(version_backends) } let settings = structuredClone(SETTINGS) const backendSettingIndex = settings.findIndex( (item) => item.key === 'version_backend' ) let originalDefaultBackendValue = '' if (backendSettingIndex !== -1) { const backendSetting = settings[backendSettingIndex] originalDefaultBackendValue = backendSetting.controllerProps .value as string backendSetting.controllerProps.options = version_backends.map((b) => { const key = `${b.version}/${b.backend}` return { value: key, name: key } }) // Set the recommended backend based on bestAvailableBackendString if (bestAvailableBackendString) { backendSetting.controllerProps.recommended = bestAvailableBackendString } const savedBackendSetting = await this.getSetting( 'version_backend', originalDefaultBackendValue ) // Determine initial UI default based on priority: // 1. Saved setting (if valid and not original default) // 2. Best available for stored backend type // 3. Original default let initialUiDefault = originalDefaultBackendValue if ( savedBackendSetting && savedBackendSetting !== originalDefaultBackendValue ) { initialUiDefault = savedBackendSetting // Store the backend type from the saved setting only if different const [, backendType] = savedBackendSetting.split('/') if (backendType) { const currentStoredBackend = this.getStoredBackendType() if (currentStoredBackend !== backendType) { this.setStoredBackendType(backendType) logger.info( `Stored backend type preference from saved setting: ${backendType}` ) } } } else if (bestAvailableBackendString) { initialUiDefault = bestAvailableBackendString // Store the backend type from the best available only if different const [, backendType] = bestAvailableBackendString.split('/') if (backendType) { const currentStoredBackend = this.getStoredBackendType() if (currentStoredBackend !== backendType) { this.setStoredBackendType(backendType) logger.info( `Stored backend type preference from best available: ${backendType}` ) } } } backendSetting.controllerProps.value = initialUiDefault logger.info( `Initial UI default for version_backend set to: ${initialUiDefault}` ) } else { logger.error( 'Critical setting "version_backend" definition not found in SETTINGS.' ) throw new Error('Critical setting "version_backend" not found.') } this.registerSettings(settings) let effectiveBackendString = this.config.version_backend let backendWasDownloaded = false // Handle fresh installation case where version_backend might be 'none' or invalid if ( (!effectiveBackendString || effectiveBackendString === 'none' || !effectiveBackendString.includes('/') || // If the selected backend is not in the list of supported backends // Need to reset too !version_backends.some( (e) => `${e.version}/${e.backend}` === effectiveBackendString )) && // Ensure we have a valid best available backend bestAvailableBackendString ) { effectiveBackendString = bestAvailableBackendString logger.info( `Fresh installation or invalid backend detected, using: ${effectiveBackendString}` ) // Update the config immediately this.config.version_backend = effectiveBackendString // Update the settings to reflect the change in UI const updatedSettings = await this.getSettings() await this.updateSettings( updatedSettings.map((item) => { if (item.key === 'version_backend') { item.controllerProps.value = effectiveBackendString } return item }) ) logger.info(`Updated UI settings to show: ${effectiveBackendString}`) // Emit for updating fe if (events && typeof events.emit === 'function') { logger.info( `Emitting settingsChanged event for version_backend with value: ${effectiveBackendString}` ) events.emit('settingsChanged', { key: 'version_backend', value: effectiveBackendString, }) } } // Download and install the backend if not already present if (effectiveBackendString) { const [version, backend] = effectiveBackendString.split('/') if (version && backend) { const isInstalled = await isBackendInstalled(backend, version) if (!isInstalled) { logger.info(`Installing initial backend: ${effectiveBackendString}`) await this.ensureBackendReady(backend, version) backendWasDownloaded = true logger.info( `Successfully installed initial backend: ${effectiveBackendString}` ) } } } if (this.config.auto_update_engine) { const updateResult = await this.handleAutoUpdate( bestAvailableBackendString ) if (updateResult.wasUpdated) { effectiveBackendString = updateResult.newBackend backendWasDownloaded = true } } if (!backendWasDownloaded && effectiveBackendString) { await this.ensureFinalBackendInstallation(effectiveBackendString) } } finally { this.isConfiguringBackends = false } } private determineBestBackend( version_backends: { version: string; backend: string }[] ): string { if (version_backends.length === 0) return '' // Priority list for backend types (more specific/performant ones first) const backendPriorities: string[] = [ 'cuda-cu12.0', 'cuda-cu11.7', 'vulkan', 'avx512', 'avx2', 'avx', 'noavx', 'arm64', 'x64', ] // Helper to map backend string to a priority category const getBackendCategory = (backendString: string): string | undefined => { if (backendString.includes('cu12.0')) return 'cuda-cu12.0' if (backendString.includes('cu11.7')) return 'cuda-cu11.7' if (backendString.includes('vulkan')) return 'vulkan' if (backendString.includes('avx512')) return 'avx512' if (backendString.includes('avx2')) return 'avx2' if ( backendString.includes('avx') && !backendString.includes('avx2') && !backendString.includes('avx512') ) return 'avx' if (backendString.includes('noavx')) return 'noavx' if (backendString.endsWith('arm64')) return 'arm64' if (backendString.endsWith('x64')) return 'x64' return undefined } let foundBestBackend: { version: string; backend: string } | undefined for (const priorityCategory of backendPriorities) { const matchingBackends = version_backends.filter((vb) => { const category = getBackendCategory(vb.backend) return category === priorityCategory }) if (matchingBackends.length > 0) { foundBestBackend = matchingBackends[0] logger.info( `Determined best available backend: ${foundBestBackend.version}/${foundBestBackend.backend} (Category: "${priorityCategory}")` ) break } } if (foundBestBackend) { return `${foundBestBackend.version}/${foundBestBackend.backend}` } else { // Fallback to newest version return `${version_backends[0].version}/${version_backends[0].backend}` } } async updateBackend( targetBackendString: string ): Promise<{ wasUpdated: boolean; newBackend: string }> { try { if (!targetBackendString) throw new Error( `Invalid backend string: ${targetBackendString} supplied to update function` ) const [version, backend] = targetBackendString.split('/') logger.info( `Updating backend to ${targetBackendString} (backend type: ${backend})` ) // Download new backend await this.ensureBackendReady(backend, version) // Add delay on Windows if (IS_WINDOWS) { await new Promise((resolve) => setTimeout(resolve, 1000)) } // Update configuration this.config.version_backend = targetBackendString // Store the backend type preference only if it changed const currentStoredBackend = this.getStoredBackendType() if (currentStoredBackend !== backend) { this.setStoredBackendType(backend) logger.info(`Updated stored backend type preference: ${backend}`) } // Update settings const settings = await this.getSettings() await this.updateSettings( settings.map((item) => { if (item.key === 'version_backend') { item.controllerProps.value = targetBackendString } return item }) ) logger.info(`Successfully updated to backend: ${targetBackendString}`) // Emit for updating frontend if (events && typeof events.emit === 'function') { logger.info( `Emitting settingsChanged event for version_backend with value: ${targetBackendString}` ) events.emit('settingsChanged', { key: 'version_backend', value: targetBackendString, }) } // Clean up old versions of the same backend type if (IS_WINDOWS) { await new Promise((resolve) => setTimeout(resolve, 500)) } await this.removeOldBackend(version, backend) return { wasUpdated: true, newBackend: targetBackendString } } catch (error) { logger.error('Backend update failed:', error) return { wasUpdated: false, newBackend: this.config.version_backend } } } private async handleAutoUpdate( bestAvailableBackendString: string ): Promise<{ wasUpdated: boolean; newBackend: string }> { logger.info( `Auto-update engine is enabled. Current backend: ${this.config.version_backend}. Best available: ${bestAvailableBackendString}` ) if (!bestAvailableBackendString) { logger.warn( 'Auto-update enabled, but no best available backend determined' ) return { wasUpdated: false, newBackend: this.config.version_backend } } // If version_backend is empty, invalid, or 'none', use the best available backend if ( !this.config.version_backend || this.config.version_backend === '' || this.config.version_backend === 'none' || !this.config.version_backend.includes('/') ) { logger.info( 'No valid backend currently selected, using best available backend' ) return await this.updateBackend(bestAvailableBackendString) } // Parse current backend configuration const [currentVersion, currentBackend] = ( this.config.version_backend || '' ).split('/') if (!currentVersion || !currentBackend) { logger.warn( `Invalid current backend format: ${this.config.version_backend}` ) return { wasUpdated: false, newBackend: this.config.version_backend } } // Find the latest version for the currently selected backend type const version_backends = await listSupportedBackends() const targetBackendString = this.findLatestVersionForBackend( version_backends, currentBackend ) if (!targetBackendString) { logger.warn( `No available versions found for current backend type: ${currentBackend}` ) return { wasUpdated: false, newBackend: this.config.version_backend } } const [latestVersion] = targetBackendString.split('/') // Check if update is needed (only version comparison for same backend type) if (currentVersion === latestVersion) { logger.info( 'Auto-update: Already using the latest version of the selected backend' ) return { wasUpdated: false, newBackend: this.config.version_backend } } // Perform version update for the same backend type logger.info( `Auto-updating from ${this.config.version_backend} to ${targetBackendString} (preserving backend type)` ) return await this.updateBackend(targetBackendString) } private parseBackendVersion(v: string): number { // Remove any leading non‑digit characters (e.g. the "b") const numeric = v.replace(/^[^\d]*/, '') const n = Number(numeric) return Number.isNaN(n) ? 0 : n } async checkBackendForUpdates(): Promise<{ updateNeeded: boolean newVersion: string }> { // Parse current backend configuration const [currentVersion, currentBackend] = ( this.config.version_backend || '' ).split('/') if (!currentVersion || !currentBackend) { logger.warn( `Invalid current backend format: ${this.config.version_backend}` ) return { updateNeeded: false, newVersion: '0' } } // Find the latest version for the currently selected backend type const version_backends = await listSupportedBackends() const targetBackendString = this.findLatestVersionForBackend( version_backends, currentBackend ) const [latestVersion] = targetBackendString.split('/') if ( this.parseBackendVersion(latestVersion) > this.parseBackendVersion(currentVersion) ) { logger.info(`New update available: ${latestVersion}`) return { updateNeeded: true, newVersion: latestVersion } } else { logger.info( `Already at latest version: ${currentVersion} = ${latestVersion}` ) return { updateNeeded: false, newVersion: '0' } } } private async removeOldBackend( latestVersion: string, backendType: string ): Promise { try { const janDataFolderPath = await getJanDataFolderPath() const backendsDir = await joinPath([ janDataFolderPath, 'llamacpp', 'backends', ]) if (!(await fs.existsSync(backendsDir))) { return } const versionDirs = await fs.readdirSync(backendsDir) for (const versionDir of versionDirs) { const versionPath = await joinPath([backendsDir, versionDir]) const versionName = await basename(versionDir) // Skip the latest version if (versionName === latestVersion) { continue } // Check if this version has the specific backend type we're interested in const backendTypePath = await joinPath([versionPath, backendType]) if (await fs.existsSync(backendTypePath)) { const isInstalled = await isBackendInstalled(backendType, versionName) if (isInstalled) { try { await fs.rm(backendTypePath) logger.info( `Removed old version of ${backendType}: ${backendTypePath}` ) } catch (e) { logger.warn( `Failed to remove old backend version: ${backendTypePath}`, e ) } } } } } catch (error) { logger.error('Error during old backend version cleanup:', error) } } private async ensureFinalBackendInstallation( backendString: string ): Promise { if (!backendString) { logger.warn('No backend specified for final installation check') return } const [selectedVersion, selectedBackend] = backendString .split('/') .map((part) => part?.trim()) if (!selectedVersion || !selectedBackend) { logger.warn(`Invalid backend format: ${backendString}`) return } try { const isInstalled = await isBackendInstalled( selectedBackend, selectedVersion ) if (!isInstalled) { logger.info(`Final check: Installing backend ${backendString}`) await this.ensureBackendReady(selectedBackend, selectedVersion) logger.info(`Successfully installed backend: ${backendString}`) } else { logger.info( `Final check: Backend ${backendString} is already installed` ) } } catch (error) { logger.error( `Failed to ensure backend ${backendString} installation:`, error ) throw error // Re-throw as this is critical } } async getProviderPath(): Promise { if (!this.providerPath) { this.providerPath = await joinPath([ await getJanDataFolderPath(), this.providerId, ]) } return this.providerPath } override async onUnload(): Promise { // Terminate all active sessions // Clean up validation event listeners if (this.unlistenValidationStarted) { this.unlistenValidationStarted() } } onSettingUpdate(key: string, value: T): void { this.config[key] = value if (key === 'version_backend') { const valueStr = value as string const [version, backend] = valueStr.split('/') // Store the backend type preference in localStorage only if it changed if (backend) { const currentStoredBackend = this.getStoredBackendType() if (currentStoredBackend !== backend) { this.setStoredBackendType(backend) logger.info(`Updated backend type preference to: ${backend}`) } } // Reset device setting when backend changes this.config.device = '' const closure = async () => { await this.ensureBackendReady(backend, version) } closure() } else if (key === 'auto_unload') { this.autoUnload = value as boolean } else if (key === 'llamacpp_env') { this.llamacpp_env = value as string } else if (key === 'memory_util') { this.memoryMode = value as string } } private async generateApiKey(modelId: string, port: string): Promise { const hash = await invoke('plugin:llamacpp|generate_api_key', { modelId: modelId + port, apiSecret: this.apiSecret, }) return hash } // Implement the required LocalProvider interface methods override async list(): Promise { const modelsDir = await joinPath([await this.getProviderPath(), 'models']) if (!(await fs.existsSync(modelsDir))) { await fs.mkdir(modelsDir) } await this.migrateLegacyModels() let modelIds: string[] = [] // DFS let stack = [modelsDir] while (stack.length > 0) { const currentDir = stack.pop() // check if model.yml exists const modelConfigPath = await joinPath([currentDir, 'model.yml']) if (await fs.existsSync(modelConfigPath)) { // +1 to remove the leading slash // NOTE: this does not handle Windows path \\ modelIds.push(currentDir.slice(modelsDir.length + 1)) continue } // otherwise, look into subdirectories const children = await fs.readdirSync(currentDir) for (const child of children) { // skip files const dirInfo = await fs.fileStat(child) if (!dirInfo.isDirectory) { continue } stack.push(child) } } let modelInfos: modelInfo[] = [] for (const modelId of modelIds) { const path = await joinPath([modelsDir, modelId, 'model.yml']) const modelConfig = await invoke('read_yaml', { path }) const modelInfo = { id: modelId, name: modelConfig.name ?? modelId, quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf providerId: this.provider, port: 0, // port is not known until the model is loaded sizeBytes: modelConfig.size_bytes ?? 0, } as modelInfo modelInfos.push(modelInfo) } return modelInfos } private async migrateLegacyModels() { // Attempt to migrate only once if (localStorage.getItem('cortex_models_migrated') === 'true') return const janDataFolderPath = await getJanDataFolderPath() const modelsDir = await joinPath([janDataFolderPath, 'models']) if (!(await fs.existsSync(modelsDir))) return // DFS let stack = [modelsDir] while (stack.length > 0) { const currentDir = stack.pop() const files = await fs.readdirSync(currentDir) for (const child of files) { try { const childPath = await joinPath([currentDir, child]) const stat = await fs.fileStat(childPath) if ( files.some((e) => e.endsWith('model.yml')) && !child.endsWith('model.yml') ) continue if (!stat.isDirectory && child.endsWith('.yml')) { // check if model.yml exists const modelConfigPath = child if (await fs.existsSync(modelConfigPath)) { const legacyModelConfig = await invoke<{ files: string[] model: string }>('read_yaml', { path: modelConfigPath, }) const legacyModelPath = legacyModelConfig.files?.[0] if (!legacyModelPath) continue // +1 to remove the leading slash // NOTE: this does not handle Windows path \\ let modelId = currentDir.slice(modelsDir.length + 1) modelId = modelId !== 'imported' ? modelId.replace(/^(cortex\.so|huggingface\.co)[\/\\]/, '') : (await basename(child)).replace('.yml', '') const modelName = legacyModelConfig.model ?? modelId const configPath = await joinPath([ await this.getProviderPath(), 'models', modelId, 'model.yml', ]) if (await fs.existsSync(configPath)) continue // Don't reimport // this is relative to Jan's data folder const modelDir = `${this.providerId}/models/${modelId}` let size_bytes = ( await fs.fileStat( await joinPath([janDataFolderPath, legacyModelPath]) ) ).size const modelConfig = { model_path: legacyModelPath, mmproj_path: undefined, // legacy models do not have mmproj name: modelName, size_bytes, } as ModelConfig await fs.mkdir(await joinPath([janDataFolderPath, modelDir])) await invoke('write_yaml', { data: modelConfig, savePath: configPath, }) continue } } } catch (error) { console.error(`Error migrating model ${child}:`, error) } } // otherwise, look into subdirectories const children = await fs.readdirSync(currentDir) for (const child of children) { // skip files const dirInfo = await fs.fileStat(child) if (!dirInfo.isDirectory) { continue } stack.push(child) } } localStorage.setItem('cortex_models_migrated', 'true') } /* * Manually installs a supported backend archive * */ async installBackend(path: string): Promise { const platformName = IS_WINDOWS ? 'win' : 'linux' const re = /^llama-(b\d+)-bin-(.+?)\.tar\.gz$/ const archiveName = await basename(path) logger.info(`Installing backend from path: ${path}`) if ((await fs.existsSync(path)) && path.endsWith('tar.gz')) { const match = re.exec(archiveName) if (!match) throw new Error('Failed to parse archive name') const [, version, backend] = match if (!version && !backend) { throw new Error(`Invalid backend archive name: ${archiveName}`) } const backendDir = await getBackendDir(backend, version) try { await invoke('decompress', { path: path, outputDir: backendDir }) } catch (e) { logger.error(`Failed to install: ${String(e)}`) } const binPath = platformName === 'win' ? await joinPath([backendDir, 'build', 'bin', 'llama-server.exe']) : await joinPath([backendDir, 'build', 'bin', 'llama-server']) if (!fs.existsSync(binPath)) { await fs.rm(backendDir) throw new Error('Not a supported backend archive!') } try { await this.configureBackends() logger.info(`Backend ${backend}/${version} installed and UI refreshed`) } catch (e) { logger.error('Backend installed but failed to refresh UI', e) throw new Error( `Backend installed but failed to refresh UI: ${String(e)}` ) } } } override async import(modelId: string, opts: ImportOptions): Promise { const isValidModelId = (id: string) => { // only allow alphanumeric, underscore, hyphen, and dot characters in modelId if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false // check for empty parts or path traversal const parts = id.split('/') return parts.every((s) => s !== '' && s !== '.' && s !== '..') } if (!isValidModelId(modelId)) throw new Error( `Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.` ) const configPath = await joinPath([ await this.getProviderPath(), 'models', modelId, 'model.yml', ]) if (await fs.existsSync(configPath)) throw new Error(`Model ${modelId} already exists`) // this is relative to Jan's data folder const modelDir = `${this.providerId}/models/${modelId}` // we only use these from opts // opts.modelPath: URL to the model file // opts.mmprojPath: URL to the mmproj file let downloadItems: DownloadItem[] = [] const maybeDownload = async (path: string, saveName: string) => { // if URL, add to downloadItems, and return local path if (path.startsWith('https://')) { const localPath = `${modelDir}/${saveName}` downloadItems.push({ url: path, save_path: localPath, proxy: getProxyConfig(), sha256: saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256, size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize, }) return localPath } // if local file (absolute path), check if it exists // and return the path if (!(await fs.existsSync(path))) throw new Error(`File not found: ${path}`) return path } let modelPath = await maybeDownload(opts.modelPath, 'model.gguf') let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined if (downloadItems.length > 0) { try { // emit download update event on progress const onProgress = (transferred: number, total: number) => { events.emit(DownloadEvent.onFileDownloadUpdate, { modelId, percent: transferred / total, size: { transferred, total }, downloadType: 'Model', }) } const downloadManager = window.core.extensionManager.getByName( '@janhq/download-extension' ) await downloadManager.downloadFiles( downloadItems, this.createDownloadTaskId(modelId), onProgress ) // If we reach here, download completed successfully (including validation) // The downloadFiles function only returns successfully if all files downloaded AND validated events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, { modelId, downloadType: 'Model', }) } catch (error) { logger.error('Error downloading model:', modelId, opts, error) const errorMessage = error instanceof Error ? error.message : String(error) // Check if this is a cancellation const isCancellationError = errorMessage.includes('Download cancelled') || errorMessage.includes('Validation cancelled') || errorMessage.includes('Hash computation cancelled') || errorMessage.includes('cancelled') || errorMessage.includes('aborted') // Check if this is a validation failure const isValidationError = errorMessage.includes('Hash verification failed') || errorMessage.includes('Size verification failed') || errorMessage.includes('Failed to verify file') if (isCancellationError) { logger.info('Download cancelled for model:', modelId) // Emit download stopped event instead of error events.emit(DownloadEvent.onFileDownloadStopped, { modelId, downloadType: 'Model', }) } else if (isValidationError) { logger.error( 'Validation failed for model:', modelId, 'Error:', errorMessage ) // Cancel any other download tasks for this model try { this.abortImport(modelId) } catch (cancelError) { logger.warn('Failed to cancel download task:', cancelError) } // Emit validation failure event events.emit(DownloadEvent.onModelValidationFailed, { modelId, downloadType: 'Model', error: errorMessage, reason: 'validation_failed', }) } else { // Regular download error events.emit(DownloadEvent.onFileDownloadError, { modelId, downloadType: 'Model', error: errorMessage, }) } throw error } } // Validate GGUF files const janDataFolderPath = await getJanDataFolderPath() const fullModelPath = await joinPath([janDataFolderPath, modelPath]) try { // Validate main model file const modelMetadata = await readGgufMetadata(fullModelPath) logger.info( `Model GGUF validation successful: version ${modelMetadata.version}, tensors: ${modelMetadata.tensor_count}` ) // Validate mmproj file if present if (mmprojPath) { const fullMmprojPath = await joinPath([janDataFolderPath, mmprojPath]) const mmprojMetadata = await readGgufMetadata(fullMmprojPath) logger.info( `Mmproj GGUF validation successful: version ${mmprojMetadata.version}, tensors: ${mmprojMetadata.tensor_count}` ) } } catch (error) { logger.error('GGUF validation failed:', error) throw new Error( `Invalid GGUF file(s): ${ error.message || 'File format validation failed' }` ) } // Calculate file sizes let size_bytes = (await fs.fileStat(fullModelPath)).size if (mmprojPath) { size_bytes += ( await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath])) ).size } // TODO: add name as import() argument // TODO: add updateModelConfig() method const modelConfig = { model_path: modelPath, mmproj_path: mmprojPath, name: modelId, size_bytes, model_sha256: opts.modelSha256, model_size_bytes: opts.modelSize, mmproj_sha256: opts.mmprojSha256, mmproj_size_bytes: opts.mmprojSize, } as ModelConfig await fs.mkdir(await joinPath([janDataFolderPath, modelDir])) await invoke('write_yaml', { data: modelConfig, savePath: configPath, }) events.emit(AppEvent.onModelImported, { modelId, modelPath, mmprojPath, size_bytes, model_sha256: opts.modelSha256, model_size_bytes: opts.modelSize, mmproj_sha256: opts.mmprojSha256, mmproj_size_bytes: opts.mmprojSize, }) } /** * Deletes the entire model folder for a given modelId * @param modelId The model ID to delete */ private async deleteModelFolder(modelId: string): Promise { try { const modelDir = await joinPath([ await this.getProviderPath(), 'models', modelId, ]) if (await fs.existsSync(modelDir)) { logger.info(`Cleaning up model directory: ${modelDir}`) await fs.rm(modelDir) } } catch (deleteError) { logger.warn('Failed to delete model directory:', deleteError) } } override async abortImport(modelId: string): Promise { // Cancel any active download task // prepend provider name to avoid name collision const taskId = this.createDownloadTaskId(modelId) const downloadManager = window.core.extensionManager.getByName( '@janhq/download-extension' ) try { await downloadManager.cancelDownload(taskId) } catch (cancelError) { logger.warn('Failed to cancel download task:', cancelError) } // Delete the entire model folder if it exists (for validation failures) await this.deleteModelFolder(modelId) } /** * Function to find a random port */ private async getRandomPort(): Promise { try { const port = await invoke('plugin:llamacpp|get_random_port') return port } catch { logger.error('Unable to find a suitable port') throw new Error('Unable to find a suitable port for model') } } private parseEnvFromString( target: Record, envString: string ): void { envString .split(';') .filter((pair) => pair.trim()) .forEach((pair) => { const [key, ...valueParts] = pair.split('=') const cleanKey = key?.trim() if ( cleanKey && valueParts.length > 0 && !cleanKey.startsWith('LLAMA') ) { target[cleanKey] = valueParts.join('=').trim() } }) } override async load( modelId: string, overrideSettings?: Partial, isEmbedding: boolean = false ): Promise { const sInfo = await this.findSessionByModel(modelId) if (sInfo) { throw new Error('Model already loaded!!') } // If this model is already being loaded, return the existing promise if (this.loadingModels.has(modelId)) { return this.loadingModels.get(modelId)! } // Create the loading promise const loadingPromise = this.performLoad( modelId, overrideSettings, isEmbedding ) this.loadingModels.set(modelId, loadingPromise) try { const result = await loadingPromise return result } finally { this.loadingModels.delete(modelId) } } private async performLoad( modelId: string, overrideSettings?: Partial, isEmbedding: boolean = false ): Promise { const loadedModels = await this.getLoadedModels() // Get OTHER models that are currently loading (exclude current model) const otherLoadingPromises = Array.from(this.loadingModels.entries()) .filter(([id, _]) => id !== modelId) .map(([_, promise]) => promise) if ( this.autoUnload && (loadedModels.length > 0 || otherLoadingPromises.length > 0) ) { // Wait for OTHER loading models to finish, then unload everything if (otherLoadingPromises.length > 0) { await Promise.all(otherLoadingPromises) } // Now unload all loaded models const allLoadedModels = await this.getLoadedModels() if (allLoadedModels.length > 0) { await Promise.all(allLoadedModels.map((model) => this.unload(model))) } } const args: string[] = [] const envs: Record = {} const cfg = { ...this.config, ...(overrideSettings ?? {}) } const [version, backend] = cfg.version_backend.split('/') if (!version || !backend) { throw new Error( 'Initial setup for the backend failed due to a network issue. Please restart the app!' ) } // Ensure backend is downloaded and ready before proceeding await this.ensureBackendReady(backend, version) const janDataFolderPath = await getJanDataFolderPath() const modelConfigPath = await joinPath([ this.providerPath, 'models', modelId, 'model.yml', ]) const modelConfig = await invoke('read_yaml', { path: modelConfigPath, }) const port = await this.getRandomPort() // disable llama-server webui args.push('--no-webui') const api_key = await this.generateApiKey(modelId, String(port)) envs['LLAMA_API_KEY'] = api_key // set user envs if (this.llamacpp_env) this.parseEnvFromString(envs, this.llamacpp_env) // model option is required // NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path const modelPath = await joinPath([ janDataFolderPath, modelConfig.model_path, ]) args.push('--jinja') args.push('-m', modelPath) // For overriding tensor buffer type, useful where // massive MOE models can be made faster by keeping attention on the GPU // and offloading the expert FFNs to the CPU. // This is an expert level settings and should only be used by people // who knows what they are doing. // Takes a regex with matching tensor name as input if (cfg.override_tensor_buffer_t) args.push('--override-tensor', cfg.override_tensor_buffer_t) // offload multimodal projector model to the GPU by default. if there is not enough memory // turn this setting off will keep the projector model on the CPU but the image processing can // take longer if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload') args.push('-a', modelId) args.push('--port', String(port)) if (modelConfig.mmproj_path) { const mmprojPath = await joinPath([ janDataFolderPath, modelConfig.mmproj_path, ]) args.push('--mmproj', mmprojPath) } // Add remaining options from the interface if (cfg.chat_template) args.push('--chat-template', cfg.chat_template) const gpu_layers = parseInt(String(cfg.n_gpu_layers)) >= 0 ? cfg.n_gpu_layers : 100 args.push('-ngl', String(gpu_layers)) if (cfg.threads > 0) args.push('--threads', String(cfg.threads)) if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch)) if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size)) if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size)) if (cfg.device.length > 0) args.push('--device', cfg.device) if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer') args.push('--split-mode', cfg.split_mode) if (cfg.main_gpu !== undefined && cfg.main_gpu != 0) args.push('--main-gpu', String(cfg.main_gpu)) // Boolean flags if (!cfg.ctx_shift) args.push('--no-context-shift') if (Number(version.replace(/^b/, '')) >= 6325) { if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported } else { if (cfg.flash_attn) args.push('--flash-attn') } if (cfg.cont_batching) args.push('--cont-batching') args.push('--no-mmap') if (cfg.mlock) args.push('--mlock') if (cfg.no_kv_offload) args.push('--no-kv-offload') if (isEmbedding) { args.push('--embedding') args.push('--pooling mean') } else { if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size)) if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict)) if (cfg.cache_type_k && cfg.cache_type_k != 'f16') args.push('--cache-type-k', cfg.cache_type_k) if ( cfg.flash_attn && cfg.cache_type_v != 'f16' && cfg.cache_type_v != 'f32' ) { args.push('--cache-type-v', cfg.cache_type_v) } if (cfg.defrag_thold && cfg.defrag_thold != 0.1) args.push('--defrag-thold', String(cfg.defrag_thold)) if (cfg.rope_scaling && cfg.rope_scaling != 'none') args.push('--rope-scaling', cfg.rope_scaling) if (cfg.rope_scale && cfg.rope_scale != 1) args.push('--rope-scale', String(cfg.rope_scale)) if (cfg.rope_freq_base && cfg.rope_freq_base != 0) args.push('--rope-freq-base', String(cfg.rope_freq_base)) if (cfg.rope_freq_scale && cfg.rope_freq_scale != 1) args.push('--rope-freq-scale', String(cfg.rope_freq_scale)) } logger.info('Calling Tauri command llama_load with args:', args) const backendPath = await getBackendExePath(backend, version) const libraryPath = await joinPath([await this.getProviderPath(), 'lib']) try { // TODO: add LIBRARY_PATH const sInfo = await invoke( 'plugin:llamacpp|load_llama_model', { backendPath, libraryPath, args, envs, } ) return sInfo } catch (error) { logger.error('Error in load command:\n', error) throw error } } override async unload(modelId: string): Promise { const sInfo: SessionInfo = await this.findSessionByModel(modelId) if (!sInfo) { throw new Error(`No active session found for model: ${modelId}`) } const pid = sInfo.pid try { // Pass the PID as the session_id const result = await invoke( 'plugin:llamacpp|unload_llama_model', { pid: pid, } ) // If successful, remove from active sessions if (result.success) { logger.info(`Successfully unloaded model with PID ${pid}`) } else { logger.warn(`Failed to unload model: ${result.error}`) } return result } catch (error) { logger.error('Error in unload command:', error) return { success: false, error: `Failed to unload model: ${error}`, } } } private createDownloadTaskId(modelId: string) { // prepend provider to make taksId unique across providers const cleanModelId = modelId.includes('.') ? modelId.slice(0, modelId.indexOf('.')) : modelId return `${this.provider}/${cleanModelId}` } private async ensureBackendReady( backend: string, version: string ): Promise { const backendKey = `${version}/${backend}` // Check if backend is already installed const isInstalled = await isBackendInstalled(backend, version) if (isInstalled) { return } // Check if download is already in progress if (this.pendingDownloads.has(backendKey)) { logger.info( `Backend ${backendKey} download already in progress, waiting...` ) await this.pendingDownloads.get(backendKey) return } // Start new download logger.info(`Backend ${backendKey} not installed, downloading...`) const downloadPromise = downloadBackend(backend, version).finally(() => { this.pendingDownloads.delete(backendKey) }) this.pendingDownloads.set(backendKey, downloadPromise) await downloadPromise logger.info(`Backend ${backendKey} download completed`) } private async *handleStreamingResponse( url: string, headers: HeadersInit, body: string, abortController?: AbortController ): AsyncIterable { const response = await fetch(url, { method: 'POST', headers, body, connectTimeout: 600000, // 10 minutes signal: AbortSignal.any([ AbortSignal.timeout(600000), abortController?.signal, ]), }) if (!response.ok) { const errorData = await response.json().catch(() => null) throw new Error( `API request failed with status ${response.status}: ${JSON.stringify( errorData )}` ) } if (!response.body) { throw new Error('Response body is null') } const reader = response.body.getReader() const decoder = new TextDecoder('utf-8') let buffer = '' let jsonStr = '' try { while (true) { const { done, value } = await reader.read() if (done) { break } buffer += decoder.decode(value, { stream: true }) // Process complete lines in the buffer const lines = buffer.split('\n') buffer = lines.pop() || '' // Keep the last incomplete line in the buffer for (const line of lines) { const trimmedLine = line.trim() if (!trimmedLine || trimmedLine === 'data: [DONE]') { continue } if (trimmedLine.startsWith('data: ')) { jsonStr = trimmedLine.slice(6) } else if (trimmedLine.startsWith('error: ')) { jsonStr = trimmedLine.slice(7) const error = JSON.parse(jsonStr) throw new Error(error.message) } else { // it should not normally reach here throw new Error('Malformed chunk') } try { const data = JSON.parse(jsonStr) const chunk = data as chatCompletionChunk yield chunk } catch (e) { logger.error('Error parsing JSON from stream or server error:', e) // re‑throw so the async iterator terminates with an error throw e } } } } finally { reader.releaseLock() } } private async findSessionByModel(modelId: string): Promise { try { let sInfo = await invoke( 'plugin:llamacpp|find_session_by_model', { modelId, } ) return sInfo } catch (e) { logger.error(e) throw new Error(String(e)) } } override async chat( opts: chatCompletionRequest, abortController?: AbortController ): Promise> { const sessionInfo = await this.findSessionByModel(opts.model) if (!sessionInfo) { throw new Error(`No active session found for model: ${opts.model}`) } // check if the process is alive const result = await invoke('plugin:llamacpp|is_process_running', { pid: sessionInfo.pid, }) if (result) { try { await fetch(`http://localhost:${sessionInfo.port}/health`) } catch (e) { this.unload(sessionInfo.model_id) throw new Error('Model appears to have crashed! Please reload!') } } else { throw new Error('Model have crashed! Please reload!') } const baseUrl = `http://localhost:${sessionInfo.port}/v1` const url = `${baseUrl}/chat/completions` const headers = { 'Content-Type': 'application/json', 'Authorization': `Bearer ${sessionInfo.api_key}`, } const body = JSON.stringify(opts) if (opts.stream) { return this.handleStreamingResponse(url, headers, body, abortController) } // Handle non-streaming response const response = await fetch(url, { method: 'POST', headers, body, signal: abortController?.signal, }) if (!response.ok) { const errorData = await response.json().catch(() => null) throw new Error( `API request failed with status ${response.status}: ${JSON.stringify( errorData )}` ) } return (await response.json()) as chatCompletion } override async delete(modelId: string): Promise { const modelDir = await joinPath([ await this.getProviderPath(), 'models', modelId, ]) if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) { throw new Error(`Model ${modelId} does not exist`) } await fs.rm(modelDir) } override async getLoadedModels(): Promise { try { let models: string[] = await invoke( 'plugin:llamacpp|get_loaded_models' ) return models } catch (e) { logger.error(e) throw new Error(e) } } /** * Check if mmproj.gguf file exists for a given model ID * @param modelId - The model ID to check for mmproj.gguf * @returns Promise - true if mmproj.gguf exists, false otherwise */ async checkMmprojExists(modelId: string): Promise { try { const modelConfigPath = await joinPath([ await this.getProviderPath(), 'models', modelId, 'model.yml', ]) const modelConfig = await invoke('read_yaml', { path: modelConfigPath, }) // If mmproj_path is not defined in YAML, return false if (modelConfig.mmproj_path) { return true } const mmprojPath = await joinPath([ await this.getProviderPath(), 'models', modelId, 'mmproj.gguf', ]) return await fs.existsSync(mmprojPath) } catch (e) { logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e) return false } } async getDevices(): Promise { const cfg = this.config const [version, backend] = cfg.version_backend.split('/') if (!version || !backend) { throw new Error( 'Backend setup was not successful. Please restart the app in a stable internet connection.' ) } // set envs const envs: Record = {} if (this.llamacpp_env) this.parseEnvFromString(envs, this.llamacpp_env) // Ensure backend is downloaded and ready before proceeding await this.ensureBackendReady(backend, version) logger.info('Calling Tauri command getDevices with arg --list-devices') const backendPath = await getBackendExePath(backend, version) const libraryPath = await joinPath([await this.getProviderPath(), 'lib']) try { const dList = await invoke('plugin:llamacpp|get_devices', { backendPath, libraryPath, envs, }) return dList } catch (error) { logger.error('Failed to query devices:\n', error) throw new Error('Failed to load llamacpp backend') } } async embed(text: string[]): Promise { let sInfo = await this.findSessionByModel('sentence-transformer-mini') if (!sInfo) { const downloadedModelList = await this.list() if ( !downloadedModelList.some( (model) => model.id === 'sentence-transformer-mini' ) ) { await this.import('sentence-transformer-mini', { modelPath: 'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true', }) } sInfo = await this.load('sentence-transformer-mini') } const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings` const headers = { 'Content-Type': 'application/json', 'Authorization': `Bearer ${sInfo.api_key}`, } const body = JSON.stringify({ input: text, model: sInfo.model_id, encoding_format: 'float', }) const response = await fetch(baseUrl, { method: 'POST', headers, body, }) if (!response.ok) { const errorData = await response.json().catch(() => null) throw new Error( `API request failed with status ${response.status}: ${JSON.stringify( errorData )}` ) } const responseData = await response.json() return responseData as EmbeddingResponse } // Optional method for direct client access override getChatClient(sessionId: string): any { throw new Error('method not implemented yet') } /** * Check if a tool is supported by the model * Currently read from GGUF chat_template * @param modelId * @returns */ async isToolSupported(modelId: string): Promise { const janDataFolderPath = await getJanDataFolderPath() const modelConfigPath = await joinPath([ this.providerPath, 'models', modelId, 'model.yml', ]) const modelConfig = await invoke('read_yaml', { path: modelConfigPath, }) // model option is required // NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path const modelPath = await joinPath([ janDataFolderPath, modelConfig.model_path, ]) return (await readGgufMetadata(modelPath)).metadata?.[ 'tokenizer.chat_template' ]?.includes('tools') } /** * Get total system memory including both VRAM and RAM */ private async getTotalSystemMemory(): Promise { const devices = await this.getDevices() let totalVRAM = 0 if (devices.length > 0) { // Sum total VRAM across all GPUs totalVRAM = devices .map((d) => d.mem * 1024 * 1024) .reduce((a, b) => a + b, 0) } // Get system RAM const sys = await getSystemUsage() const totalRAM = sys.used_memory * 1024 * 1024 const totalMemory = totalVRAM + totalRAM logger.info( `Total VRAM: ${totalVRAM} bytes, Total RAM: ${totalRAM} bytes, Free: ${usableRAM} bytes, Total Memory: ${totalMemory} bytes` ) return { totalVRAM, totalRAM, totalMemory, } } private async getKVCachePerToken( meta: Record ): Promise { const arch = meta['general.architecture'] const nLayer = Number(meta[`${arch}.block_count`]) const nHead = Number(meta[`${arch}.attention.head_count`]) // Get head dimensions const nHeadKV = Number(meta[`${arch}.attention.head_count_kv`]) || nHead const embeddingLen = Number(meta[`${arch}.embedding_length`]) const headDim = embeddingLen / nHead // KV cache uses head_count_kv (for GQA models) or head_count // Each token needs K and V, both are fp16 (2 bytes) const bytesPerToken = nHeadKV * headDim * 2 * 2 * nLayer // K+V, fp16, all layers return bytesPerToken } private async getLayerSize( path: string, meta: Record ): Promise<{ layerSize: number; totalLayers: number }> { const modelSize = await this.getModelSize(path) const arch = meta['general.architecture'] const totalLayers = Number(meta[`${arch}.block_count`]) if (!totalLayers) throw new Error('Invalid metadata: block_count not found') return { layerSize: modelSize / totalLayers, totalLayers } } async planModelLoad( path: string, requestedCtx?: number, mmprojPath?: string ): Promise { const modelSize = await this.getModelSize(path) const memoryInfo = await this.getTotalSystemMemory() const gguf = await readGgufMetadata(path) // Get mmproj size if provided let mmprojSize = 0 if (mmprojPath) { mmprojSize = await this.getModelSize(mmprojPath) } const { layerSize, totalLayers } = await this.getLayerSize( path, gguf.metadata ) // Fixed KV cache calculation const kvCachePerToken = await this.getKVCachePerToken(gguf.metadata) // Debug logging logger.info( `Model size: ${modelSize}, Layer size: ${layerSize}, Total layers: ${totalLayers}, KV cache per token: ${kvCachePerToken}` ) // Validate critical values if (!modelSize || modelSize <= 0) { throw new Error(`Invalid model size: ${modelSize}`) } if (!kvCachePerToken || kvCachePerToken <= 0) { throw new Error(`Invalid KV cache per token: ${kvCachePerToken}`) } if (!layerSize || layerSize <= 0) { throw new Error(`Invalid layer size: ${layerSize}`) } // GPU overhead factor (20% reserved for GPU operations, alignment, etc.) const GPU_OVERHEAD_FACTOR = 0.8 // VRAM budget with overhead consideration const VRAM_RESERVE_GB = 0.5 const VRAM_RESERVE_BYTES = VRAM_RESERVE_GB * 1024 * 1024 * 1024 const usableVRAM = Math.max( 0, (memoryInfo.totalVRAM - VRAM_RESERVE_BYTES) * GPU_OVERHEAD_FACTOR ) // Get model's maximum context length const arch = gguf.metadata['general.architecture'] const modelMaxContextLength = Number(gguf.metadata[`${arch}.context_length`]) || 131072 // Default fallback // Set minimum context length const MIN_CONTEXT_LENGTH = 2048 // Reduced from 4096 for better compatibility // System RAM budget const memoryPercentages = { high: 0.7, medium: 0.5, low: 0.4 } logger.info( `Memory info - Total (VRAM + RAM): ${memoryInfo.totalMemory}, Total VRAM: ${memoryInfo.totalVRAM}, Mode: ${this.memoryMode}` ) // Validate memory info if (!memoryInfo.totalMemory || isNaN(memoryInfo.totalMemory)) { throw new Error(`Invalid total memory: ${memoryInfo.totalMemory}`) } if (!memoryInfo.totalVRAM || isNaN(memoryInfo.totalVRAM)) { throw new Error(`Invalid total VRAM: ${memoryInfo.totalVRAM}`) } if (!this.memoryMode || !(this.memoryMode in memoryPercentages)) { throw new Error( `Invalid memory mode: ${this.memoryMode}. Must be 'high', 'medium', or 'low'` ) } // Calculate actual system RAM const actualSystemRAM = Math.max( 0, memoryInfo.totalMemory - memoryInfo.totalVRAM ) const usableSystemMemory = actualSystemRAM * memoryPercentages[this.memoryMode] logger.info( `Actual System RAM: ${actualSystemRAM}, Usable VRAM: ${usableVRAM}, Usable System Memory: ${usableSystemMemory}` ) // --- Priority 1: Allocate mmproj (if exists) --- let noOffloadMmproj = false let remainingVRAM = usableVRAM if (mmprojSize > 0) { if (mmprojSize <= remainingVRAM) { noOffloadMmproj = true remainingVRAM -= mmprojSize logger.info(`MMProj allocated to VRAM: ${mmprojSize} bytes`) } else { logger.info(`MMProj will use CPU RAM: ${mmprojSize} bytes`) } } // --- Priority 2: Calculate optimal layer/context balance --- let gpuLayers = 0 let maxContextLength = MIN_CONTEXT_LENGTH let noOffloadKVCache = false let mode: ModelPlan['mode'] = 'Unsupported' // Calculate how much VRAM we need for different context sizes const contextSizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072] const targetContext = requestedCtx || modelMaxContextLength // Find the best balance of layers and context let bestConfig = { layers: 0, context: MIN_CONTEXT_LENGTH, vramUsed: 0, } for (const ctxSize of contextSizes) { if (ctxSize > targetContext) break const kvCacheSize = ctxSize * kvCachePerToken const availableForLayers = remainingVRAM - kvCacheSize if (availableForLayers <= 0) continue const possibleLayers = Math.min( Math.floor(availableForLayers / layerSize), totalLayers ) if (possibleLayers > 0) { const totalVramNeeded = possibleLayers * layerSize + kvCacheSize // Verify this fits with some margin if (totalVramNeeded <= remainingVRAM * 0.95) { bestConfig = { layers: possibleLayers, context: ctxSize, vramUsed: totalVramNeeded, } } } } // Apply the best configuration found if (bestConfig.layers > 0) { gpuLayers = bestConfig.layers maxContextLength = bestConfig.context noOffloadKVCache = false mode = gpuLayers === totalLayers ? 'GPU' : 'Hybrid' logger.info( `Best GPU config: ${gpuLayers}/${totalLayers} layers, ${maxContextLength} context, ` + `VRAM used: ${bestConfig.vramUsed}/${remainingVRAM} bytes` ) } else { // Fallback: Try minimal GPU layers with KV cache on CPU gpuLayers = Math.min( Math.floor((remainingVRAM * 0.9) / layerSize), // Use 90% for layers totalLayers ) if (gpuLayers > 0) { // Calculate available system RAM for KV cache const cpuLayers = totalLayers - gpuLayers const modelCPUSize = cpuLayers * layerSize const mmprojCPUSize = mmprojSize > 0 && !noOffloadMmproj ? mmprojSize : 0 const systemRAMUsed = modelCPUSize + mmprojCPUSize const availableSystemRAMForKVCache = Math.max( 0, usableSystemMemory - systemRAMUsed ) // Calculate context that fits in system RAM const systemRAMContext = Math.min( Math.floor(availableSystemRAMForKVCache / kvCachePerToken), targetContext ) if (systemRAMContext >= MIN_CONTEXT_LENGTH) { maxContextLength = systemRAMContext noOffloadKVCache = true mode = 'Hybrid' logger.info( `Hybrid mode: ${gpuLayers}/${totalLayers} layers on GPU, ` + `${maxContextLength} context on CPU RAM` ) } else { // Can't fit reasonable context even with CPU RAM // Reduce GPU layers further gpuLayers = Math.floor(gpuLayers / 2) maxContextLength = MIN_CONTEXT_LENGTH noOffloadKVCache = true mode = gpuLayers > 0 ? 'Hybrid' : 'CPU' } } else { // Pure CPU mode gpuLayers = 0 noOffloadKVCache = true // Calculate context for pure CPU mode const totalCPUMemoryNeeded = modelSize + (mmprojSize || 0) const availableForKVCache = Math.max( 0, usableSystemMemory - totalCPUMemoryNeeded ) maxContextLength = Math.min( Math.max( MIN_CONTEXT_LENGTH, Math.floor(availableForKVCache / kvCachePerToken) ), targetContext ) mode = maxContextLength >= MIN_CONTEXT_LENGTH ? 'CPU' : 'Unsupported' } } // Safety check: Verify total GPU memory usage if (gpuLayers > 0 && !noOffloadKVCache) { const estimatedGPUUsage = gpuLayers * layerSize + maxContextLength * kvCachePerToken + (noOffloadMmproj ? mmprojSize : 0) if (estimatedGPUUsage > memoryInfo.totalVRAM * 0.9) { logger.warn( `GPU memory usage (${estimatedGPUUsage}) exceeds safe limit. Adjusting...` ) // Reduce context first while ( maxContextLength > MIN_CONTEXT_LENGTH && estimatedGPUUsage > memoryInfo.totalVRAM * 0.9 ) { maxContextLength = Math.floor(maxContextLength / 2) const newEstimate = gpuLayers * layerSize + maxContextLength * kvCachePerToken + (noOffloadMmproj ? mmprojSize : 0) if (newEstimate <= memoryInfo.totalVRAM * 0.9) break } // If still too much, reduce layers if (estimatedGPUUsage > memoryInfo.totalVRAM * 0.9) { gpuLayers = Math.floor(gpuLayers * 0.7) mode = gpuLayers > 0 ? 'Hybrid' : 'CPU' noOffloadKVCache = true // Move KV cache to CPU } } } // Apply user-requested context limit if specified if (requestedCtx && requestedCtx > 0) { maxContextLength = Math.min(maxContextLength, requestedCtx) logger.info( `User requested context: ${requestedCtx}, final: ${maxContextLength}` ) } // Ensure we never exceed model's maximum context maxContextLength = Math.min(maxContextLength, modelMaxContextLength) // Final validation if (gpuLayers <= 0 && maxContextLength < MIN_CONTEXT_LENGTH) { mode = 'Unsupported' } // Ensure maxContextLength is valid maxContextLength = isNaN(maxContextLength) ? MIN_CONTEXT_LENGTH : Math.max(MIN_CONTEXT_LENGTH, maxContextLength) // Log final plan const mmprojInfo = mmprojPath ? `, mmprojSize=${(mmprojSize / (1024 * 1024)).toFixed(2)}MB, noOffloadMmproj=${noOffloadMmproj}` : '' logger.info( `Final plan for ${path}: gpuLayers=${gpuLayers}/${totalLayers}, ` + `maxContextLength=${maxContextLength}, noOffloadKVCache=${noOffloadKVCache}, ` + `mode=${mode}${mmprojInfo}` ) return { gpuLayers, maxContextLength, noOffloadKVCache, mode, noOffloadMmproj, } } /** * estimate KVCache size from a given metadata */ private async estimateKVCache( meta: Record, ctx_size?: number ): Promise { const arch = meta['general.architecture'] if (!arch) throw new Error('Invalid metadata: architecture not found') const nLayer = Number(meta[`${arch}.block_count`]) if (!nLayer) throw new Error('Invalid metadata: block_count not found') const nHead = Number(meta[`${arch}.attention.head_count`]) if (!nHead) throw new Error('Invalid metadata: head_count not found') // Try to get key/value lengths first (more accurate) const keyLen = Number(meta[`${arch}.attention.key_length`]) const valLen = Number(meta[`${arch}.attention.value_length`]) let headDim: number if (keyLen && valLen) { // Use explicit key/value lengths if available logger.info( `Using explicit key_length: ${keyLen}, value_length: ${valLen}` ) headDim = keyLen + valLen } else { // Fall back to embedding_length estimation const embeddingLen = Number(meta[`${arch}.embedding_length`]) if (!embeddingLen) throw new Error('Invalid metadata: embedding_length not found') // Standard transformer: head_dim = embedding_dim / num_heads // For KV cache: we need both K and V, so 2 * head_dim per head headDim = (embeddingLen / nHead) * 2 logger.info( `Using embedding_length estimation: ${embeddingLen}, calculated head_dim: ${headDim}` ) } let ctxLen: number if (!ctx_size) { ctxLen = Number(meta[`${arch}.context_length`]) } else { ctxLen = ctx_size } logger.info(`ctxLen: ${ctxLen}`) logger.info(`nLayer: ${nLayer}`) logger.info(`nHead: ${nHead}`) logger.info(`headDim: ${headDim}`) // Consider f16 by default // Can be extended by checking cache-type-v and cache-type-k // but we are checking overall compatibility with the default settings // fp16 = 8 bits * 2 = 16 const bytesPerElement = 2 // Total KV cache size per token = nHead * headDim * bytesPerElement const kvPerToken = nHead * headDim * bytesPerElement return ctxLen * nLayer * kvPerToken } private async getModelSize(path: string): Promise { if (path.startsWith('https://')) { const res = await fetch(path, { method: 'HEAD' }) const len = res.headers.get('content-length') return len ? parseInt(len, 10) : 0 } else { return (await fs.fileStat(path)).size } } /** * Check the support status of a model by its path (local/remote) * * Returns: * - "RED" → weights don't fit in total memory * - "YELLOW" → weights fit in VRAM but need system RAM, or KV cache doesn't fit * - "GREEN" → both weights + KV cache fit in VRAM */ async isModelSupported( path: string, ctx_size?: number ): Promise<'RED' | 'YELLOW' | 'GREEN'> { try { const modelSize = await this.getModelSize(path) const memoryInfo = await this.getTotalSystemMemory() logger.info(`modelSize: ${modelSize}`) const gguf = await readGgufMetadata(path) let kvCacheSize: number if (ctx_size) { kvCacheSize = await this.estimateKVCache(gguf.metadata, ctx_size) } else { kvCacheSize = await this.estimateKVCache(gguf.metadata) } // Total memory consumption = model weights + kvcache const totalRequired = modelSize + kvCacheSize logger.info( `isModelSupported: Total memory requirement: ${totalRequired} for ${path}` ) // Use 80% of total memory as the usable limit const USABLE_MEMORY_PERCENTAGE = 0.8 const usableTotalMemory = memoryInfo.totalMemory * USABLE_MEMORY_PERCENTAGE const usableVRAM = memoryInfo.totalVRAM * USABLE_MEMORY_PERCENTAGE // Check if model fits in total memory at all if (modelSize > usableTotalMemory) { return 'RED' } // Check if everything fits in VRAM (ideal case) if (totalRequired <= usableVRAM) { return 'GREEN' } // Check if model fits in VRAM but total requirement exceeds VRAM // OR if total requirement fits in total memory but not in VRAM if (modelSize <= usableVRAM || totalRequired <= usableTotalMemory) { return 'YELLOW' } // If we get here, nothing fits properly return 'RED' } catch (e) { throw new Error(String(e)) } } /** * Validate GGUF file and check for unsupported architectures like CLIP */ async validateGgufFile(filePath: string): Promise<{ isValid: boolean error?: string metadata?: any }> { try { logger.info(`Validating GGUF file: ${filePath}`) const metadata = await readGgufMetadata(filePath) // Log full metadata for debugging logger.info('Full GGUF metadata:', JSON.stringify(metadata, null, 2)) // Check if architecture is 'clip' which is not supported for text generation const architecture = metadata.metadata?.['general.architecture'] logger.info(`Model architecture: ${architecture}`) if (architecture === 'clip') { const errorMessage = 'This model has CLIP architecture and cannot be imported as a text generation model. CLIP models are designed for vision tasks and require different handling.' logger.error('CLIP architecture detected:', architecture) return { isValid: false, error: errorMessage, metadata, } } logger.info('Model validation passed. Architecture:', architecture) return { isValid: true, metadata, } } catch (error) { logger.error('Failed to validate GGUF file:', error) return { isValid: false, error: `Failed to read model metadata: ${error instanceof Error ? error.message : 'Unknown error'}`, } } } }