diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 9e3fb9884..a3033a6c7 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -108,11 +108,11 @@ export interface loadOptions { } export interface sessionInfo { - sessionId: string // opaque handle for unload/chat + pid: string // opaque handle for unload/chat port: number // llama-server output port (corrected from portid) - modelName: string, //name of the model + modelId: string, //name of the model modelPath: string // path of the loaded model - api_key: string + apiKey: string } // 4. /unload diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index cd8cac862..b48240e66 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -478,7 +478,7 @@ export default class llamacpp_extension extends AIEngine { const sInfo = await invoke('load_llama_model', { backendPath, args }) // Store the session info for later use - this.activeSessions.set(sInfo.sessionId, sInfo) + this.activeSessions.set(sInfo.pid, sInfo) return sInfo } catch (error) { @@ -487,17 +487,22 @@ export default class llamacpp_extension extends AIEngine { } } - override async unload(sessionId: string): Promise { + override async unload(modelId: string): Promise { + const sInfo: sessionInfo = this.findSessionByModel(modelId) + if (!sInfo) { + throw new Error(`No active session found for model: ${modelId}`) + } + const pid = sInfo.pid try { // Pass the PID as the session_id const result = await invoke('unload_llama_model', { - sessionId, // Using PID as session ID + pid }) // If successful, remove from active sessions if (result.success) { - this.activeSessions.delete(sessionId) - console.log(`Successfully unloaded model with PID ${sessionId}`) + this.activeSessions.delete(pid) + console.log(`Successfully unloaded model with PID ${pid}`) } else { console.warn(`Failed to unload model: ${result.error}`) } @@ -577,13 +582,9 @@ export default class llamacpp_extension extends AIEngine { } } - private findSessionByModel(modelName: string): sessionInfo | undefined { - for (const [, session] of this.activeSessions) { - if (session.modelName === modelName) { - return session - } - } - return undefined + private findSessionByModel(modelId: string): sessionInfo | undefined { + return Array.from(this.activeSessions.values()) + .find(session => session.modelId === modelId); } override async chat( @@ -595,9 +596,10 @@ export default class llamacpp_extension extends AIEngine { } const baseUrl = `http://localhost:${sessionInfo.port}/v1` const url = `${baseUrl}/chat/completions` + console.log(`Using api-key: ${sessionInfo.apiKey}`) const headers = { 'Content-Type': 'application/json', - 'Authorization': `Bearer ${sessionInfo.api_key}`, + 'Authorization': `Bearer ${sessionInfo.apiKey}`, } const body = JSON.stringify(opts) diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index 69586b57d..acba92fbb 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -41,7 +41,7 @@ type ServerResult = Result; #[derive(Debug, Serialize, Deserialize)] pub struct sessionInfo { pub pid: String, // opaque handle for unload/chat - pub port: u16, // llama-server output port + pub port: String, // llama-server output port pub modelId: String, pub modelPath: String, // path of the loaded model pub apiKey: String, @@ -147,7 +147,7 @@ pub async fn load_llama_model( // --- Unload Command --- #[tauri::command] pub async fn unload_llama_model( - session_id: String, + pid: String, state: State<'_, AppState>, ) -> ServerResult { let mut process_lock = state.llama_server_process.lock().await; @@ -157,13 +157,13 @@ pub async fn unload_llama_model( let process_pid = child.id().map(|pid| pid.to_string()).unwrap_or_default(); // Check if the session_id matches the PID - if session_id != process_pid && !session_id.is_empty() && !process_pid.is_empty() { + if pid != process_pid && !pid.is_empty() && !process_pid.is_empty() { // Put the process back in the lock since we're not killing it *process_lock = Some(child); log::warn!( "Session ID mismatch: provided {} vs process {}", - session_id, + pid, process_pid ); @@ -171,7 +171,7 @@ pub async fn unload_llama_model( success: false, error: Some(format!( "Session ID mismatch: provided {} doesn't match process {}", - session_id, process_pid + pid, process_pid )), }); }