refactor: move session management & port allocation to backend (#6083)
* refactor: move session management & port allocation to backend
- Remove the in‑process `activeSessions` map and its cleanup logic from the TypeScript side.
- Introduce new Tauri commands in Rust:
- `get_random_port` – picks an unused port using a seeded RNG and checks availability.
- `find_session_by_model` – returns the `SessionInfo` for a given model ID.
- `get_loaded_models` – returns a list of currently loaded model IDs.
- Update the extension’s TypeScript code to use these commands via `invoke`:
- `findSessionByModel`, `load`, `unload`, `chat`, `getLoadedModels`, and `embed` now operate asynchronously and query the backend.
- Remove the old `is_port_available` command and the custom port‑checking loop.
- Simplify `onUnload` – session termination is now handled by the backend.
- Drop unused helpers (`sleep`, `waitForModelLoad`) and related port‑availability code.
- Add missing Rust imports (`rand::{StdRng,Rng,SeedableRng}`, `HashSet`) and improve error handling.
- Register the new commands in `src-tauri/src/lib.rs` (replace `is_port_available` with the three new commands).
This refactor centralises session state and port allocation in the Rust backend, eliminates duplicated logic, and resolves race conditions around model loading and session cleanup.
* Use String(e) for error
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
This commit is contained in:
parent
1f1605bdf9
commit
6a699d8004
@ -145,7 +145,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
readonly providerId: string = 'llamacpp'
|
readonly providerId: string = 'llamacpp'
|
||||||
|
|
||||||
private config: LlamacppConfig
|
private config: LlamacppConfig
|
||||||
private activeSessions: Map<number, SessionInfo> = new Map()
|
|
||||||
private providerPath!: string
|
private providerPath!: string
|
||||||
private apiSecret: string = 'JustAskNow'
|
private apiSecret: string = 'JustAskNow'
|
||||||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||||||
@ -771,16 +770,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
override async onUnload(): Promise<void> {
|
override async onUnload(): Promise<void> {
|
||||||
// Terminate all active sessions
|
// Terminate all active sessions
|
||||||
for (const [_, sInfo] of this.activeSessions) {
|
|
||||||
try {
|
|
||||||
await this.unload(sInfo.model_id)
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(`Failed to unload model ${sInfo.model_id}:`, error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the sessions map
|
|
||||||
this.activeSessions.clear()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onSettingUpdate<T>(key: string, value: T): void {
|
onSettingUpdate<T>(key: string, value: T): void {
|
||||||
@ -1104,67 +1093,13 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
* Function to find a random port
|
* Function to find a random port
|
||||||
*/
|
*/
|
||||||
private async getRandomPort(): Promise<number> {
|
private async getRandomPort(): Promise<number> {
|
||||||
const MAX_ATTEMPTS = 20000
|
try {
|
||||||
let attempts = 0
|
const port = await invoke<number>('get_random_port')
|
||||||
|
return port
|
||||||
while (attempts < MAX_ATTEMPTS) {
|
} catch {
|
||||||
const port = Math.floor(Math.random() * 1000) + 3000
|
logger.error('Unable to find a suitable port')
|
||||||
|
throw new Error('Unable to find a suitable port for model')
|
||||||
const isAlreadyUsed = Array.from(this.activeSessions.values()).some(
|
|
||||||
(info) => info.port === port
|
|
||||||
)
|
|
||||||
|
|
||||||
if (!isAlreadyUsed) {
|
|
||||||
const isAvailable = await invoke<boolean>('is_port_available', { port })
|
|
||||||
if (isAvailable) return port
|
|
||||||
}
|
|
||||||
|
|
||||||
attempts++
|
|
||||||
}
|
}
|
||||||
|
|
||||||
throw new Error('Failed to find an available port for the model to load')
|
|
||||||
}
|
|
||||||
|
|
||||||
private async sleep(ms: number): Promise<void> {
|
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
|
||||||
}
|
|
||||||
|
|
||||||
private async waitForModelLoad(
|
|
||||||
sInfo: SessionInfo,
|
|
||||||
timeoutMs = 240_000
|
|
||||||
): Promise<void> {
|
|
||||||
await this.sleep(500) // Wait before first check
|
|
||||||
const start = Date.now()
|
|
||||||
while (Date.now() - start < timeoutMs) {
|
|
||||||
try {
|
|
||||||
const res = await fetch(`http://localhost:${sInfo.port}/health`)
|
|
||||||
|
|
||||||
if (res.status === 503) {
|
|
||||||
const body = await res.json()
|
|
||||||
const msg = body?.error?.message ?? 'Model loading'
|
|
||||||
logger.info(`waiting for model load... (${msg})`)
|
|
||||||
} else if (res.ok) {
|
|
||||||
const body = await res.json()
|
|
||||||
if (body.status === 'ok') {
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
logger.warn('Unexpected OK response from /health:', body)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.warn(`Unexpected status ${res.status} from /health`)
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
await this.unload(sInfo.model_id)
|
|
||||||
throw new Error(`Model appears to have crashed: ${e}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.sleep(800) // Retry interval
|
|
||||||
}
|
|
||||||
|
|
||||||
await this.unload(sInfo.model_id)
|
|
||||||
throw new Error(
|
|
||||||
`Timed out loading model after ${timeoutMs}... killing llamacpp`
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async load(
|
override async load(
|
||||||
@ -1172,7 +1107,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
overrideSettings?: Partial<LlamacppConfig>,
|
overrideSettings?: Partial<LlamacppConfig>,
|
||||||
isEmbedding: boolean = false
|
isEmbedding: boolean = false
|
||||||
): Promise<SessionInfo> {
|
): Promise<SessionInfo> {
|
||||||
const sInfo = this.findSessionByModel(modelId)
|
const sInfo = await this.findSessionByModel(modelId)
|
||||||
if (sInfo) {
|
if (sInfo) {
|
||||||
throw new Error('Model already loaded!!')
|
throw new Error('Model already loaded!!')
|
||||||
}
|
}
|
||||||
@ -1342,11 +1277,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
libraryPath,
|
libraryPath,
|
||||||
args,
|
args,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Store the session info for later use
|
|
||||||
this.activeSessions.set(sInfo.pid, sInfo)
|
|
||||||
await this.waitForModelLoad(sInfo)
|
|
||||||
|
|
||||||
return sInfo
|
return sInfo
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error in load command:\n', error)
|
logger.error('Error in load command:\n', error)
|
||||||
@ -1355,13 +1285,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override async unload(modelId: string): Promise<UnloadResult> {
|
override async unload(modelId: string): Promise<UnloadResult> {
|
||||||
const sInfo: SessionInfo = this.findSessionByModel(modelId)
|
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
|
||||||
if (!sInfo) {
|
if (!sInfo) {
|
||||||
throw new Error(`No active session found for model: ${modelId}`)
|
throw new Error(`No active session found for model: ${modelId}`)
|
||||||
}
|
}
|
||||||
const pid = sInfo.pid
|
const pid = sInfo.pid
|
||||||
try {
|
try {
|
||||||
this.activeSessions.delete(pid)
|
|
||||||
|
|
||||||
// Pass the PID as the session_id
|
// Pass the PID as the session_id
|
||||||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||||||
@ -1373,13 +1302,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
logger.info(`Successfully unloaded model with PID ${pid}`)
|
logger.info(`Successfully unloaded model with PID ${pid}`)
|
||||||
} else {
|
} else {
|
||||||
logger.warn(`Failed to unload model: ${result.error}`)
|
logger.warn(`Failed to unload model: ${result.error}`)
|
||||||
this.activeSessions.set(sInfo.pid, sInfo)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error in unload command:', error)
|
logger.error('Error in unload command:', error)
|
||||||
this.activeSessions.set(sInfo.pid, sInfo)
|
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: `Failed to unload model: ${error}`,
|
error: `Failed to unload model: ${error}`,
|
||||||
@ -1502,17 +1429,21 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private findSessionByModel(modelId: string): SessionInfo | undefined {
|
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||||||
return Array.from(this.activeSessions.values()).find(
|
try {
|
||||||
(session) => session.model_id === modelId
|
let sInfo = await invoke<SessionInfo>('find_session_by_model', {modelId})
|
||||||
)
|
return sInfo
|
||||||
|
} catch (e) {
|
||||||
|
logger.error(e)
|
||||||
|
throw new Error(String(e))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override async chat(
|
override async chat(
|
||||||
opts: chatCompletionRequest,
|
opts: chatCompletionRequest,
|
||||||
abortController?: AbortController
|
abortController?: AbortController
|
||||||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||||
const sessionInfo = this.findSessionByModel(opts.model)
|
const sessionInfo = await this.findSessionByModel(opts.model)
|
||||||
if (!sessionInfo) {
|
if (!sessionInfo) {
|
||||||
throw new Error(`No active session found for model: ${opts.model}`)
|
throw new Error(`No active session found for model: ${opts.model}`)
|
||||||
}
|
}
|
||||||
@ -1528,7 +1459,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
throw new Error('Model appears to have crashed! Please reload!')
|
throw new Error('Model appears to have crashed! Please reload!')
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
this.activeSessions.delete(sessionInfo.pid)
|
|
||||||
throw new Error('Model have crashed! Please reload!')
|
throw new Error('Model have crashed! Please reload!')
|
||||||
}
|
}
|
||||||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||||||
@ -1577,11 +1507,13 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override async getLoadedModels(): Promise<string[]> {
|
override async getLoadedModels(): Promise<string[]> {
|
||||||
let lmodels: string[] = []
|
try {
|
||||||
for (const [_, sInfo] of this.activeSessions) {
|
let models: string[] = await invoke<string[]>('get_loaded_models')
|
||||||
lmodels.push(sInfo.model_id)
|
return models
|
||||||
}
|
} catch (e) {
|
||||||
return lmodels
|
logger.error(e)
|
||||||
|
throw new Error(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async getDevices(): Promise<DeviceList[]> {
|
async getDevices(): Promise<DeviceList[]> {
|
||||||
@ -1611,7 +1543,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async embed(text: string[]): Promise<EmbeddingResponse> {
|
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||||||
let sInfo = this.findSessionByModel('sentence-transformer-mini')
|
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
|
||||||
if (!sInfo) {
|
if (!sInfo) {
|
||||||
const downloadedModelList = await this.list()
|
const downloadedModelList = await this.list()
|
||||||
if (
|
if (
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
|
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sha2::Sha256;
|
use sha2::Sha256;
|
||||||
|
use std::collections::HashSet;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
@ -724,11 +726,80 @@ pub async fn is_process_running(pid: i32, state: State<'_, AppState>) -> Result<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check port availability
|
// check port availability
|
||||||
#[tauri::command]
|
fn is_port_available(port: u16) -> bool {
|
||||||
pub fn is_port_available(port: u16) -> bool {
|
|
||||||
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
|
std::net::TcpListener::bind(("127.0.0.1", port)).is_ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn get_random_port(state: State<'_, AppState>) -> Result<u16, String> {
|
||||||
|
const MAX_ATTEMPTS: u32 = 20000;
|
||||||
|
let mut attempts = 0;
|
||||||
|
let mut rng = StdRng::from_entropy();
|
||||||
|
|
||||||
|
// Get all active ports from sessions
|
||||||
|
let map = state.llama_server_process.lock().await;
|
||||||
|
|
||||||
|
let used_ports: HashSet<u16> = map
|
||||||
|
.values()
|
||||||
|
.filter_map(|session| {
|
||||||
|
// Convert valid ports to u16 (filter out placeholder ports like -1)
|
||||||
|
if session.info.port > 0 && session.info.port <= u16::MAX as i32 {
|
||||||
|
Some(session.info.port as u16)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
drop(map); // unlock early
|
||||||
|
|
||||||
|
while attempts < MAX_ATTEMPTS {
|
||||||
|
let port = rng.gen_range(3000..4000);
|
||||||
|
|
||||||
|
if used_ports.contains(&port) {
|
||||||
|
attempts += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_port_available(port) {
|
||||||
|
return Ok(port);
|
||||||
|
}
|
||||||
|
|
||||||
|
attempts += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Err("Failed to find an available port for the model to load".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
// find session
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn find_session_by_model(
|
||||||
|
model_id: String,
|
||||||
|
state: State<'_, AppState>,
|
||||||
|
) -> Result<Option<SessionInfo>, String> {
|
||||||
|
let map = state.llama_server_process.lock().await;
|
||||||
|
|
||||||
|
let session_info = map
|
||||||
|
.values()
|
||||||
|
.find(|backend_session| backend_session.info.model_id == model_id)
|
||||||
|
.map(|backend_session| backend_session.info.clone());
|
||||||
|
|
||||||
|
Ok(session_info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get running models
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn get_loaded_models(state: State<'_, AppState>) -> Result<Vec<String>, String> {
|
||||||
|
let map = state.llama_server_process.lock().await;
|
||||||
|
|
||||||
|
let model_ids = map
|
||||||
|
.values()
|
||||||
|
.map(|backend_session| backend_session.info.model_id.clone())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(model_ids)
|
||||||
|
}
|
||||||
|
|
||||||
// tests
|
// tests
|
||||||
//
|
//
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@ -95,7 +95,9 @@ pub fn run() {
|
|||||||
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
|
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
|
||||||
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
|
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
|
||||||
core::utils::extensions::inference_llamacpp_extension::server::get_devices,
|
core::utils::extensions::inference_llamacpp_extension::server::get_devices,
|
||||||
core::utils::extensions::inference_llamacpp_extension::server::is_port_available,
|
core::utils::extensions::inference_llamacpp_extension::server::get_random_port,
|
||||||
|
core::utils::extensions::inference_llamacpp_extension::server::find_session_by_model,
|
||||||
|
core::utils::extensions::inference_llamacpp_extension::server::get_loaded_models,
|
||||||
core::utils::extensions::inference_llamacpp_extension::server::generate_api_key,
|
core::utils::extensions::inference_llamacpp_extension::server::generate_api_key,
|
||||||
core::utils::extensions::inference_llamacpp_extension::server::is_process_running,
|
core::utils::extensions::inference_llamacpp_extension::server::is_process_running,
|
||||||
])
|
])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user