From 7762cea10a321ccec04868c07b2773b303a44276 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Wed, 8 Oct 2025 20:02:05 +0530 Subject: [PATCH] feat: Distinguish and preserve embedding model sessions This commit introduces a new field, `is_embedding`, to the `SessionInfo` structure to clearly mark sessions running dedicated embedding models. Key changes: - Adds `is_embedding` to the `SessionInfo` interface in `AIEngine.ts` and the Rust backend. - Updates the `loadLlamaModel` command signatures to pass this new flag. - Modifies the llama.cpp extension's **auto-unload logic** to explicitly **filter out** and **not unload** any currently loaded embedding models when a new text generation model is loaded. This is a critical performance fix to prevent the embedding model (e.g., used for RAG) from being repeatedly reloaded. Also includes minor code style cleanup/reformatting in `jan-provider-web/provider.ts` for improved readability. --- .../browser/extensions/engines/AIEngine.ts | 1 + .../src/jan-provider-web/provider.ts | 32 ++++++++++--- extensions/llamacpp-extension/src/index.ts | 45 +++++++++++++++---- .../tauri-plugin-llamacpp/guest-js/index.ts | 4 +- .../tauri-plugin-llamacpp/src/commands.rs | 2 + .../tauri-plugin-llamacpp/src/state.rs | 1 + 6 files changed, 68 insertions(+), 17 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index a4f98e71c..1be977034 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -182,6 +182,7 @@ export interface SessionInfo { port: number // llama-server output port (corrected from portid) model_id: string //name of the model model_path: string // path of the loaded model + is_embedding: boolean api_key: string mmproj_path?: string } diff --git a/extensions-web/src/jan-provider-web/provider.ts b/extensions-web/src/jan-provider-web/provider.ts index 3375fd351..67e513c3f 100644 --- a/extensions-web/src/jan-provider-web/provider.ts +++ b/extensions-web/src/jan-provider-web/provider.ts @@ -45,7 +45,7 @@ export default class JanProviderWeb extends AIEngine { // Verify Jan models capabilities in localStorage private validateJanModelsLocalStorage() { try { - console.log("Validating Jan models in localStorage...") + console.log('Validating Jan models in localStorage...') const storageKey = 'model-provider' const data = localStorage.getItem(storageKey) if (!data) return @@ -60,9 +60,14 @@ export default class JanProviderWeb extends AIEngine { if (provider.provider === 'jan' && provider.models) { for (const model of provider.models) { console.log(`Checking Jan model: ${model.id}`, model.capabilities) - if (JSON.stringify(model.capabilities) !== JSON.stringify(JAN_MODEL_CAPABILITIES)) { + if ( + JSON.stringify(model.capabilities) !== + JSON.stringify(JAN_MODEL_CAPABILITIES) + ) { hasInvalidModel = true - console.log(`Found invalid Jan model: ${model.id}, clearing localStorage`) + console.log( + `Found invalid Jan model: ${model.id}, clearing localStorage` + ) break } } @@ -79,9 +84,17 @@ export default class JanProviderWeb extends AIEngine { // If still present, try setting to empty state if (afterRemoval) { // Try alternative clearing method - localStorage.setItem(storageKey, JSON.stringify({ state: { providers: [] }, version: parsed.version || 3 })) + localStorage.setItem( + storageKey, + JSON.stringify({ + state: { providers: [] }, + version: parsed.version || 3, + }) + ) } - console.log('Cleared model-provider from localStorage due to invalid Jan capabilities') + console.log( + 'Cleared model-provider from localStorage due to invalid Jan capabilities' + ) // Force a page reload to ensure clean state window.location.reload() } @@ -159,6 +172,7 @@ export default class JanProviderWeb extends AIEngine { port: 443, // HTTPS port model_id: modelId, model_path: `remote:${modelId}`, // Indicate this is a remote model + is_embedding: false, // assume false here, TODO: might need further implementation api_key: '', // API key handled by auth service } @@ -193,8 +207,12 @@ export default class JanProviderWeb extends AIEngine { console.error(`Failed to unload Jan session ${sessionId}:`, error) return { success: false, - error: error instanceof ApiError ? error.message : - error instanceof Error ? error.message : 'Unknown error', + error: + error instanceof ApiError + ? error.message + : error instanceof Error + ? error.message + : 'Unknown error', } } } diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index fcee9e412..4359e9fa6 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -333,14 +333,12 @@ export default class llamacpp_extension extends AIEngine { ) // Clear the invalid stored preference this.clearStoredBackendType() - bestAvailableBackendString = await this.determineBestBackend( - version_backends - ) + bestAvailableBackendString = + await this.determineBestBackend(version_backends) } } else { - bestAvailableBackendString = await this.determineBestBackend( - version_backends - ) + bestAvailableBackendString = + await this.determineBestBackend(version_backends) } let settings = structuredClone(SETTINGS) @@ -1530,6 +1528,7 @@ export default class llamacpp_extension extends AIEngine { if ( this.autoUnload && + !isEmbedding && (loadedModels.length > 0 || otherLoadingPromises.length > 0) ) { // Wait for OTHER loading models to finish, then unload everything @@ -1537,10 +1536,33 @@ export default class llamacpp_extension extends AIEngine { await Promise.all(otherLoadingPromises) } - // Now unload all loaded models + // Now unload all loaded Text models excluding embedding models const allLoadedModels = await this.getLoadedModels() if (allLoadedModels.length > 0) { - await Promise.all(allLoadedModels.map((model) => this.unload(model))) + const sessionInfos: (SessionInfo | null)[] = await Promise.all( + allLoadedModels.map(async (modelId) => { + try { + return await this.findSessionByModel(modelId) + } catch (e) { + logger.warn(`Unable to find session for model "${modelId}": ${e}`) + return null // treat as “not‑eligible for unload” + } + }) + ) + + logger.info(JSON.stringify(sessionInfos)) + + const nonEmbeddingModels: string[] = sessionInfos + .filter( + (s): s is SessionInfo => s !== null && s.is_embedding === false + ) + .map((s) => s.model_id) + + if (nonEmbeddingModels.length > 0) { + await Promise.all( + nonEmbeddingModels.map((modelId) => this.unload(modelId)) + ) + } } } const args: string[] = [] @@ -1677,6 +1699,7 @@ export default class llamacpp_extension extends AIEngine { libraryPath, args, envs, + isEmbedding, } ) return sInfo @@ -2024,7 +2047,11 @@ export default class llamacpp_extension extends AIEngine { let sInfo = await this.findSessionByModel('sentence-transformer-mini') if (!sInfo) { const downloadedModelList = await this.list() - if (!downloadedModelList.some((model) => model.id === 'sentence-transformer-mini')) { + if ( + !downloadedModelList.some( + (model) => model.id === 'sentence-transformer-mini' + ) + ) { await this.import('sentence-transformer-mini', { modelPath: 'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true', diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts b/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts index 957839a63..7c0e3e4be 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts +++ b/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts @@ -30,12 +30,14 @@ export async function cleanupLlamaProcesses(): Promise { export async function loadLlamaModel( backendPath: string, libraryPath?: string, - args: string[] = [] + args: string[] = [], + isEmbedding: boolean = false ): Promise { return await invoke('plugin:llamacpp|load_llama_model', { backendPath, libraryPath, args, + isEmbedding, }) } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs index 96ecb36bc..1d898b4d9 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs @@ -44,6 +44,7 @@ pub async fn load_llama_model( library_path: Option<&str>, mut args: Vec, envs: HashMap, + is_embedding: bool, ) -> ServerResult { let state: State = app_handle.state(); let mut process_map = state.llama_server_process.lock().await; @@ -223,6 +224,7 @@ pub async fn load_llama_model( port: port, model_id: model_id, model_path: model_path_pb.display().to_string(), + is_embedding: is_embedding, api_key: api_key, mmproj_path: mmproj_path_string, }; diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/state.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/state.rs index 2aad02ecf..a299ec9c5 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/state.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/state.rs @@ -10,6 +10,7 @@ pub struct SessionInfo { pub port: i32, // llama-server output port pub model_id: String, pub model_path: String, // path of the loaded model + pub is_embedding: bool, pub api_key: String, #[serde(default)] pub mmproj_path: Option,