feat: Distinguish and preserve embedding model sessions

This commit introduces a new field, `is_embedding`, to the `SessionInfo` structure to clearly mark sessions running dedicated embedding models.

Key changes:
- Adds `is_embedding` to the `SessionInfo` interface in `AIEngine.ts` and the Rust backend.
- Updates the `loadLlamaModel` command signatures to pass this new flag.
- Modifies the llama.cpp extension's **auto-unload logic** to explicitly **filter out** and **not unload** any currently loaded embedding models when a new text generation model is loaded. This is a critical performance fix to prevent the embedding model (e.g., used for RAG) from being repeatedly reloaded.

Also includes minor code style cleanup/reformatting in `jan-provider-web/provider.ts` for improved readability.
This commit is contained in:
Akarshan 2025-10-08 20:02:05 +05:30
parent ff93dc3c5c
commit 7762cea10a
No known key found for this signature in database
GPG Key ID: D75C9634A870665F
6 changed files with 68 additions and 17 deletions

View File

@ -182,6 +182,7 @@ export interface SessionInfo {
port: number // llama-server output port (corrected from portid)
model_id: string //name of the model
model_path: string // path of the loaded model
is_embedding: boolean
api_key: string
mmproj_path?: string
}

View File

@ -45,7 +45,7 @@ export default class JanProviderWeb extends AIEngine {
// Verify Jan models capabilities in localStorage
private validateJanModelsLocalStorage() {
try {
console.log("Validating Jan models in localStorage...")
console.log('Validating Jan models in localStorage...')
const storageKey = 'model-provider'
const data = localStorage.getItem(storageKey)
if (!data) return
@ -60,9 +60,14 @@ export default class JanProviderWeb extends AIEngine {
if (provider.provider === 'jan' && provider.models) {
for (const model of provider.models) {
console.log(`Checking Jan model: ${model.id}`, model.capabilities)
if (JSON.stringify(model.capabilities) !== JSON.stringify(JAN_MODEL_CAPABILITIES)) {
if (
JSON.stringify(model.capabilities) !==
JSON.stringify(JAN_MODEL_CAPABILITIES)
) {
hasInvalidModel = true
console.log(`Found invalid Jan model: ${model.id}, clearing localStorage`)
console.log(
`Found invalid Jan model: ${model.id}, clearing localStorage`
)
break
}
}
@ -79,9 +84,17 @@ export default class JanProviderWeb extends AIEngine {
// If still present, try setting to empty state
if (afterRemoval) {
// Try alternative clearing method
localStorage.setItem(storageKey, JSON.stringify({ state: { providers: [] }, version: parsed.version || 3 }))
localStorage.setItem(
storageKey,
JSON.stringify({
state: { providers: [] },
version: parsed.version || 3,
})
)
}
console.log('Cleared model-provider from localStorage due to invalid Jan capabilities')
console.log(
'Cleared model-provider from localStorage due to invalid Jan capabilities'
)
// Force a page reload to ensure clean state
window.location.reload()
}
@ -159,6 +172,7 @@ export default class JanProviderWeb extends AIEngine {
port: 443, // HTTPS port
model_id: modelId,
model_path: `remote:${modelId}`, // Indicate this is a remote model
is_embedding: false, // assume false here, TODO: might need further implementation
api_key: '', // API key handled by auth service
}
@ -193,8 +207,12 @@ export default class JanProviderWeb extends AIEngine {
console.error(`Failed to unload Jan session ${sessionId}:`, error)
return {
success: false,
error: error instanceof ApiError ? error.message :
error instanceof Error ? error.message : 'Unknown error',
error:
error instanceof ApiError
? error.message
: error instanceof Error
? error.message
: 'Unknown error',
}
}
}

View File

@ -333,14 +333,12 @@ export default class llamacpp_extension extends AIEngine {
)
// Clear the invalid stored preference
this.clearStoredBackendType()
bestAvailableBackendString = await this.determineBestBackend(
version_backends
)
bestAvailableBackendString =
await this.determineBestBackend(version_backends)
}
} else {
bestAvailableBackendString = await this.determineBestBackend(
version_backends
)
bestAvailableBackendString =
await this.determineBestBackend(version_backends)
}
let settings = structuredClone(SETTINGS)
@ -1530,6 +1528,7 @@ export default class llamacpp_extension extends AIEngine {
if (
this.autoUnload &&
!isEmbedding &&
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
) {
// Wait for OTHER loading models to finish, then unload everything
@ -1537,10 +1536,33 @@ export default class llamacpp_extension extends AIEngine {
await Promise.all(otherLoadingPromises)
}
// Now unload all loaded models
// Now unload all loaded Text models excluding embedding models
const allLoadedModels = await this.getLoadedModels()
if (allLoadedModels.length > 0) {
await Promise.all(allLoadedModels.map((model) => this.unload(model)))
const sessionInfos: (SessionInfo | null)[] = await Promise.all(
allLoadedModels.map(async (modelId) => {
try {
return await this.findSessionByModel(modelId)
} catch (e) {
logger.warn(`Unable to find session for model "${modelId}": ${e}`)
return null // treat as “noteligible for unload”
}
})
)
logger.info(JSON.stringify(sessionInfos))
const nonEmbeddingModels: string[] = sessionInfos
.filter(
(s): s is SessionInfo => s !== null && s.is_embedding === false
)
.map((s) => s.model_id)
if (nonEmbeddingModels.length > 0) {
await Promise.all(
nonEmbeddingModels.map((modelId) => this.unload(modelId))
)
}
}
}
const args: string[] = []
@ -1677,6 +1699,7 @@ export default class llamacpp_extension extends AIEngine {
libraryPath,
args,
envs,
isEmbedding,
}
)
return sInfo
@ -2024,7 +2047,11 @@ export default class llamacpp_extension extends AIEngine {
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
if (!sInfo) {
const downloadedModelList = await this.list()
if (!downloadedModelList.some((model) => model.id === 'sentence-transformer-mini')) {
if (
!downloadedModelList.some(
(model) => model.id === 'sentence-transformer-mini'
)
) {
await this.import('sentence-transformer-mini', {
modelPath:
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',

View File

@ -30,12 +30,14 @@ export async function cleanupLlamaProcesses(): Promise<void> {
export async function loadLlamaModel(
backendPath: string,
libraryPath?: string,
args: string[] = []
args: string[] = [],
isEmbedding: boolean = false
): Promise<SessionInfo> {
return await invoke('plugin:llamacpp|load_llama_model', {
backendPath,
libraryPath,
args,
isEmbedding,
})
}

View File

@ -44,6 +44,7 @@ pub async fn load_llama_model<R: Runtime>(
library_path: Option<&str>,
mut args: Vec<String>,
envs: HashMap<String, String>,
is_embedding: bool,
) -> ServerResult<SessionInfo> {
let state: State<LlamacppState> = app_handle.state();
let mut process_map = state.llama_server_process.lock().await;
@ -223,6 +224,7 @@ pub async fn load_llama_model<R: Runtime>(
port: port,
model_id: model_id,
model_path: model_path_pb.display().to_string(),
is_embedding: is_embedding,
api_key: api_key,
mmproj_path: mmproj_path_string,
};

View File

@ -10,6 +10,7 @@ pub struct SessionInfo {
pub port: i32, // llama-server output port
pub model_id: String,
pub model_path: String, // path of the loaded model
pub is_embedding: bool,
pub api_key: String,
#[serde(default)]
pub mmproj_path: Option<String>,