refactor: move session management & port allocation to backend (#6083)

* refactor: move session management & port allocation to backend

- Remove the in‑process `activeSessions` map and its cleanup logic from the TypeScript side.
- Introduce new Tauri commands in Rust:
  - `get_random_port` – picks an unused port using a seeded RNG and checks availability.
  - `find_session_by_model` – returns the `SessionInfo` for a given model ID.
  - `get_loaded_models` – returns a list of currently loaded model IDs.
- Update the extension’s TypeScript code to use these commands via `invoke`:
  - `findSessionByModel`, `load`, `unload`, `chat`, `getLoadedModels`, and `embed` now operate asynchronously and query the backend.
  - Remove the old `is_port_available` command and the custom port‑checking loop.
  - Simplify `onUnload` – session termination is now handled by the backend.
- Drop unused helpers (`sleep`, `waitForModelLoad`) and related port‑availability code.
- Add missing Rust imports (`rand::{StdRng,Rng,SeedableRng}`, `HashSet`) and improve error handling.
- Register the new commands in `src-tauri/src/lib.rs` (replace `is_port_available` with the three new commands).

This refactor centralises session state and port allocation in the Rust backend, eliminates duplicated logic, and resolves race conditions around model loading and session cleanup.

* Use String(e) for error

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>

---------

Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
Akarshan Biswas 2025-08-07 13:06:21 +05:30 committed by GitHub
parent 1f1605bdf9
commit 6a699d8004
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 101 additions and 96 deletions

View File

@ -145,7 +145,6 @@ export default class llamacpp_extension extends AIEngine {
readonly providerId: string = 'llamacpp'
private config: LlamacppConfig
private activeSessions: Map<number, SessionInfo> = new Map()
private providerPath!: string
private apiSecret: string = 'JustAskNow'
private pendingDownloads: Map<string, Promise<void>> = new Map()
@ -771,16 +770,6 @@ export default class llamacpp_extension extends AIEngine {
override async onUnload(): Promise<void> {
// Terminate all active sessions
for (const [_, sInfo] of this.activeSessions) {
try {
await this.unload(sInfo.model_id)
} catch (error) {
logger.error(`Failed to unload model ${sInfo.model_id}:`, error)
}
}
// Clear the sessions map
this.activeSessions.clear()
}
onSettingUpdate<T>(key: string, value: T): void {
@ -1104,67 +1093,13 @@ export default class llamacpp_extension extends AIEngine {
* Function to find a random port
*/
private async getRandomPort(): Promise<number> {
const MAX_ATTEMPTS = 20000
let attempts = 0
while (attempts < MAX_ATTEMPTS) {
const port = Math.floor(Math.random() * 1000) + 3000
const isAlreadyUsed = Array.from(this.activeSessions.values()).some(
(info) => info.port === port
)
if (!isAlreadyUsed) {
const isAvailable = await invoke<boolean>('is_port_available', { port })
if (isAvailable) return port
}
attempts++
try {
const port = await invoke<number>('get_random_port')
return port
} catch {
logger.error('Unable to find a suitable port')
throw new Error('Unable to find a suitable port for model')
}
throw new Error('Failed to find an available port for the model to load')
}
private async sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
private async waitForModelLoad(
sInfo: SessionInfo,
timeoutMs = 240_000
): Promise<void> {
await this.sleep(500) // Wait before first check
const start = Date.now()
while (Date.now() - start < timeoutMs) {
try {
const res = await fetch(`http://localhost:${sInfo.port}/health`)
if (res.status === 503) {
const body = await res.json()
const msg = body?.error?.message ?? 'Model loading'
logger.info(`waiting for model load... (${msg})`)
} else if (res.ok) {
const body = await res.json()
if (body.status === 'ok') {
return
} else {
logger.warn('Unexpected OK response from /health:', body)
}
} else {
logger.warn(`Unexpected status ${res.status} from /health`)
}
} catch (e) {
await this.unload(sInfo.model_id)
throw new Error(`Model appears to have crashed: ${e}`)
}
await this.sleep(800) // Retry interval
}
await this.unload(sInfo.model_id)
throw new Error(
`Timed out loading model after ${timeoutMs}... killing llamacpp`
)
}
override async load(
@ -1172,7 +1107,7 @@ export default class llamacpp_extension extends AIEngine {
overrideSettings?: Partial<LlamacppConfig>,
isEmbedding: boolean = false
): Promise<SessionInfo> {
const sInfo = this.findSessionByModel(modelId)
const sInfo = await this.findSessionByModel(modelId)
if (sInfo) {
throw new Error('Model already loaded!!')
}
@ -1342,11 +1277,6 @@ export default class llamacpp_extension extends AIEngine {
libraryPath,
args,
})
// Store the session info for later use
this.activeSessions.set(sInfo.pid, sInfo)
await this.waitForModelLoad(sInfo)
return sInfo
} catch (error) {
logger.error('Error in load command:\n', error)
@ -1355,13 +1285,12 @@ export default class llamacpp_extension extends AIEngine {
}
override async unload(modelId: string): Promise<UnloadResult> {
const sInfo: SessionInfo = this.findSessionByModel(modelId)
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
if (!sInfo) {
throw new Error(`No active session found for model: ${modelId}`)
}
const pid = sInfo.pid
try {
this.activeSessions.delete(pid)
// Pass the PID as the session_id
const result = await invoke<UnloadResult>('unload_llama_model', {
@ -1373,13 +1302,11 @@ export default class llamacpp_extension extends AIEngine {
logger.info(`Successfully unloaded model with PID ${pid}`)
} else {
logger.warn(`Failed to unload model: ${result.error}`)
this.activeSessions.set(sInfo.pid, sInfo)
}
return result
} catch (error) {
logger.error('Error in unload command:', error)
this.activeSessions.set(sInfo.pid, sInfo)
return {
success: false,
error: `Failed to unload model: ${error}`,
@ -1502,17 +1429,21 @@ export default class llamacpp_extension extends AIEngine {
}
}
private findSessionByModel(modelId: string): SessionInfo | undefined {
return Array.from(this.activeSessions.values()).find(
(session) => session.model_id === modelId
)
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
try {
let sInfo = await invoke<SessionInfo>('find_session_by_model', {modelId})
return sInfo
} catch (e) {
logger.error(e)
throw new Error(String(e))
}
}
override async chat(
opts: chatCompletionRequest,
abortController?: AbortController
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
const sessionInfo = this.findSessionByModel(opts.model)
const sessionInfo = await this.findSessionByModel(opts.model)
if (!sessionInfo) {
throw new Error(`No active session found for model: ${opts.model}`)
}
@ -1528,7 +1459,6 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('Model appears to have crashed! Please reload!')
}
} else {
this.activeSessions.delete(sessionInfo.pid)
throw new Error('Model have crashed! Please reload!')
}
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
@ -1577,11 +1507,13 @@ export default class llamacpp_extension extends AIEngine {
}
override async getLoadedModels(): Promise<string[]> {
let lmodels: string[] = []
for (const [_, sInfo] of this.activeSessions) {
lmodels.push(sInfo.model_id)
}
return lmodels
try {
let models: string[] = await invoke<string[]>('get_loaded_models')
return models
} catch (e) {
logger.error(e)
throw new Error(e)
}
}
async getDevices(): Promise<DeviceList[]> {
@ -1611,7 +1543,7 @@ export default class llamacpp_extension extends AIEngine {
}
async embed(text: string[]): Promise<EmbeddingResponse> {
let sInfo = this.findSessionByModel('sentence-transformer-mini')
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
if (!sInfo) {
const downloadedModelList = await this.list()
if (

View File

@ -1,7 +1,9 @@
use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use rand::{rngs::StdRng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::collections::HashSet;
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
@ -724,11 +726,80 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result<
}
// check port availability
#[tauri::command]
pub fn is_port_available(port: u16) -> bool {
fn is_port_available(port: u16) -> bool {
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
}
#[tauri::command]
pub async fn get_random_port(state: State<'_, AppState>) -> Result<u16, String> {
const MAX_ATTEMPTS: u32 = 20000;
let mut attempts = 0;
let mut rng = StdRng::from_entropy();
// Get all active ports from sessions
let map = state.llama_server_process.lock().await;
let used_ports: HashSet<u16> = map
.values()
.filter_map(|session| {
// Convert valid ports to u16 (filter out placeholder ports like -1)
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
Some(session.info.port as u16)
} else {
None
}
})
.collect();
drop(map); // unlock early
while attempts < MAX_ATTEMPTS {
let port = rng.gen_range(3000..4000);
if used_ports.contains(&port) {
attempts += 1;
continue;
}
if is_port_available(port) {
return Ok(port);
}
attempts += 1;
}
Err("Failed to find an available port for the model to load".into())
}
// find session
#[tauri::command]
pub async fn find_session_by_model(
model_id: String,
state: State<'_, AppState>,
) -> Result<Option<SessionInfo>, String> {
let map = state.llama_server_process.lock().await;
let session_info = map
.values()
.find(|backend_session| backend_session.info.model_id == model_id)
.map(|backend_session| backend_session.info.clone());
Ok(session_info)
}
// get running models
#[tauri::command]
pub async fn get_loaded_models(state: State<'_, AppState>) -> Result<Vec<String>, String> {
let map = state.llama_server_process.lock().await;
let model_ids = map
.values()
.map(|backend_session| backend_session.info.model_id.clone())
.collect();
Ok(model_ids)
}
// tests
//
#[cfg(test)]

View File

@ -95,7 +95,9 @@ pub fn run() {
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
core::utils::extensions::inference_llamacpp_extension::server::get_devices,
core::utils::extensions::inference_llamacpp_extension::server::is_port_available,
core::utils::extensions::inference_llamacpp_extension::server::get_random_port,
core::utils::extensions::inference_llamacpp_extension::server::find_session_by_model,
core::utils::extensions::inference_llamacpp_extension::server::get_loaded_models,
core::utils::extensions::inference_llamacpp_extension::server::generate_api_key,
core::utils::extensions::inference_llamacpp_extension::server::is_process_running,
])