refactor unload and implement a destructor to clean up sessions

This commit is contained in:
Akarshan Biswas 2025-05-26 11:42:18 +05:30 committed by Louis
parent cd36b423b6
commit c5a0ee7f6e
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
2 changed files with 23 additions and 37 deletions

View File

@ -76,7 +76,7 @@ export default class llamacpp_extension extends AIEngine {
private downloadManager
private activeSessions: Map<string, sessionInfo> = new Map()
private modelsBasePath!: string
private activeRequests: Map<string, AbortController> = new Map()
private enginesPath!: string
override async onLoad(): Promise<void> {
super.onLoad() // Calls registerEngine() from AIEngine
@ -89,6 +89,22 @@ export default class llamacpp_extension extends AIEngine {
await getJanDataFolderPath(),
'models',
])
this.enginesPath = await joinPath([await getJanDataFolderPath(), 'llamacpp', 'engines'])
}
override async onUnload(): Promise<void> {
// Terminate all active sessions
for (const [sessionId, _] of this.activeSessions) {
try {
await this.unload(sessionId);
} catch (error) {
console.error(`Failed to unload session ${sessionId}:`, error);
}
}
// Clear the sessions map
this.activeSessions.clear();
}
// Implement the required LocalProvider interface methods
@ -335,6 +351,7 @@ export default class llamacpp_extension extends AIEngine {
try {
const sInfo = await invoke<sessionInfo>('load_llama_model', {
server_path: this.enginesPath,
args: args,
})
@ -348,17 +365,17 @@ export default class llamacpp_extension extends AIEngine {
}
}
override async unload(opts: unloadOptions): Promise<unloadResult> {
override async unload(sessionId: string): Promise<unloadResult> {
try {
// Pass the PID as the session_id
const result = await invoke<unloadResult>('unload_llama_model', {
session_id: opts.sessionId, // Using PID as session ID
session_id: sessionId, // Using PID as session ID
})
// If successful, remove from active sessions
if (result.success) {
this.activeSessions.delete(opts.sessionId)
console.log(`Successfully unloaded model with PID ${opts.sessionId}`)
this.activeSessions.delete(sessionId)
console.log(`Successfully unloaded model with PID ${sessionId}`)
} else {
console.warn(`Failed to unload model: ${result.error}`)
}
@ -496,8 +513,4 @@ export default class llamacpp_extension extends AIEngine {
override getChatClient(sessionId: string): any {
throw new Error('method not implemented yet')
}
onUnload(): void {
throw new Error('Method not implemented.')
}
}

View File

@ -37,33 +37,6 @@ impl serde::Serialize for ServerError {
type ServerResult<T> = Result<T, ServerError>;
// --- Helper function to find the server binary ---
// -- TODO: Adjust extension engine paths
// engine: static llama-server build (CUDA, VULKAN, SYCL, etc)
fn get_server_path(app_handle: &AppHandle) -> ServerResult<PathBuf> {
let binary_name = if cfg!(windows) {
"llama-server.exe"
} else {
"llama-server"
};
let relative_path = PathBuf::from("engines").join(binary_name); // TODO: ADJUST THIS PATH
app_handle
.path()
.resolve(relative_path, BaseDirectory::Resource)
.map_err(|e| ServerError::ResourcePathError(e.to_string()))
// .ok_or_else(|| {
// ServerError::BinaryNotFound(format!(
// "Could not resolve resource path for '{}'",
// if cfg!(windows) {
// "engines/llama-server.exe"
// } else {
// "engines/llama-server"
// } // TODO: ADJUST THIS PATH
// ))
// })
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SessionInfo {
pub session_id: String, // opaque handle for unload/chat
@ -82,6 +55,7 @@ pub struct UnloadResult {
pub async fn load_llama_model(
app_handle: AppHandle, // Get the AppHandle
state: State<'_, AppState>, // Access the shared state
server_path: String,
args: Vec<String>, // Arguments from the frontend
) -> ServerResult<SessionInfo> {
let mut process_lock = state.llama_server_process.lock().await;
@ -91,7 +65,6 @@ pub async fn load_llama_model(
return Err(ServerError::AlreadyRunning);
}
let server_path = get_server_path(&app_handle)?;
log::info!("Attempting to launch server at path: {:?}", server_path);
log::info!("Using arguments: {:?}", args);