diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 91fe4dd34..92ceaad60 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -145,7 +145,6 @@ export default class llamacpp_extension extends AIEngine { readonly providerId: string = 'llamacpp' private config: LlamacppConfig - private activeSessions: Map = new Map() private providerPath!: string private apiSecret: string = 'JustAskNow' private pendingDownloads: Map> = new Map() @@ -771,16 +770,6 @@ export default class llamacpp_extension extends AIEngine { override async onUnload(): Promise { // Terminate all active sessions - for (const [_, sInfo] of this.activeSessions) { - try { - await this.unload(sInfo.model_id) - } catch (error) { - logger.error(`Failed to unload model ${sInfo.model_id}:`, error) - } - } - - // Clear the sessions map - this.activeSessions.clear() } onSettingUpdate(key: string, value: T): void { @@ -1104,67 +1093,13 @@ export default class llamacpp_extension extends AIEngine { * Function to find a random port */ private async getRandomPort(): Promise { - const MAX_ATTEMPTS = 20000 - let attempts = 0 - - while (attempts < MAX_ATTEMPTS) { - const port = Math.floor(Math.random() * 1000) + 3000 - - const isAlreadyUsed = Array.from(this.activeSessions.values()).some( - (info) => info.port === port - ) - - if (!isAlreadyUsed) { - const isAvailable = await invoke('is_port_available', { port }) - if (isAvailable) return port - } - - attempts++ + try { + const port = await invoke('get_random_port') + return port + } catch { + logger.error('Unable to find a suitable port') + throw new Error('Unable to find a suitable port for model') } - - throw new Error('Failed to find an available port for the model to load') - } - - private async sleep(ms: number): Promise { - return new Promise((resolve) => setTimeout(resolve, ms)) - } - - private async waitForModelLoad( - sInfo: SessionInfo, - timeoutMs = 240_000 - ): Promise { - await this.sleep(500) // Wait before first check - const start = Date.now() - while (Date.now() - start < timeoutMs) { - try { - const res = await fetch(`http://localhost:${sInfo.port}/health`) - - if (res.status === 503) { - const body = await res.json() - const msg = body?.error?.message ?? 'Model loading' - logger.info(`waiting for model load... (${msg})`) - } else if (res.ok) { - const body = await res.json() - if (body.status === 'ok') { - return - } else { - logger.warn('Unexpected OK response from /health:', body) - } - } else { - logger.warn(`Unexpected status ${res.status} from /health`) - } - } catch (e) { - await this.unload(sInfo.model_id) - throw new Error(`Model appears to have crashed: ${e}`) - } - - await this.sleep(800) // Retry interval - } - - await this.unload(sInfo.model_id) - throw new Error( - `Timed out loading model after ${timeoutMs}... killing llamacpp` - ) } override async load( @@ -1172,7 +1107,7 @@ export default class llamacpp_extension extends AIEngine { overrideSettings?: Partial, isEmbedding: boolean = false ): Promise { - const sInfo = this.findSessionByModel(modelId) + const sInfo = await this.findSessionByModel(modelId) if (sInfo) { throw new Error('Model already loaded!!') } @@ -1342,11 +1277,6 @@ export default class llamacpp_extension extends AIEngine { libraryPath, args, }) - - // Store the session info for later use - this.activeSessions.set(sInfo.pid, sInfo) - await this.waitForModelLoad(sInfo) - return sInfo } catch (error) { logger.error('Error in load command:\n', error) @@ -1355,13 +1285,12 @@ export default class llamacpp_extension extends AIEngine { } override async unload(modelId: string): Promise { - const sInfo: SessionInfo = this.findSessionByModel(modelId) + const sInfo: SessionInfo = await this.findSessionByModel(modelId) if (!sInfo) { throw new Error(`No active session found for model: ${modelId}`) } const pid = sInfo.pid try { - this.activeSessions.delete(pid) // Pass the PID as the session_id const result = await invoke('unload_llama_model', { @@ -1373,13 +1302,11 @@ export default class llamacpp_extension extends AIEngine { logger.info(`Successfully unloaded model with PID ${pid}`) } else { logger.warn(`Failed to unload model: ${result.error}`) - this.activeSessions.set(sInfo.pid, sInfo) } return result } catch (error) { logger.error('Error in unload command:', error) - this.activeSessions.set(sInfo.pid, sInfo) return { success: false, error: `Failed to unload model: ${error}`, @@ -1502,17 +1429,21 @@ export default class llamacpp_extension extends AIEngine { } } - private findSessionByModel(modelId: string): SessionInfo | undefined { - return Array.from(this.activeSessions.values()).find( - (session) => session.model_id === modelId - ) + private async findSessionByModel(modelId: string): Promise { + try { + let sInfo = await invoke('find_session_by_model', {modelId}) + return sInfo + } catch (e) { + logger.error(e) + throw new Error(String(e)) + } } override async chat( opts: chatCompletionRequest, abortController?: AbortController ): Promise> { - const sessionInfo = this.findSessionByModel(opts.model) + const sessionInfo = await this.findSessionByModel(opts.model) if (!sessionInfo) { throw new Error(`No active session found for model: ${opts.model}`) } @@ -1528,7 +1459,6 @@ export default class llamacpp_extension extends AIEngine { throw new Error('Model appears to have crashed! Please reload!') } } else { - this.activeSessions.delete(sessionInfo.pid) throw new Error('Model have crashed! Please reload!') } const baseUrl = `http://localhost:${sessionInfo.port}/v1` @@ -1577,11 +1507,13 @@ export default class llamacpp_extension extends AIEngine { } override async getLoadedModels(): Promise { - let lmodels: string[] = [] - for (const [_, sInfo] of this.activeSessions) { - lmodels.push(sInfo.model_id) - } - return lmodels + try { + let models: string[] = await invoke('get_loaded_models') + return models + } catch (e) { + logger.error(e) + throw new Error(e) + } } async getDevices(): Promise { @@ -1611,7 +1543,7 @@ export default class llamacpp_extension extends AIEngine { } async embed(text: string[]): Promise { - let sInfo = this.findSessionByModel('sentence-transformer-mini') + let sInfo = await this.findSessionByModel('sentence-transformer-mini') if (!sInfo) { const downloadedModelList = await this.list() if ( diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index 191eb5c6e..b95e17010 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -1,7 +1,9 @@ use base64::{engine::general_purpose, Engine as _}; use hmac::{Hmac, Mac}; +use rand::{rngs::StdRng, Rng, SeedableRng}; use serde::{Deserialize, Serialize}; use sha2::Sha256; +use std::collections::HashSet; use std::path::PathBuf; use std::process::Stdio; use std::time::Duration; @@ -724,11 +726,80 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result< } // check port availability -#[tauri::command] -pub fn is_port_available(port: u16) -> bool { +fn is_port_available(port: u16) -> bool { std::net::TcpListener::bind(("127.0.0.1", port)).is_ok() } +#[tauri::command] +pub async fn get_random_port(state: State<'_, AppState>) -> Result { + const MAX_ATTEMPTS: u32 = 20000; + let mut attempts = 0; + let mut rng = StdRng::from_entropy(); + + // Get all active ports from sessions + let map = state.llama_server_process.lock().await; + + let used_ports: HashSet = map + .values() + .filter_map(|session| { + // Convert valid ports to u16 (filter out placeholder ports like -1) + if session.info.port > 0 && session.info.port <= u16::MAX as i32 { + Some(session.info.port as u16) + } else { + None + } + }) + .collect(); + + drop(map); // unlock early + + while attempts < MAX_ATTEMPTS { + let port = rng.gen_range(3000..4000); + + if used_ports.contains(&port) { + attempts += 1; + continue; + } + + if is_port_available(port) { + return Ok(port); + } + + attempts += 1; + } + + Err("Failed to find an available port for the model to load".into()) +} + +// find session +#[tauri::command] +pub async fn find_session_by_model( + model_id: String, + state: State<'_, AppState>, +) -> Result, String> { + let map = state.llama_server_process.lock().await; + + let session_info = map + .values() + .find(|backend_session| backend_session.info.model_id == model_id) + .map(|backend_session| backend_session.info.clone()); + + Ok(session_info) +} + +// get running models +#[tauri::command] +pub async fn get_loaded_models(state: State<'_, AppState>) -> Result, String> { + let map = state.llama_server_process.lock().await; + + let model_ids = map + .values() + .map(|backend_session| backend_session.info.model_id.clone()) + .collect(); + + Ok(model_ids) +} + // tests // #[cfg(test)] diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 0332be3ef..e449fc739 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -95,7 +95,9 @@ pub fn run() { core::utils::extensions::inference_llamacpp_extension::server::load_llama_model, core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model, core::utils::extensions::inference_llamacpp_extension::server::get_devices, - core::utils::extensions::inference_llamacpp_extension::server::is_port_available, + core::utils::extensions::inference_llamacpp_extension::server::get_random_port, + core::utils::extensions::inference_llamacpp_extension::server::find_session_by_model, + core::utils::extensions::inference_llamacpp_extension::server::get_loaded_models, core::utils::extensions::inference_llamacpp_extension::server::generate_api_key, core::utils::extensions::inference_llamacpp_extension::server::is_process_running, ])