From fd9e034461a9832f8c62d1d54b02168090247813 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Fri, 30 May 2025 10:25:58 +0530 Subject: [PATCH] feat: update AIEngine load method and backend path handling - Changed load method to accept modelId instead of loadOptions for better clarity and simplicity - Renamed engineBasePath parameter to backendPath for consistency with the backend's directory structure - Added getRandomPort method to ensure unique ports for each session to prevent conflicts - Refactored configuration and model loading logic to improve maintainability and reduce redundancy --- .../browser/extensions/engines/AIEngine.ts | 2 +- extensions/llamacpp-extension/src/index.ts | 205 ++++++++++++------ .../inference_llamacpp_extension/server.rs | 2 +- 3 files changed, 143 insertions(+), 66 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index be7bdb0e5..9e3fb9884 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -178,7 +178,7 @@ export abstract class AIEngine extends BaseExtension { /** * Loads a model into memory */ - abstract load(opts: loadOptions): Promise + abstract load(modelId: string): Promise /** * Unloads a model from memory diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 3a573d740..455406b7a 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -21,34 +21,38 @@ import { chatCompletionRequest, events, } from '@janhq/core' -import { listSupportedBackends, downloadBackend, isBackendInstalled } from './backend' +import { + listSupportedBackends, + downloadBackend, + isBackendInstalled, +} from './backend' import { invoke } from '@tauri-apps/api/core' type LlamacppConfig = { - backend: string; - n_gpu_layers: number; - ctx_size: number; - threads: number; - threads_batch: number; - n_predict: number; - batch_size: number; - ubatch_size: number; - device: string; - split_mode: string; - main_gpu: number; - flash_attn: boolean; - cont_batching: boolean; - no_mmap: boolean; - mlock: boolean; - no_kv_offload: boolean; - cache_type_k: string; - cache_type_v: string; - defrag_thold: number; - rope_scaling: string; - rope_scale: number; - rope_freq_base: number; - rope_freq_scale: number; - reasoning_budget: number; + backend: string + n_gpu_layers: number + ctx_size: number + threads: number + threads_batch: number + n_predict: number + batch_size: number + ubatch_size: number + device: string + split_mode: string + main_gpu: number + flash_attn: boolean + cont_batching: boolean + no_mmap: boolean + mlock: boolean + no_kv_offload: boolean + cache_type_k: string + cache_type_v: string + defrag_thold: number + rope_scaling: string + rope_scale: number + rope_freq_base: number + rope_freq_scale: number + reasoning_budget: number } interface DownloadItem { @@ -64,7 +68,6 @@ interface ModelConfig { size_bytes: number } - /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -85,11 +88,10 @@ export default class llamacpp_extension extends AIEngine { private config: LlamacppConfig private downloadManager - private downloadBackend // for testing + private downloadBackend // for testing private activeSessions: Map = new Map() private modelsBasePath!: string - private enginesBasePath!: string - private apiSecret: string = "Jan" + private apiSecret: string = 'Jan' override async onLoad(): Promise { super.onLoad() // Calls registerEngine() from AIEngine @@ -117,33 +119,36 @@ export default class llamacpp_extension extends AIEngine { let config = {} for (const item of SETTINGS) { const defaultValue = item.controllerProps.value - config[item.key] = await this.getSetting(item.key, defaultValue) + config[item.key] = await this.getSetting( + item.key, + defaultValue + ) } this.config = config as LlamacppConfig - this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') + this.downloadManager = window.core.extensionManager.getByName( + '@janhq/download-extension' + ) // Initialize models base path - assuming this would be retrieved from settings this.modelsBasePath = await joinPath([ await getJanDataFolderPath(), 'models', ]) - - this.enginesBasePath = await joinPath([await getJanDataFolderPath(), 'engines']) } override async onUnload(): Promise { // Terminate all active sessions for (const [sessionId, _] of this.activeSessions) { try { - await this.unload(sessionId); + await this.unload(sessionId) } catch (error) { - console.error(`Failed to unload session ${sessionId}:`, error); + console.error(`Failed to unload session ${sessionId}:`, error) } } // Clear the sessions map - this.activeSessions.clear(); + this.activeSessions.clear() } onSettingUpdate(key: string, value: T): void { @@ -168,7 +173,7 @@ export default class llamacpp_extension extends AIEngine { private async generateApiKey(modelId: string): Promise { const hash = await invoke('generate_api_key', { modelId: modelId, - apiSecret: this.apiSecret + apiSecret: this.apiSecret, }) return hash } @@ -211,7 +216,12 @@ export default class llamacpp_extension extends AIEngine { let modelInfos: modelInfo[] = [] for (const modelId of modelIds) { - const path = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml']) + const path = await joinPath([ + this.modelsBasePath, + this.provider, + modelId, + 'model.yml', + ]) const modelConfig = await invoke('read_yaml', { path }) const modelInfo = { @@ -235,14 +245,21 @@ export default class llamacpp_extension extends AIEngine { // check for empty parts or path traversal const parts = id.split('/') - return parts.every(s => s !== '' && s !== '.' && s !== '..') + return parts.every((s) => s !== '' && s !== '.' && s !== '..') } if (!isValidModelId(modelId)) { - throw new Error(`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`) + throw new Error( + `Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.` + ) } - let configPath = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml']) + let configPath = await joinPath([ + this.modelsBasePath, + this.provider, + modelId, + 'model.yml', + ]) if (await fs.existsSync(configPath)) { throw new Error(`Model ${modelId} already exists`) } @@ -260,8 +277,11 @@ export default class llamacpp_extension extends AIEngine { let modelPath = opts.modelPath let mmprojPath = opts.mmprojPath - const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` } - if (opts.modelPath.startsWith("https://")) { + const modelItem = { + url: opts.modelPath, + save_path: `${modelDir}/model.gguf`, + } + if (opts.modelPath.startsWith('https://')) { downloadItems.push(modelItem) modelPath = modelItem.save_path } else { @@ -272,8 +292,11 @@ export default class llamacpp_extension extends AIEngine { } if (opts.mmprojPath) { - const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` } - if (opts.mmprojPath.startsWith("https://")) { + const mmprojItem = { + url: opts.mmprojPath, + save_path: `${modelDir}/mmproj.gguf`, + } + if (opts.mmprojPath.startsWith('https://')) { downloadItems.push(mmprojItem) mmprojPath = mmprojItem.save_path } else { @@ -298,7 +321,11 @@ export default class llamacpp_extension extends AIEngine { }) downloadCompleted = transferred === total } - await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress) + await this.downloadManager.downloadFiles( + downloadItems, + taskId, + onProgress + ) } catch (error) { console.error('Error downloading model:', modelId, opts, error) events.emit('onFileDownloadError', { modelId, downloadType: 'Model' }) @@ -307,7 +334,9 @@ export default class llamacpp_extension extends AIEngine { // once we reach this point, it either means download finishes or it was cancelled. // if there was an error, it would have been caught above - const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped' + const eventName = downloadCompleted + ? 'onFileDownloadSuccess' + : 'onFileDownloadStopped' events.emit(eventName, { modelId, downloadType: 'Model' }) } @@ -315,9 +344,13 @@ export default class llamacpp_extension extends AIEngine { // NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded) // or absolute paths (if they are provided as local files) const janDataFolderPath = await getJanDataFolderPath() - let size_bytes = (await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))).size + let size_bytes = ( + await fs.fileStat(await joinPath([janDataFolderPath, modelPath])) + ).size if (mmprojPath) { - size_bytes += (await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))).size + size_bytes += ( + await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath])) + ).size } // TODO: add name as import() argument @@ -328,10 +361,10 @@ export default class llamacpp_extension extends AIEngine { name: modelId, size_bytes, } as ModelConfig - await invoke( - 'write_yaml', - { data: modelConfig, savePath: `${modelDir}/model.yml` }, - ) + await invoke('write_yaml', { + data: modelConfig, + savePath: `${modelDir}/model.yml`, + }) } override async abortImport(modelId: string): Promise { @@ -339,24 +372,62 @@ export default class llamacpp_extension extends AIEngine { const taskId = this.createDownloadTaskId(modelId) await this.downloadManager.cancelDownload(taskId) } + /** + * Function to find a random port + */ + private async getRandomPort(): Promise { + let port: number + do { + port = Math.floor(Math.random() * 1000) + 3000 + } while ( + Array.from(this.activeSessions.values()).some( + (info) => info.port === port + ) + ) + return port + } - override async load(opts: loadOptions): Promise { + override async load(modelId: string): Promise { const args: string[] = [] const cfg = this.config + const sysInfo = await window.core.api.getSystemInfo() + const [backend, version] = cfg.backend.split('-') + const exe_name = + sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server' + const backendPath = await joinPath([ + await getJanDataFolderPath(), + 'llamacpp', + 'backends', + backend, + version, + 'build', + 'bin', + exe_name, + ]) + const modelPath = await joinPath([ + this.modelsBasePath, + this.provider, + modelId, + ]) + const modelConfigPath = await joinPath([modelPath, 'model.yml']) + const modelConfig = await invoke('read_yaml', { + modelConfigPath, + }) + const port = await this.getRandomPort() // disable llama-server webui args.push('--no-webui') // update key for security; TODO: (qnixsynapse) Make it more secure - const api_key = this.generateApiKey(opts.modelPath) + const api_key = await this.generateApiKey(modelId) args.push(`--api-key ${api_key}`) // model option is required // TODO: llama.cpp extension lookup model path based on modelId - args.push('-m', opts.modelPath) - args.push('-a', opts.modelId) - args.push('--port', String(opts.port || 8080)) // Default port if not specified - if (opts.mmprojPath) { - args.push('--mmproj', opts.mmprojPath) + args.push('-m', modelConfig.model_path) + args.push('-a', modelId) + args.push('--port', String(port)) // Default port if not specified + if (modelConfig.mmproj_path) { + args.push('--mmproj', modelConfig.mmproj_path) } if (cfg.ctx_size !== undefined) { @@ -366,14 +437,16 @@ export default class llamacpp_extension extends AIEngine { // Add remaining options from the interface if (cfg.n_gpu_layers > 0) args.push('-ngl', String(cfg.n_gpu_layers)) if (cfg.threads > 0) args.push('--threads', String(cfg.threads)) - if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch)) + if (cfg.threads_batch > 0) + args.push('--threads-batch', String(cfg.threads_batch)) if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size)) if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict)) if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size)) if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size)) if (cfg.device.length > 0) args.push('--device', cfg.device) if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode) - if (cfg.main_gpu !== undefined) args.push('--main-gpu', String(cfg.main_gpu)) + if (cfg.main_gpu !== undefined) + args.push('--main-gpu', String(cfg.main_gpu)) // Boolean flags if (cfg.flash_attn) args.push('--flash-attn') @@ -396,7 +469,7 @@ export default class llamacpp_extension extends AIEngine { try { const sInfo = await invoke('load_llama_model', { - server_path: this.enginesBasePath, + backendPath: backendPath, args: args, }) @@ -545,7 +618,11 @@ export default class llamacpp_extension extends AIEngine { } override async delete(modelId: string): Promise { - const modelDir = await joinPath([this.modelsBasePath, this.provider, modelId]) + const modelDir = await joinPath([ + this.modelsBasePath, + this.provider, + modelId, + ]) if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) { throw new Error(`Model ${modelId} does not exist`) diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index 5de95f099..849a14a17 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -58,7 +58,7 @@ pub struct unloadResult { pub async fn load_llama_model( _app_handle: AppHandle, // Get the AppHandle state: State<'_, AppState>, // Access the shared state - engineBasePath: String, + backendPath: String, args: Vec, // Arguments from the frontend ) -> ServerResult { let mut process_lock = state.llama_server_process.lock().await;