refactor: move session management & port allocation to backend (#6083)
* refactor: move session management & port allocation to backend
- Remove the in‑process `activeSessions` map and its cleanup logic from the TypeScript side.
- Introduce new Tauri commands in Rust:
- `get_random_port` – picks an unused port using a seeded RNG and checks availability.
- `find_session_by_model` – returns the `SessionInfo` for a given model ID.
- `get_loaded_models` – returns a list of currently loaded model IDs.
- Update the extension’s TypeScript code to use these commands via `invoke`:
- `findSessionByModel`, `load`, `unload`, `chat`, `getLoadedModels`, and `embed` now operate asynchronously and query the backend.
- Remove the old `is_port_available` command and the custom port‑checking loop.
- Simplify `onUnload` – session termination is now handled by the backend.
- Drop unused helpers (`sleep`, `waitForModelLoad`) and related port‑availability code.
- Add missing Rust imports (`rand::{StdRng,Rng,SeedableRng}`, `HashSet`) and improve error handling.
- Register the new commands in `src-tauri/src/lib.rs` (replace `is_port_available` with the three new commands).
This refactor centralises session state and port allocation in the Rust backend, eliminates duplicated logic, and resolves race conditions around model loading and session cleanup.
* Use String(e) for error
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
1f1605bdf9
commit
6a699d8004
@ -145,7 +145,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
private config: LlamacppConfig
|
||||
private activeSessions: Map<number, SessionInfo> = new Map()
|
||||
private providerPath!: string
|
||||
private apiSecret: string = 'JustAskNow'
|
||||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||||
@ -771,16 +770,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
override async onUnload(): Promise<void> {
|
||||
// Terminate all active sessions
|
||||
for (const [_, sInfo] of this.activeSessions) {
|
||||
try {
|
||||
await this.unload(sInfo.model_id)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to unload model ${sInfo.model_id}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the sessions map
|
||||
this.activeSessions.clear()
|
||||
}
|
||||
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
@ -1104,67 +1093,13 @@ export default class llamacpp_extension extends AIEngine {
|
||||
* Function to find a random port
|
||||
*/
|
||||
private async getRandomPort(): Promise<number> {
|
||||
const MAX_ATTEMPTS = 20000
|
||||
let attempts = 0
|
||||
|
||||
while (attempts < MAX_ATTEMPTS) {
|
||||
const port = Math.floor(Math.random() * 1000) + 3000
|
||||
|
||||
const isAlreadyUsed = Array.from(this.activeSessions.values()).some(
|
||||
(info) => info.port === port
|
||||
)
|
||||
|
||||
if (!isAlreadyUsed) {
|
||||
const isAvailable = await invoke<boolean>('is_port_available', { port })
|
||||
if (isAvailable) return port
|
||||
}
|
||||
|
||||
attempts++
|
||||
try {
|
||||
const port = await invoke<number>('get_random_port')
|
||||
return port
|
||||
} catch {
|
||||
logger.error('Unable to find a suitable port')
|
||||
throw new Error('Unable to find a suitable port for model')
|
||||
}
|
||||
|
||||
throw new Error('Failed to find an available port for the model to load')
|
||||
}
|
||||
|
||||
private async sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||
}
|
||||
|
||||
private async waitForModelLoad(
|
||||
sInfo: SessionInfo,
|
||||
timeoutMs = 240_000
|
||||
): Promise<void> {
|
||||
await this.sleep(500) // Wait before first check
|
||||
const start = Date.now()
|
||||
while (Date.now() - start < timeoutMs) {
|
||||
try {
|
||||
const res = await fetch(`http://localhost:${sInfo.port}/health`)
|
||||
|
||||
if (res.status === 503) {
|
||||
const body = await res.json()
|
||||
const msg = body?.error?.message ?? 'Model loading'
|
||||
logger.info(`waiting for model load... (${msg})`)
|
||||
} else if (res.ok) {
|
||||
const body = await res.json()
|
||||
if (body.status === 'ok') {
|
||||
return
|
||||
} else {
|
||||
logger.warn('Unexpected OK response from /health:', body)
|
||||
}
|
||||
} else {
|
||||
logger.warn(`Unexpected status ${res.status} from /health`)
|
||||
}
|
||||
} catch (e) {
|
||||
await this.unload(sInfo.model_id)
|
||||
throw new Error(`Model appears to have crashed: ${e}`)
|
||||
}
|
||||
|
||||
await this.sleep(800) // Retry interval
|
||||
}
|
||||
|
||||
await this.unload(sInfo.model_id)
|
||||
throw new Error(
|
||||
`Timed out loading model after ${timeoutMs}... killing llamacpp`
|
||||
)
|
||||
}
|
||||
|
||||
override async load(
|
||||
@ -1172,7 +1107,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
overrideSettings?: Partial<LlamacppConfig>,
|
||||
isEmbedding: boolean = false
|
||||
): Promise<SessionInfo> {
|
||||
const sInfo = this.findSessionByModel(modelId)
|
||||
const sInfo = await this.findSessionByModel(modelId)
|
||||
if (sInfo) {
|
||||
throw new Error('Model already loaded!!')
|
||||
}
|
||||
@ -1342,11 +1277,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
libraryPath,
|
||||
args,
|
||||
})
|
||||
|
||||
// Store the session info for later use
|
||||
this.activeSessions.set(sInfo.pid, sInfo)
|
||||
await this.waitForModelLoad(sInfo)
|
||||
|
||||
return sInfo
|
||||
} catch (error) {
|
||||
logger.error('Error in load command:\n', error)
|
||||
@ -1355,13 +1285,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
override async unload(modelId: string): Promise<UnloadResult> {
|
||||
const sInfo: SessionInfo = this.findSessionByModel(modelId)
|
||||
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
|
||||
if (!sInfo) {
|
||||
throw new Error(`No active session found for model: ${modelId}`)
|
||||
}
|
||||
const pid = sInfo.pid
|
||||
try {
|
||||
this.activeSessions.delete(pid)
|
||||
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||||
@ -1373,13 +1302,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
logger.info(`Successfully unloaded model with PID ${pid}`)
|
||||
} else {
|
||||
logger.warn(`Failed to unload model: ${result.error}`)
|
||||
this.activeSessions.set(sInfo.pid, sInfo)
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in unload command:', error)
|
||||
this.activeSessions.set(sInfo.pid, sInfo)
|
||||
return {
|
||||
success: false,
|
||||
error: `Failed to unload model: ${error}`,
|
||||
@ -1502,17 +1429,21 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private findSessionByModel(modelId: string): SessionInfo | undefined {
|
||||
return Array.from(this.activeSessions.values()).find(
|
||||
(session) => session.model_id === modelId
|
||||
)
|
||||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||||
try {
|
||||
let sInfo = await invoke<SessionInfo>('find_session_by_model', {modelId})
|
||||
return sInfo
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
throw new Error(String(e))
|
||||
}
|
||||
}
|
||||
|
||||
override async chat(
|
||||
opts: chatCompletionRequest,
|
||||
abortController?: AbortController
|
||||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||
const sessionInfo = this.findSessionByModel(opts.model)
|
||||
const sessionInfo = await this.findSessionByModel(opts.model)
|
||||
if (!sessionInfo) {
|
||||
throw new Error(`No active session found for model: ${opts.model}`)
|
||||
}
|
||||
@ -1528,7 +1459,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
throw new Error('Model appears to have crashed! Please reload!')
|
||||
}
|
||||
} else {
|
||||
this.activeSessions.delete(sessionInfo.pid)
|
||||
throw new Error('Model have crashed! Please reload!')
|
||||
}
|
||||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||||
@ -1577,11 +1507,13 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
override async getLoadedModels(): Promise<string[]> {
|
||||
let lmodels: string[] = []
|
||||
for (const [_, sInfo] of this.activeSessions) {
|
||||
lmodels.push(sInfo.model_id)
|
||||
}
|
||||
return lmodels
|
||||
try {
|
||||
let models: string[] = await invoke<string[]>('get_loaded_models')
|
||||
return models
|
||||
} catch (e) {
|
||||
logger.error(e)
|
||||
throw new Error(e)
|
||||
}
|
||||
}
|
||||
|
||||
async getDevices(): Promise<DeviceList[]> {
|
||||
@ -1611,7 +1543,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||||
let sInfo = this.findSessionByModel('sentence-transformer-mini')
|
||||
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
|
||||
if (!sInfo) {
|
||||
const downloadedModelList = await this.list()
|
||||
if (
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::Sha256;
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
@ -724,11 +726,80 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result<
|
||||
}
|
||||
|
||||
// check port availability
|
||||
#[tauri::command]
|
||||
pub fn is_port_available(port: u16) -> bool {
|
||||
fn is_port_available(port: u16) -> bool {
|
||||
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_random_port(state: State<'_, AppState>) -> Result<u16, String> {
|
||||
const MAX_ATTEMPTS: u32 = 20000;
|
||||
let mut attempts = 0;
|
||||
let mut rng = StdRng::from_entropy();
|
||||
|
||||
// Get all active ports from sessions
|
||||
let map = state.llama_server_process.lock().await;
|
||||
|
||||
let used_ports: HashSet<u16> = map
|
||||
.values()
|
||||
.filter_map(|session| {
|
||||
// Convert valid ports to u16 (filter out placeholder ports like -1)
|
||||
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
|
||||
Some(session.info.port as u16)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
drop(map); // unlock early
|
||||
|
||||
while attempts < MAX_ATTEMPTS {
|
||||
let port = rng.gen_range(3000..4000);
|
||||
|
||||
if used_ports.contains(&port) {
|
||||
attempts += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if is_port_available(port) {
|
||||
return Ok(port);
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
}
|
||||
|
||||
Err("Failed to find an available port for the model to load".into())
|
||||
}
|
||||
|
||||
// find session
|
||||
#[tauri::command]
|
||||
pub async fn find_session_by_model(
|
||||
model_id: String,
|
||||
state: State<'_, AppState>,
|
||||
) -> Result<Option<SessionInfo>, String> {
|
||||
let map = state.llama_server_process.lock().await;
|
||||
|
||||
let session_info = map
|
||||
.values()
|
||||
.find(|backend_session| backend_session.info.model_id == model_id)
|
||||
.map(|backend_session| backend_session.info.clone());
|
||||
|
||||
Ok(session_info)
|
||||
}
|
||||
|
||||
// get running models
|
||||
#[tauri::command]
|
||||
pub async fn get_loaded_models(state: State<'_, AppState>) -> Result<Vec<String>, String> {
|
||||
let map = state.llama_server_process.lock().await;
|
||||
|
||||
let model_ids = map
|
||||
.values()
|
||||
.map(|backend_session| backend_session.info.model_id.clone())
|
||||
.collect();
|
||||
|
||||
Ok(model_ids)
|
||||
}
|
||||
|
||||
// tests
|
||||
//
|
||||
#[cfg(test)]
|
||||
|
||||
@ -95,7 +95,9 @@ pub fn run() {
|
||||
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::get_devices,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::is_port_available,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::get_random_port,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::find_session_by_model,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::get_loaded_models,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::generate_api_key,
|
||||
core::utils::extensions::inference_llamacpp_extension::server::is_process_running,
|
||||
])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user