From bbbf4779dfdb8ba8f6fa290df2176aa44d50bb41 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 20 May 2025 12:39:18 +0530 Subject: [PATCH] refactor load/unload --- .../browser/extensions/engines/AIEngine.ts | 30 -- extensions/llamacpp-extension/src/index.ts | 439 ++++++++---------- extensions/llamacpp-extension/src/types.ts | 78 ++-- .../inference_llamacpp_extension/server.rs | 109 +++-- 4 files changed, 306 insertions(+), 350 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 4f96eb93a..b5261b4e0 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -16,9 +16,6 @@ export abstract class AIEngine extends BaseExtension { */ override onLoad() { this.registerEngine() - - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) - events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) } /** @@ -27,31 +24,4 @@ export abstract class AIEngine extends BaseExtension { registerEngine() { EngineManager.instance().register(this) } - - /** - * Loads the model. - */ - async loadModel(model: Partial, abortController?: AbortController): Promise { - if (model?.engine?.toString() !== this.provider) return Promise.resolve() - events.emit(ModelEvent.OnModelReady, model) - return Promise.resolve() - } - /** - * Stops the model. - */ - async unloadModel(model?: Partial): Promise { - if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve() - events.emit(ModelEvent.OnModelStopped, model ?? {}) - return Promise.resolve() - } - - /** - * Inference request - */ - inference(data: MessageRequest) {} - - /** - * Stop inference - */ - stopInference() {} } diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index bb1ae6b58..68a2143c6 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -6,35 +6,30 @@ * @module llamacpp-extension/src/index */ -import { - AIEngine, - getJanDataFolderPath, - fs, - Model, -} from '@janhq/core' +import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core' import { invoke } from '@tauri-apps/api/tauri' import { - LocalProvider, - ModelInfo, - ListOptions, - ListResult, - PullOptions, - PullResult, - LoadOptions, - SessionInfo, - UnloadOptions, - UnloadResult, - ChatOptions, - ChatCompletion, - ChatCompletionChunk, - DeleteOptions, - DeleteResult, - ImportOptions, - ImportResult, - AbortPullOptions, - AbortPullResult, - ChatCompletionRequest, + localProvider, + modelInfo, + listOptions, + listResult, + pullOptions, + pullResult, + loadOptions, + sessionInfo, + unloadOptions, + unloadResult, + chatOptions, + chatCompletion, + chatCompletionChunk, + deleteOptions, + deleteResult, + importOptions, + importResult, + abortPullOptions, + abortPullResult, + chatCompletionRequest, } from './types' /** @@ -61,246 +56,224 @@ function parseGGUFFileName(filename: string): { * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class inference_llamacpp_extension +export default class llamacpp_extension extends AIEngine - implements LocalProvider + implements localProvider { provider: string = 'llamacpp' - readonly providerId: string = 'llamcpp' - - private activeSessions: Map = new Map() + readonly providerId: string = 'llamacpp' + private activeSessions: Map = new Map() private modelsBasePath!: string + private activeRequests: Map = new Map() override async onLoad(): Promise { super.onLoad() // Calls registerEngine() from AIEngine - this.registerSettings(SETTINGS_DEFINITIONS) + this.registerSettings(SETTINGS) - const customPath = await this.getSetting( - LlamaCppSettings.ModelsPath, - '' - ) - if (customPath && (await fs.exists(customPath))) { - this.modelsBasePath = customPath + // Initialize models base path - assuming this would be retrieved from settings + this.modelsBasePath = await joinPath([ + await getJanDataFolderPath(), + 'models', + ]) + } + + // Implement the required LocalProvider interface methods + async list(opts: listOptions): Promise { + throw new Error('method not implemented yet') + } + + async pull(opts: pullOptions): Promise { + throw new Error('method not implemented yet') + } + + async load(opts: loadOptions): Promise { + const args: string[] = [] + + // model option is required + args.push('-m', opts.modelPath) + args.push('--port', String(opts.port || 8080)) // Default port if not specified + + if (opts.n_gpu_layers === undefined) { + // in case of CPU only build, this option will be ignored + args.push('-ngl', '99') } else { - this.modelsBasePath = await path.join( - await getJanDataFolderPath(), - 'models', - ENGINE_ID - ) + args.push('-ngl', String(opts.n_gpu_layers)) } - await fs.createDirAll(this.modelsBasePath) - console.log( - `${this.providerId} provider loaded. Models path: ${this.modelsBasePath}` - ) - - // Optionally, list and register models with the core system if AIEngine expects it - // const models = await this.listModels({ providerId: this.providerId }); - // this.registerModels(this.mapModelInfoToCoreModel(models)); // mapModelInfoToCoreModel would be a helper - } - - async getModelsPath(): Promise { - // Ensure modelsBasePath is initialized - if (!this.modelsBasePath) { - const customPath = await this.getSetting( - LlamaCppSettings.ModelsPath, - '' - ) - if (customPath && (await fs.exists(customPath))) { - this.modelsBasePath = customPath - } else { - this.modelsBasePath = await path.join( - await getJanDataFolderPath(), - 'models', - ENGINE_ID - ) - } - await fs.createDirAll(this.modelsBasePath) + if (opts.n_ctx !== undefined) { + args.push('-c', String(opts.n_ctx)) } - return this.modelsBasePath - } - async listModels(_opts: ListOptions): Promise { - const modelsDir = await this.getModelsPath() - const result: ModelInfo[] = [] + // Add remaining options from the interface + if (opts.threads !== undefined) { + args.push('--threads', String(opts.threads)) + } + + if (opts.threads_batch !== undefined) { + args.push('--threads-batch', String(opts.threads_batch)) + } + + if (opts.ctx_size !== undefined) { + args.push('--ctx-size', String(opts.ctx_size)) + } + + if (opts.n_predict !== undefined) { + args.push('--n-predict', String(opts.n_predict)) + } + + if (opts.batch_size !== undefined) { + args.push('--batch-size', String(opts.batch_size)) + } + + if (opts.ubatch_size !== undefined) { + args.push('--ubatch-size', String(opts.ubatch_size)) + } + + if (opts.device !== undefined) { + args.push('--device', opts.device) + } + + if (opts.split_mode !== undefined) { + args.push('--split-mode', opts.split_mode) + } + + if (opts.main_gpu !== undefined) { + args.push('--main-gpu', String(opts.main_gpu)) + } + + // Boolean flags + if (opts.flash_attn === true) { + args.push('--flash-attn') + } + + if (opts.cont_batching === true) { + args.push('--cont-batching') + } + + if (opts.no_mmap === true) { + args.push('--no-mmap') + } + + if (opts.mlock === true) { + args.push('--mlock') + } + + if (opts.no_kv_offload === true) { + args.push('--no-kv-offload') + } + + if (opts.cache_type_k !== undefined) { + args.push('--cache-type-k', opts.cache_type_k) + } + + if (opts.cache_type_v !== undefined) { + args.push('--cache-type-v', opts.cache_type_v) + } + + if (opts.defrag_thold !== undefined) { + args.push('--defrag-thold', String(opts.defrag_thold)) + } + + if (opts.rope_scaling !== undefined) { + args.push('--rope-scaling', opts.rope_scaling) + } + + if (opts.rope_scale !== undefined) { + args.push('--rope-scale', String(opts.rope_scale)) + } + + if (opts.rope_freq_base !== undefined) { + args.push('--rope-freq-base', String(opts.rope_freq_base)) + } + + if (opts.rope_freq_scale !== undefined) { + args.push('--rope-freq-scale', String(opts.rope_freq_scale)) + } + console.log('Calling Tauri command load with args:', args) try { - if (!(await fs.exists(modelsDir))) { - await fs.createDirAll(modelsDir) - return [] - } - - const entries = await fs.readDir(modelsDir) - for (const entry of entries) { - if (entry.name?.endsWith('.gguf') && entry.isFile) { - const modelPath = await path.join(modelsDir, entry.name) - const stats = await fs.stat(modelPath) - const parsedName = parseGGUFFileName(entry.name) - - result.push({ - id: `${parsedName.baseModelId}${parsedName.quant ? `/${parsedName.quant}` : ''}`, // e.g., "mistral-7b/Q4_0" - name: entry.name.replace('.gguf', ''), // Or a more human-friendly name - quant_type: parsedName.quant, - providerId: this.providerId, - sizeBytes: stats.size, - path: modelPath, - tags: [this.providerId, parsedName.quant || 'unknown_quant'].filter( - Boolean - ) as string[], - }) - } - } - } catch (error) { - console.error(`[${this.providerId}] Error listing models:`, error) - // Depending on desired behavior, either throw or return empty/partial list - } - return result - } - - // pullModel - async pullModel(opts: PullOptions): Promise { - // TODO: Implement pullModel - return 0; - } - - // abortPull - async abortPull(opts: AbortPullOptions): Promise { - // TODO: implement abortPull - } - - async load(opts: LoadOptions): Promise { - if (opts.providerId !== this.providerId) { - throw new Error('Invalid providerId for LlamaCppProvider.loadModel') - } - - const sessionId = uuidv4() - const loadParams = { - model_path: opts.modelPath, - session_id: sessionId, // Pass sessionId to Rust for tracking - // Default llama.cpp server options, can be overridden by opts.options - port: opts.options?.port ?? 0, // 0 for dynamic port assignment by OS - n_gpu_layers: - opts.options?.n_gpu_layers ?? - (await this.getSetting(LlamaCppSettings.DefaultNGpuLayers, -1)), - n_ctx: - opts.options?.n_ctx ?? - (await this.getSetting(LlamaCppSettings.DefaultNContext, 2048)), - // Spread any other options from opts.options - ...(opts.options || {}), - } - - try { - console.log( - `[${this.providerId}] Requesting to load model: ${opts.modelPath} with options:`, - loadParams - ) - // This matches the Rust handler: core::utils::extensions::inference_llamacpp_extension::server::load - const rustResponse: { - session_id: string - port: number - model_path: string - settings: Record - } = await invoke('plugin:llamacpp|load', { params: loadParams }) // Adjust namespace if needed - - if (!rustResponse || !rustResponse.port) { - throw new Error( - 'Rust load function did not return expected port or session info.' - ) - } - - const sessionInfo: SessionInfo = { - sessionId: rustResponse.session_id, // Use sessionId from Rust if it regenerates/confirms it - port: rustResponse.port, - modelPath: rustResponse.model_path, - providerId: this.providerId, - settings: rustResponse.settings, // Settings actually used by the server - } + const sessionInfo = await invoke('plugin:llamacpp|load', { + args: args, + }) + // Store the session info for later use this.activeSessions.set(sessionInfo.sessionId, sessionInfo) - console.log( - `[${this.providerId}] Model loaded: ${sessionInfo.modelPath} on port ${sessionInfo.port}, session: ${sessionInfo.sessionId}` - ) + return sessionInfo } catch (error) { - console.error( - `[${this.providerId}] Error loading model ${opts.modelPath}:`, - error - ) - throw error // Re-throw to be handled by the caller + console.error('Error loading llama-server:', error) + throw new Error(`Failed to load llama-server: ${error}`) } } - async unload(opts: UnloadOptions): Promise { - if (opts.providerId !== this.providerId) { - return { success: false, error: 'Invalid providerId' } - } - const session = this.activeSessions.get(opts.sessionId) - if (!session) { + async unload(opts: unloadOptions): Promise { + try { + // Pass the PID as the session_id + const result = await invoke('plugin:llamacpp|unload', { + session_id: opts.sessionId, // Using PID as session ID + }) + + // If successful, remove from active sessions + if (result.success) { + this.activeSessions.delete(opts.sessionId) + console.log(`Successfully unloaded model with PID ${opts.sessionId}`) + } else { + console.warn(`Failed to unload model: ${result.error}`) + } + + return result + } catch (error) { + console.error('Error in unload command:', error) return { success: false, - error: `No active session found for id: ${opts.sessionId}`, + error: `Failed to unload model: ${error}`, } } - - try { - console.log( - `[${this.providerId}] Requesting to unload model for session: ${opts.sessionId}` - ) - // Matches: core::utils::extensions::inference_llamacpp_extension::server::unload - const rustResponse: { success: boolean; error?: string } = await invoke( - 'plugin:llamacpp|unload', - { sessionId: opts.sessionId } - ) - - if (rustResponse.success) { - this.activeSessions.delete(opts.sessionId) - console.log( - `[${this.providerId}] Session ${opts.sessionId} unloaded successfully.` - ) - return { success: true } - } else { - console.error( - `[${this.providerId}] Failed to unload session ${opts.sessionId}: ${rustResponse.error}` - ) - return { - success: false, - error: rustResponse.error || 'Unknown error during unload', - } - } - } catch (error: any) { - console.error( - `[${this.providerId}] Error invoking unload for session ${opts.sessionId}:`, - error - ) - return { success: false, error: error.message || String(error) } - } } async chat( - opts: ChatOptions - ): Promise> {} + opts: chatOptions + ): Promise> { + const sessionInfo = this.activeSessions.get(opts.sessionId) + if (!sessionInfo) { + throw new Error( + `No active session found for sessionId: ${opts.sessionId}` + ) + } - async deleteModel(opts: DeleteOptions): Promise {} + // For streaming responses + if (opts.stream) { + return this.streamChat(opts) + } - async importModel(opts: ImportOptions): Promise {} - - override async loadModel(model: Model): Promise { - if (model.engine?.toString() !== this.provider) return Promise.resolve() - console.log( - `[${this.providerId} AIEngine] Received OnModelInit for:`, - model.id - ) - return super.load(model) + // For non-streaming responses + try { + return await invoke('plugin:llamacpp|chat', { opts }) + } catch (error) { + console.error('Error during chat completion:', error) + throw new Error(`Chat completion failed: ${error}`) + } } - override async unloadModel(model?: Model): Promise { - if (model?.engine && model.engine.toString() !== this.provider) - return Promise.resolve() - console.log( - `[${this.providerId} AIEngine] Received OnModelStop for:`, - model?.id || 'all models' - ) - return super.unload(model) + async delete(opts: deleteOptions): Promise { + throw new Error("method not implemented yet") + } + + async import(opts: importOptions): Promise { + throw new Error("method not implemented yet") + } + + async abortPull(opts: abortPullOptions): Promise { + throw new Error('method not implemented yet') + } + + // Optional method for direct client access + getChatClient(sessionId: string): any { + throw new Error("method not implemented yet") + } + + onUnload(): void { + throw new Error('Method not implemented.') } } diff --git a/extensions/llamacpp-extension/src/types.ts b/extensions/llamacpp-extension/src/types.ts index 0acfa0329..3a8837147 100644 --- a/extensions/llamacpp-extension/src/types.ts +++ b/extensions/llamacpp-extension/src/types.ts @@ -2,7 +2,7 @@ // --- Re-using OpenAI types (minimal definitions for this example) --- // In a real project, you'd import these from 'openai' or a shared types package. -export interface ChatCompletionRequestMessage { +export interface chatCompletionRequestMessage { role: 'system' | 'user' | 'assistant' | 'tool'; content: string | null; name?: string; @@ -10,9 +10,9 @@ export interface ChatCompletionRequestMessage { tool_call_id?: string; } -export interface ChatCompletionRequest { +export interface chatCompletionRequest { model: string; // Model ID, though for local it might be implicit via sessionId - messages: ChatCompletionRequestMessage[]; + messages: chatCompletionRequestMessage[]; temperature?: number | null; top_p?: number | null; n?: number | null; @@ -26,41 +26,41 @@ export interface ChatCompletionRequest { // ... TODO: other OpenAI params } -export interface ChatCompletionChunkChoiceDelta { +export interface chatCompletionChunkChoiceDelta { content?: string | null; role?: 'system' | 'user' | 'assistant' | 'tool'; tool_calls?: any[]; // Simplified } -export interface ChatCompletionChunkChoice { +export interface chatCompletionChunkChoice { index: number; - delta: ChatCompletionChunkChoiceDelta; + delta: chatCompletionChunkChoiceDelta; finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null; } -export interface ChatCompletionChunk { +export interface chatCompletionChunk { id: string; object: 'chat.completion.chunk'; created: number; model: string; - choices: ChatCompletionChunkChoice[]; + choices: chatCompletionChunkChoice[]; system_fingerprint?: string; } -export interface ChatCompletionChoice { +export interface chatCompletionChoice { index: number; - message: ChatCompletionRequestMessage; // Response message + message: chatCompletionRequestMessage; // Response message finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'; logprobs?: any; // Simplified } -export interface ChatCompletion { +export interface chatCompletion { id: string; object: 'chat.completion'; created: number; model: string; // Model ID used - choices: ChatCompletionChoice[]; + choices: chatCompletionChoice[]; usage?: { prompt_tokens: number; completion_tokens: number; @@ -72,7 +72,7 @@ export interface ChatCompletion { // Shared model metadata -export interface ModelInfo { +export interface modelInfo { id: string; // e.g. "qwen3-4B" or "org/model/quant" name: string; // human‑readable, e.g., "Qwen3 4B Q4_0" quant_type?: string; // q4_0 (optional as it might be part of ID or name) @@ -86,24 +86,24 @@ export interface ModelInfo { } // 1. /list -export interface ListOptions { +export interface listOptions { providerId: string; // To specify which provider if a central manager calls this } -export type ListResult = ModelInfo[]; +export type listResult = ModelInfo[]; // 2. /pull -export interface PullOptions { +export interface pullOptions { providerId: string; modelId: string; // Identifier for the model to pull (e.g., from a known registry) downloadUrl: string; // URL to download the model from /** optional callback to receive download progress */ onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void; } -export interface PullResult { +export interface pullResult { success: boolean; path?: string; // local file path to the pulled model error?: string; - modelInfo?: ModelInfo; // Info of the pulled model + modelInfo?: modelInfo; // Info of the pulled model } // 3. /load @@ -135,7 +135,7 @@ export interface loadOptions { rope_freq_scale?: number } -export interface SessionInfo { +export interface sessionInfo { sessionId: string; // opaque handle for unload/chat port: number; // llama-server output port (corrected from portid) modelPath: string; // path of the loaded model @@ -143,71 +143,71 @@ export interface SessionInfo { } // 4. /unload -export interface UnloadOptions { +export interface unloadOptions { providerId: string; sessionId: string; } -export interface UnloadResult { +export interface unloadResult { success: boolean; error?: string; } // 5. /chat -export interface ChatOptions { +export interface chatOptions { providerId: string; sessionId: string; /** Full OpenAI ChatCompletionRequest payload */ - payload: ChatCompletionRequest; + payload: chatCompletionRequest; } // Output for /chat will be Promise for non-streaming // or Promise> for streaming // 6. /delete -export interface DeleteOptions { +export interface deleteOptions { providerId: string; modelId: string; // The ID of the model to delete (implies finding its path) modelPath?: string; // Optionally, direct path can be provided } -export interface DeleteResult { +export interface deleteResult { success: boolean; error?: string; } // 7. /import -export interface ImportOptions { +export interface importOptions { providerId: string; sourcePath: string; // Path to the local model file to import desiredModelId?: string; // Optional: if user wants to name it specifically } -export interface ImportResult { +export interface importResult { success: boolean; - modelInfo?: ModelInfo; + modelInfo?: modelInfo; error?: string; } // 8. /abortPull -export interface AbortPullOptions { +export interface abortPullOptions { providerId: string; modelId: string; // The modelId whose download is to be aborted } -export interface AbortPullResult { +export interface abortPullResult { success: boolean; error?: string; } // The interface for any local provider -export interface LocalProvider { +export interface localProvider { readonly providerId: string; - listModels(opts: ListOptions): Promise; - pullModel(opts: PullOptions): Promise; - loadModel(opts: LoadOptions): Promise; - unloadModel(opts: UnloadOptions): Promise; - chat(opts: ChatOptions): Promise>; - deleteModel(opts: DeleteOptions): Promise; - importModel(opts: ImportOptions): Promise; - abortPull(opts: AbortPullOptions): Promise; + list(opts: listOptions): Promise; + pull(opts: pullOptions): Promise; + load(opts: loadOptions): Promise; + unload(opts: unloadOptions): Promise; + chat(opts: chatOptions): Promise>; + delete(opts: deleteOptions): Promise; + import(opts: importOptions): Promise; + abortPull(opts: abortPullOptions): Promise; // Optional: for direct access to underlying client if needed for specific streaming cases getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index a5f592e80..a990b7780 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -3,7 +3,6 @@ use serde::{Serialize, Deserialize}; use tauri::path::BaseDirectory; use tauri::{AppHandle, Manager, State}; // Import Manager trait use tokio::process::Command; -use std::collections::HashMap; use uuid::Uuid; use thiserror; @@ -70,7 +69,12 @@ pub struct SessionInfo { pub session_id: String, // opaque handle for unload/chat pub port: u16, // llama-server output port pub model_path: String, // path of the loaded model - pub settings: HashMap, // The actual settings used to load +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct UnloadResult { + success: bool, + error: Option, } // --- Load Command --- @@ -102,40 +106,12 @@ pub async fn load( ))); } - let mut port = 8080; // Default port - let mut model_path = String::new(); - let mut settings: HashMap = HashMap::new(); - - // Extract arguments into settings map and specific fields - let mut i = 0; - while i < args.len() { - if args[i] == "--port" && i + 1 < args.len() { - if let Ok(p) = args[i + 1].parse::() { - port = p; - } - settings.insert("port".to_string(), serde_json::Value::String(args[i + 1].clone())); - i += 2; - } else if args[i] == "-m" && i + 1 < args.len() { - model_path = args[i + 1].clone(); - settings.insert("modelPath".to_string(), serde_json::Value::String(model_path.clone())); - i += 2; - } else if i + 1 < args.len() && args[i].starts_with("-") { - // Store other arguments as settings - let key = args[i].trim_start_matches("-").trim_start_matches("-"); - settings.insert(key.to_string(), serde_json::Value::String(args[i + 1].clone())); - i += 2; - } else { - // Handle boolean flags - if args[i].starts_with("-") { - let key = args[i].trim_start_matches("-").trim_start_matches("-"); - settings.insert(key.to_string(), serde_json::Value::Bool(true)); - } - i += 1; - } - } + let port = 8080; // Default port // Configure the command to run the server let mut command = Command::new(server_path); + + let model_path = args[0].replace("-m", ""); command.args(args); // Optional: Redirect stdio if needed (e.g., for logging within Jan) @@ -145,17 +121,21 @@ pub async fn load( // Spawn the child process let child = command.spawn().map_err(ServerError::Io)?; - log::info!("Server process started with PID: {:?}", child.id()); + // Get the PID to use as session ID + let pid = child.id().map(|id| id.to_string()).unwrap_or_else(|| { + // Fallback in case we can't get the PID for some reason + format!("unknown_pid_{}", Uuid::new_v4()) + }); + + log::info!("Server process started with PID: {}", pid); // Store the child process handle in the state *process_lock = Some(child); - let session_id = format!("session_{}", Uuid::new_v4()); let session_info = SessionInfo { - session_id, + session_id: pid, // Use PID as session ID port, model_path, - settings, }; Ok(session_info) @@ -163,32 +143,65 @@ pub async fn load( // --- Unload Command --- #[tauri::command] -pub async fn unload(state: State<'_, AppState>) -> ServerResult<()> { +pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult { let mut process_lock = state.llama_server_process.lock().await; - // Take the child process out of the Option, leaving None in its place if let Some(mut child) = process_lock.take() { + // Convert the PID to a string to compare with the session_id + let process_pid = child.id().map(|pid| pid.to_string()).unwrap_or_default(); + + // Check if the session_id matches the PID + if session_id != process_pid && !session_id.is_empty() && !process_pid.is_empty() { + // Put the process back in the lock since we're not killing it + *process_lock = Some(child); + + log::warn!( + "Session ID mismatch: provided {} vs process {}", + session_id, + process_pid + ); + + return Ok(UnloadResult { + success: false, + error: Some(format!("Session ID mismatch: provided {} doesn't match process {}", + session_id, process_pid)), + }); + } + log::info!( "Attempting to terminate server process with PID: {:?}", child.id() ); + // Kill the process - // `start_kill` is preferred in async contexts match child.start_kill() { Ok(_) => { - log::info!("Server process termination signal sent."); - Ok(()) + log::info!("Server process termination signal sent successfully"); + + Ok(UnloadResult { + success: true, + error: None, + }) } Err(e) => { - // For simplicity, we log and return error. log::error!("Failed to kill server process: {}", e); - // Put it back? Maybe not useful if kill failed. - // *process_lock = Some(child); - Err(ServerError::Io(e)) + + // Return formatted error + Ok(UnloadResult { + success: false, + error: Some(format!("Failed to kill server process: {}", e)), + }) } } } else { - log::warn!("Attempted to unload server, but it was not running."); - Ok(()) + log::warn!("Attempted to unload server, but no process was running"); + + // If no process is running but client thinks there is, + // still report success since the end state is what they wanted + Ok(UnloadResult { + success: true, + error: None, + }) } } +