From 32a2ca95b6e16411bada6cee6bc7889a11c8c5bf Mon Sep 17 00:00:00 2001 From: Dinh Long Nguyen Date: Thu, 21 Aug 2025 16:17:58 +0700 Subject: [PATCH] feat: gguf file size + hash validation (#5266) (#6259) * feat: gguf file size + hash validation * fix tests fe * update cargo tests * handle asyn download for both models and mmproj * move progress tracker to models * handle file download cancelled * add cancellation mid hash run --- .../browser/extensions/engines/AIEngine.ts | 4 + core/src/types/api/index.ts | 3 + extensions/download-extension/src/index.ts | 2 + extensions/llamacpp-extension/src/index.ts | 142 ++++- src-tauri/Cargo.lock | 11 +- src-tauri/src/core/downloads/helpers.rs | 503 +++++++++++++----- src-tauri/src/core/downloads/models.rs | 32 ++ src-tauri/src/core/downloads/tests.rs | 6 + src-tauri/utils/Cargo.toml | 1 + src-tauri/utils/src/crypto.rs | 62 ++- web-app/src/containers/DownloadManegement.tsx | 74 ++- web-app/src/locales/de-DE/common.json | 16 + web-app/src/locales/de-DE/hub.json | 1 + web-app/src/locales/en/common.json | 12 + web-app/src/locales/en/hub.json | 1 + web-app/src/locales/id/common.json | 16 + web-app/src/locales/id/hub.json | 1 + web-app/src/locales/vn/common.json | 16 + web-app/src/locales/vn/hub.json | 1 + web-app/src/locales/zh-CN/common.json | 16 + web-app/src/locales/zh-CN/hub.json | 1 + web-app/src/locales/zh-TW/common.json | 16 + web-app/src/locales/zh-TW/hub.json | 1 + web-app/src/routes/hub/$modelId.tsx | 8 +- web-app/src/routes/hub/index.tsx | 14 +- web-app/src/services/__tests__/models.test.ts | 10 +- web-app/src/services/models.ts | 98 +++- 27 files changed, 915 insertions(+), 153 deletions(-) diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index b203092ce..7a223e468 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -194,6 +194,10 @@ export interface chatOptions { export interface ImportOptions { modelPath: string mmprojPath?: string + modelSha256?: string + modelSize?: number + mmprojSha256?: string + mmprojSize?: number } export interface importResult { diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index ade6421ff..d40aab852 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -73,6 +73,9 @@ export enum DownloadEvent { onFileDownloadSuccess = 'onFileDownloadSuccess', onFileDownloadStopped = 'onFileDownloadStopped', onFileDownloadStarted = 'onFileDownloadStarted', + onModelValidationStarted = 'onModelValidationStarted', + onModelValidationFailed = 'onModelValidationFailed', + onFileDownloadAndVerificationSuccess = 'onFileDownloadAndVerificationSuccess', } export enum ExtensionRoute { baseExtensions = 'baseExtensions', diff --git a/extensions/download-extension/src/index.ts b/extensions/download-extension/src/index.ts index 04c34cd6c..8045b5eeb 100644 --- a/extensions/download-extension/src/index.ts +++ b/extensions/download-extension/src/index.ts @@ -10,6 +10,8 @@ interface DownloadItem { url: string save_path: string proxy?: Record + sha256?: string + size?: number } type DownloadEvent = { diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index bf03024a0..eefdc44af 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -20,9 +20,11 @@ import { chatCompletionRequest, events, AppEvent, + DownloadEvent, } from '@janhq/core' import { error, info, warn } from '@tauri-apps/plugin-log' +import { listen } from '@tauri-apps/api/event' import { listSupportedBackends, @@ -71,6 +73,8 @@ interface DownloadItem { url: string save_path: string proxy?: Record + sha256?: string + size?: number } interface ModelConfig { @@ -79,6 +83,9 @@ interface ModelConfig { name: string // user-friendly // some model info that we cache upon import size_bytes: number + sha256?: string + mmproj_sha256?: string + mmproj_size_bytes?: number } interface EmbeddingResponse { @@ -154,6 +161,7 @@ export default class llamacpp_extension extends AIEngine { private pendingDownloads: Map> = new Map() private isConfiguringBackends: boolean = false private loadingModels = new Map>() // Track loading promises + private unlistenValidationStarted?: () => void override async onLoad(): Promise { super.onLoad() // Calls registerEngine() from AIEngine @@ -181,6 +189,19 @@ export default class llamacpp_extension extends AIEngine { await getJanDataFolderPath(), this.providerId, ]) + + // Set up validation event listeners to bridge Tauri events to frontend + this.unlistenValidationStarted = await listen<{ + modelId: string + downloadType: string + }>('onModelValidationStarted', (event) => { + console.debug( + 'LlamaCPP: bridging onModelValidationStarted event', + event.payload + ) + events.emit(DownloadEvent.onModelValidationStarted, event.payload) + }) + this.configureBackends() } @@ -774,6 +795,11 @@ export default class llamacpp_extension extends AIEngine { override async onUnload(): Promise { // Terminate all active sessions + + // Clean up validation event listeners + if (this.unlistenValidationStarted) { + this.unlistenValidationStarted() + } } onSettingUpdate(key: string, value: T): void { @@ -1006,6 +1032,9 @@ export default class llamacpp_extension extends AIEngine { url: path, save_path: localPath, proxy: getProxyConfig(), + sha256: + saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256, + size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize, }) return localPath } @@ -1023,8 +1052,6 @@ export default class llamacpp_extension extends AIEngine { : undefined if (downloadItems.length > 0) { - let downloadCompleted = false - try { // emit download update event on progress const onProgress = (transferred: number, total: number) => { @@ -1034,7 +1061,6 @@ export default class llamacpp_extension extends AIEngine { size: { transferred, total }, downloadType: 'Model', }) - downloadCompleted = transferred === total } const downloadManager = window.core.extensionManager.getByName( '@janhq/download-extension' @@ -1045,13 +1071,67 @@ export default class llamacpp_extension extends AIEngine { onProgress ) - const eventName = downloadCompleted - ? 'onFileDownloadSuccess' - : 'onFileDownloadStopped' - events.emit(eventName, { modelId, downloadType: 'Model' }) + // If we reach here, download completed successfully (including validation) + // The downloadFiles function only returns successfully if all files downloaded AND validated + events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, { + modelId, + downloadType: 'Model' + }) } catch (error) { logger.error('Error downloading model:', modelId, opts, error) - events.emit('onFileDownloadError', { modelId, downloadType: 'Model' }) + const errorMessage = + error instanceof Error ? error.message : String(error) + + // Check if this is a cancellation + const isCancellationError = errorMessage.includes('Download cancelled') || + errorMessage.includes('Validation cancelled') || + errorMessage.includes('Hash computation cancelled') || + errorMessage.includes('cancelled') || + errorMessage.includes('aborted') + + // Check if this is a validation failure + const isValidationError = + errorMessage.includes('Hash verification failed') || + errorMessage.includes('Size verification failed') || + errorMessage.includes('Failed to verify file') + + if (isCancellationError) { + logger.info('Download cancelled for model:', modelId) + // Emit download stopped event instead of error + events.emit(DownloadEvent.onFileDownloadStopped, { + modelId, + downloadType: 'Model', + }) + } else if (isValidationError) { + logger.error( + 'Validation failed for model:', + modelId, + 'Error:', + errorMessage + ) + + // Cancel any other download tasks for this model + try { + this.abortImport(modelId) + } catch (cancelError) { + logger.warn('Failed to cancel download task:', cancelError) + } + + // Emit validation failure event + events.emit(DownloadEvent.onModelValidationFailed, { + modelId, + downloadType: 'Model', + error: errorMessage, + reason: 'validation_failed', + }) + } else { + // Regular download error + events.emit(DownloadEvent.onFileDownloadError, { + modelId, + downloadType: 'Model', + error: errorMessage, + }) + } throw error } } @@ -1078,7 +1158,9 @@ export default class llamacpp_extension extends AIEngine { } catch (error) { logger.error('GGUF validation failed:', error) throw new Error( - `Invalid GGUF file(s): ${error.message || 'File format validation failed'}` + `Invalid GGUF file(s): ${ + error.message || 'File format validation failed' + }` ) } @@ -1097,6 +1179,10 @@ export default class llamacpp_extension extends AIEngine { mmproj_path: mmprojPath, name: modelId, size_bytes, + model_sha256: opts.modelSha256, + model_size_bytes: opts.modelSize, + mmproj_sha256: opts.mmprojSha256, + mmproj_size_bytes: opts.mmprojSize, } as ModelConfig await fs.mkdir(await joinPath([janDataFolderPath, modelDir])) await invoke('write_yaml', { @@ -1108,16 +1194,50 @@ export default class llamacpp_extension extends AIEngine { modelPath, mmprojPath, size_bytes, + model_sha256: opts.modelSha256, + model_size_bytes: opts.modelSize, + mmproj_sha256: opts.mmprojSha256, + mmproj_size_bytes: opts.mmprojSize, }) } + /** + * Deletes the entire model folder for a given modelId + * @param modelId The model ID to delete + */ + private async deleteModelFolder(modelId: string): Promise { + try { + const modelDir = await joinPath([ + await this.getProviderPath(), + 'models', + modelId, + ]) + + if (await fs.existsSync(modelDir)) { + logger.info(`Cleaning up model directory: ${modelDir}`) + await fs.rm(modelDir) + } + } catch (deleteError) { + logger.warn('Failed to delete model directory:', deleteError) + } + } + override async abortImport(modelId: string): Promise { - // prepand provider name to avoid name collision + // Cancel any active download task + // prepend provider name to avoid name collision const taskId = this.createDownloadTaskId(modelId) const downloadManager = window.core.extensionManager.getByName( '@janhq/download-extension' ) - await downloadManager.cancelDownload(taskId) + + try { + await downloadManager.cancelDownload(taskId) + } catch (cancelError) { + logger.warn('Failed to cancel download task:', cancelError) + } + + // Delete the entire model folder if it exists (for validation failures) + await this.deleteModelFolder(modelId) } /** diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 32638bc56..013982c83 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -2323,6 +2323,7 @@ dependencies = [ "serde_json", "sha2", "tokio", + "tokio-util", "url", ] @@ -4019,8 +4020,9 @@ dependencies = [ [[package]] name = "rmcp" -version = "0.5.0" -source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb21cd3555f1059f27e4813827338dec44429a08ecd0011acc41d9907b160c00" dependencies = [ "base64 0.22.1", "chrono", @@ -4045,8 +4047,9 @@ dependencies = [ [[package]] name = "rmcp-macros" -version = "0.5.0" -source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5d16ae1ff3ce2c5fd86c37047b2869b75bec795d53a4b1d8257b15415a2354" dependencies = [ "darling 0.21.2", "proc-macro2", diff --git a/src-tauri/src/core/downloads/helpers.rs b/src-tauri/src/core/downloads/helpers.rs index 1fad0ea4b..137bbdd3d 100644 --- a/src-tauri/src/core/downloads/helpers.rs +++ b/src-tauri/src/core/downloads/helpers.rs @@ -1,9 +1,10 @@ -use super::models::{DownloadEvent, DownloadItem, ProxyConfig}; +use super::models::{DownloadEvent, DownloadItem, ProxyConfig, ProgressTracker}; use crate::core::app::commands::get_jan_data_folder_path; use futures_util::StreamExt; use jan_utils::normalize_path; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use std::collections::HashMap; +use std::path::Path; use std::time::Duration; use tauri::Emitter; use tokio::fs::File; @@ -11,10 +12,131 @@ use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; use url::Url; +// ===== UTILITY FUNCTIONS ===== + pub fn err_to_string(e: E) -> String { format!("Error: {}", e) } + +// ===== VALIDATION FUNCTIONS ===== + +/// Validates a downloaded file against expected hash and size +async fn validate_downloaded_file( + item: &DownloadItem, + save_path: &Path, + app: &tauri::AppHandle, + cancel_token: &CancellationToken, +) -> Result<(), String> { + // Skip validation if no verification data is provided + if item.sha256.is_none() && item.size.is_none() { + log::debug!( + "No validation data provided for {}, skipping validation", + item.url + ); + return Ok(()); + } + + // Extract model ID from save path for validation events + // Path structure: llamacpp/models/{modelId}/model.gguf or llamacpp/models/{modelId}/mmproj.gguf + let model_id = save_path + .parent() // get parent directory (modelId folder) + .and_then(|p| p.file_name()) + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + + // Emit validation started event + app.emit( + "onModelValidationStarted", + serde_json::json!({ + "modelId": model_id, + "downloadType": "Model", + }), + ) + .unwrap(); + + log::info!("Starting validation for model: {}", model_id); + + // Validate size if provided (fast check first) + if let Some(expected_size) = &item.size { + log::info!("Starting size verification for {}", item.url); + + match tokio::fs::metadata(save_path).await { + Ok(metadata) => { + let actual_size = metadata.len(); + + if actual_size != *expected_size { + log::error!( + "Size verification failed for {}. Expected: {} bytes, Actual: {} bytes", + item.url, + expected_size, + actual_size + ); + return Err(format!( + "Size verification failed. Expected {} bytes but got {} bytes.", + expected_size, actual_size + )); + } + + log::info!( + "Size verification successful for {} ({} bytes)", + item.url, + actual_size + ); + } + Err(e) => { + log::error!( + "Failed to get file metadata for {}: {}", + save_path.display(), + e + ); + return Err(format!("Failed to verify file size: {}", e)); + } + } + } + + // Check for cancellation before expensive hash computation + if cancel_token.is_cancelled() { + log::info!("Validation cancelled for {}", item.url); + return Err("Validation cancelled".to_string()); + } + + // Validate hash if provided (expensive check second) + if let Some(expected_sha256) = &item.sha256 { + log::info!("Starting Hash verification for {}", item.url); + + match jan_utils::crypto::compute_file_sha256_with_cancellation(save_path, cancel_token).await { + Ok(computed_sha256) => { + if computed_sha256 != *expected_sha256 { + log::error!( + "Hash verification failed for {}. Expected: {}, Computed: {}", + item.url, + expected_sha256, + computed_sha256 + ); + + return Err(format!( + "Hash verification failed. The downloaded file is corrupted or has been tampered with." + )); + } + + log::info!("Hash verification successful for {}", item.url); + } + Err(e) => { + log::error!( + "Failed to compute SHA256 for {}: {}", + save_path.display(), + e + ); + return Err(format!("Failed to verify file integrity: {}", e)); + } + } + } + + log::info!("All validations passed for {}", item.url); + Ok(()) +} + pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> { // Validate proxy URL format if let Err(e) = Url::parse(&config.url) { @@ -172,6 +294,9 @@ pub async fn _get_file_size( } } +// ===== MAIN DOWNLOAD FUNCTIONS ===== + +/// Downloads multiple files in parallel with individual progress tracking pub async fn _download_files_internal( app: tauri::AppHandle, items: &[DownloadItem], @@ -184,28 +309,31 @@ pub async fn _download_files_internal( let header_map = _convert_headers(headers).map_err(err_to_string)?; - let total_size = { - let mut total_size = 0u64; - for item in items.iter() { - let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?; - total_size += _get_file_size(&client, &item.url) - .await - .map_err(err_to_string)?; - } - total_size - }; + // Calculate sizes for each file + let mut file_sizes = HashMap::new(); + for item in items.iter() { + let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?; + let size = _get_file_size(&client, &item.url) + .await + .map_err(err_to_string)?; + file_sizes.insert(item.url.clone(), size); + } + + let total_size: u64 = file_sizes.values().sum(); log::info!("Total download size: {}", total_size); - let mut evt = DownloadEvent { - transferred: 0, - total: total_size, - }; let evt_name = format!("download-{}", task_id); + // Create progress tracker + let progress_tracker = ProgressTracker::new(items, file_sizes.clone()); + // save file under Jan data folder let jan_data_folder = get_jan_data_folder_path(app.clone()); - for item in items.iter() { + // Collect download tasks for parallel execution + let mut download_tasks = Vec::new(); + + for (index, item) in items.iter().enumerate() { let save_path = jan_data_folder.join(&item.save_path); let save_path = normalize_path(&save_path); @@ -217,120 +345,251 @@ pub async fn _download_files_internal( )); } - // Create parent directories if they don't exist - if let Some(parent) = save_path.parent() { - if !parent.exists() { - tokio::fs::create_dir_all(parent) - .await - .map_err(err_to_string)?; - } - } + // Spawn download task for each file + let item_clone = item.clone(); + let app_clone = app.clone(); + let header_map_clone = header_map.clone(); + let cancel_token_clone = cancel_token.clone(); + let evt_name_clone = evt_name.clone(); + let progress_tracker_clone = progress_tracker.clone(); + let file_id = format!("{}-{}", task_id, index); + let file_size = file_sizes.get(&item.url).copied().unwrap_or(0); - let current_extension = save_path.extension().unwrap_or_default().to_string_lossy(); - let append_extension = |ext: &str| { - if current_extension.is_empty() { - ext.to_string() - } else { - format!("{}.{}", current_extension, ext) - } - }; - let tmp_save_path = save_path.with_extension(append_extension("tmp")); - let url_save_path = save_path.with_extension(append_extension("url")); - - let mut should_resume = resume - && tmp_save_path.exists() - && tokio::fs::read_to_string(&url_save_path) - .await - .map(|url| url == item.url) // check if we resume the same URL - .unwrap_or(false); - - tokio::fs::write(&url_save_path, item.url.clone()) + let task = tokio::spawn(async move { + download_single_file( + app_clone, + &item_clone, + &header_map_clone, + &save_path, + resume, + cancel_token_clone, + evt_name_clone, + progress_tracker_clone, + file_id, + file_size, + ) .await - .map_err(err_to_string)?; + }); - log::info!("Started downloading: {}", item.url); - let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?; - let mut download_delta = 0u64; - let resp = if should_resume { - let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len(); - match _get_maybe_resume(&client, &item.url, downloaded_size).await { - Ok(resp) => { - log::info!( - "Resume download: {}, already downloaded {} bytes", - item.url, - downloaded_size - ); - download_delta += downloaded_size; - resp - } - Err(e) => { - // fallback to normal download - log::warn!("Failed to resume download: {}", e); - should_resume = false; - _get_maybe_resume(&client, &item.url, 0).await? - } - } - } else { - _get_maybe_resume(&client, &item.url, 0).await? - }; - let mut stream = resp.bytes_stream(); - - let file = if should_resume { - // resume download, append to existing file - tokio::fs::OpenOptions::new() - .write(true) - .append(true) - .open(&tmp_save_path) - .await - .map_err(err_to_string)? - } else { - // start new download, create a new file - File::create(&tmp_save_path).await.map_err(err_to_string)? - }; - let mut writer = tokio::io::BufWriter::new(file); - - // write chunk to file - while let Some(chunk) = stream.next().await { - if cancel_token.is_cancelled() { - if !should_resume { - tokio::fs::remove_dir_all(&save_path.parent().unwrap()) - .await - .ok(); - } - log::info!("Download cancelled for task: {}", task_id); - app.emit(&evt_name, evt.clone()).unwrap(); - return Ok(()); - } - - let chunk = chunk.map_err(err_to_string)?; - writer.write_all(&chunk).await.map_err(err_to_string)?; - download_delta += chunk.len() as u64; - - // only update every 10 MB - if download_delta >= 10 * 1024 * 1024 { - evt.transferred += download_delta; - app.emit(&evt_name, evt.clone()).unwrap(); - download_delta = 0u64; - } - } - - writer.flush().await.map_err(err_to_string)?; - evt.transferred += download_delta; - - // rename tmp file to final file - tokio::fs::rename(&tmp_save_path, &save_path) - .await - .map_err(err_to_string)?; - tokio::fs::remove_file(&url_save_path) - .await - .map_err(err_to_string)?; - log::info!("Finished downloading: {}", item.url); + download_tasks.push(task); } - app.emit(&evt_name, evt.clone()).unwrap(); + // Wait for all downloads to complete + let mut validation_tasks = Vec::new(); + for (task, item) in download_tasks.into_iter().zip(items.iter()) { + let result = task.await.map_err(|e| format!("Task join error: {}", e))?; + + match result { + Ok(downloaded_path) => { + // Spawn validation task in parallel + let item_clone = item.clone(); + let app_clone = app.clone(); + let path_clone = downloaded_path.clone(); + let cancel_token_clone = cancel_token.clone(); + let validation_task = tokio::spawn(async move { + validate_downloaded_file(&item_clone, &path_clone, &app_clone, &cancel_token_clone).await + }); + validation_tasks.push((validation_task, downloaded_path, item.clone())); + } + Err(e) => return Err(e), + } + } + + // Wait for all validations to complete + for (validation_task, save_path, _item) in validation_tasks { + let validation_result = validation_task + .await + .map_err(|e| format!("Validation task join error: {}", e))?; + + if let Err(validation_error) = validation_result { + // Clean up the file if validation fails + let _ = tokio::fs::remove_file(&save_path).await; + + // Try to clean up the parent directory if it's empty + if let Some(parent) = save_path.parent() { + let _ = tokio::fs::remove_dir(parent).await; + } + + return Err(validation_error); + } + } + + // Emit final progress + let (transferred, total) = progress_tracker.get_total_progress().await; + let final_evt = DownloadEvent { transferred, total }; + app.emit(&evt_name, final_evt).unwrap(); Ok(()) } +/// Downloads a single file without blocking other downloads +async fn download_single_file( + app: tauri::AppHandle, + item: &DownloadItem, + header_map: &HeaderMap, + save_path: &std::path::Path, + resume: bool, + cancel_token: CancellationToken, + evt_name: String, + progress_tracker: ProgressTracker, + file_id: String, + _file_size: u64, +) -> Result { + // Create parent directories if they don't exist + if let Some(parent) = save_path.parent() { + if !parent.exists() { + tokio::fs::create_dir_all(parent) + .await + .map_err(err_to_string)?; + } + } + + let current_extension = save_path.extension().unwrap_or_default().to_string_lossy(); + let append_extension = |ext: &str| { + if current_extension.is_empty() { + ext.to_string() + } else { + format!("{}.{}", current_extension, ext) + } + }; + let tmp_save_path = save_path.with_extension(append_extension("tmp")); + let url_save_path = save_path.with_extension(append_extension("url")); + + let mut should_resume = resume + && tmp_save_path.exists() + && tokio::fs::read_to_string(&url_save_path) + .await + .map(|url| url == item.url) // check if we resume the same URL + .unwrap_or(false); + + tokio::fs::write(&url_save_path, item.url.clone()) + .await + .map_err(err_to_string)?; + + log::info!("Started downloading: {}", item.url); + let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?; + let mut download_delta = 0u64; + let mut initial_progress = 0u64; + + let resp = if should_resume { + let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len(); + match _get_maybe_resume(&client, &item.url, downloaded_size).await { + Ok(resp) => { + log::info!( + "Resume download: {}, already downloaded {} bytes", + item.url, + downloaded_size + ); + initial_progress = downloaded_size; + + // Initialize progress for resumed download + progress_tracker + .update_progress(&file_id, downloaded_size) + .await; + + // Emit initial combined progress + let (combined_transferred, combined_total) = + progress_tracker.get_total_progress().await; + let evt = DownloadEvent { + transferred: combined_transferred, + total: combined_total, + }; + app.emit(&evt_name, evt).unwrap(); + + resp + } + Err(e) => { + // fallback to normal download + log::warn!("Failed to resume download: {}", e); + should_resume = false; + _get_maybe_resume(&client, &item.url, 0).await? + } + } + } else { + _get_maybe_resume(&client, &item.url, 0).await? + }; + let mut stream = resp.bytes_stream(); + + let file = if should_resume { + // resume download, append to existing file + tokio::fs::OpenOptions::new() + .write(true) + .append(true) + .open(&tmp_save_path) + .await + .map_err(err_to_string)? + } else { + // start new download, create a new file + File::create(&tmp_save_path).await.map_err(err_to_string)? + }; + let mut writer = tokio::io::BufWriter::new(file); + let mut total_transferred = initial_progress; + + // write chunk to file + while let Some(chunk) = stream.next().await { + if cancel_token.is_cancelled() { + if !should_resume { + tokio::fs::remove_dir_all(&save_path.parent().unwrap()) + .await + .ok(); + } + log::info!("Download cancelled: {}", item.url); + return Err("Download cancelled".to_string()); + } + + let chunk = chunk.map_err(err_to_string)?; + writer.write_all(&chunk).await.map_err(err_to_string)?; + download_delta += chunk.len() as u64; + total_transferred += chunk.len() as u64; + + // Update progress every 10 MB + if download_delta >= 10 * 1024 * 1024 { + // Update individual file progress + progress_tracker + .update_progress(&file_id, total_transferred) + .await; + + // Emit combined progress event + let (combined_transferred, combined_total) = + progress_tracker.get_total_progress().await; + let evt = DownloadEvent { + transferred: combined_transferred, + total: combined_total, + }; + app.emit(&evt_name, evt).unwrap(); + + download_delta = 0u64; + } + } + + writer.flush().await.map_err(err_to_string)?; + + // Final progress update for this file + progress_tracker + .update_progress(&file_id, total_transferred) + .await; + + // Emit final combined progress + let (combined_transferred, combined_total) = progress_tracker.get_total_progress().await; + let evt = DownloadEvent { + transferred: combined_transferred, + total: combined_total, + }; + app.emit(&evt_name, evt).unwrap(); + + // rename tmp file to final file + tokio::fs::rename(&tmp_save_path, &save_path) + .await + .map_err(err_to_string)?; + tokio::fs::remove_file(&url_save_path) + .await + .map_err(err_to_string)?; + + log::info!("Finished downloading: {}", item.url); + Ok(save_path.to_path_buf()) +} + +// ===== HTTP CLIENT HELPER FUNCTIONS ===== + pub async fn _get_maybe_resume( client: &reqwest::Client, url: &str, diff --git a/src-tauri/src/core/downloads/models.rs b/src-tauri/src/core/downloads/models.rs index 61f438ec8..75a84e2b3 100644 --- a/src-tauri/src/core/downloads/models.rs +++ b/src-tauri/src/core/downloads/models.rs @@ -1,4 +1,6 @@ use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; #[derive(Default)] @@ -20,6 +22,8 @@ pub struct DownloadItem { pub url: String, pub save_path: String, pub proxy: Option, + pub sha256: Option, + pub size: Option, } #[derive(serde::Serialize, Clone, Debug)] @@ -27,3 +31,31 @@ pub struct DownloadEvent { pub transferred: u64, pub total: u64, } + +/// Structure to track progress for each file in parallel downloads +#[derive(Clone)] +pub struct ProgressTracker { + file_progress: Arc>>, + total_size: u64, +} + +impl ProgressTracker { + pub fn new(_items: &[DownloadItem], sizes: HashMap) -> Self { + let total_size = sizes.values().sum(); + ProgressTracker { + file_progress: Arc::new(Mutex::new(HashMap::new())), + total_size, + } + } + + pub async fn update_progress(&self, file_id: &str, transferred: u64) { + let mut progress = self.file_progress.lock().await; + progress.insert(file_id.to_string(), transferred); + } + + pub async fn get_total_progress(&self) -> (u64, u64) { + let progress = self.file_progress.lock().await; + let total_transferred: u64 = progress.values().sum(); + (total_transferred, self.total_size) + } +} diff --git a/src-tauri/src/core/downloads/tests.rs b/src-tauri/src/core/downloads/tests.rs index 42e690dba..8c3b14af5 100644 --- a/src-tauri/src/core/downloads/tests.rs +++ b/src-tauri/src/core/downloads/tests.rs @@ -194,6 +194,8 @@ fn test_download_item_with_ssl_proxy() { url: "https://example.com/file.zip".to_string(), save_path: "downloads/file.zip".to_string(), proxy: Some(proxy_config), + sha256: None, + size: None, }; assert!(download_item.proxy.is_some()); @@ -211,6 +213,8 @@ fn test_client_creation_with_ssl_settings() { url: "https://example.com/file.zip".to_string(), save_path: "downloads/file.zip".to_string(), proxy: Some(proxy_config), + sha256: None, + size: None, }; let header_map = HeaderMap::new(); @@ -256,6 +260,8 @@ fn test_download_item_creation() { url: "https://example.com/file.tar.gz".to_string(), save_path: "models/test.tar.gz".to_string(), proxy: None, + sha256: None, + size: None, }; assert_eq!(item.url, "https://example.com/file.tar.gz"); diff --git a/src-tauri/utils/Cargo.toml b/src-tauri/utils/Cargo.toml index 65f4dc8e1..09fc121e4 100644 --- a/src-tauri/utils/Cargo.toml +++ b/src-tauri/utils/Cargo.toml @@ -13,6 +13,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha2 = "0.10" tokio = { version = "1", features = ["process"] } +tokio-util = "0.7.14" url = "2.5" [features] diff --git a/src-tauri/utils/src/crypto.rs b/src-tauri/utils/src/crypto.rs index dcc1d541b..379b10290 100644 --- a/src-tauri/utils/src/crypto.rs +++ b/src-tauri/utils/src/crypto.rs @@ -1,7 +1,11 @@ use base64::{engine::general_purpose, Engine as _}; use hmac::{Hmac, Mac}; use rand::{distributions::Alphanumeric, Rng}; -use sha2::Sha256; +use sha2::{Digest, Sha256}; +use std::path::Path; +use tokio::fs::File; +use tokio::io::AsyncReadExt; +use tokio_util::sync::CancellationToken; type HmacSha256 = Hmac; @@ -24,3 +28,59 @@ pub fn generate_api_key(model_id: String, api_secret: String) -> Result Result { + // Check for cancellation before starting + if cancel_token.is_cancelled() { + return Err("Hash computation cancelled".to_string()); + } + + let mut file = File::open(file_path) + .await + .map_err(|e| format!("Failed to open file for hashing: {}", e))?; + + let mut hasher = Sha256::new(); + let mut buffer = vec![0u8; 64 * 1024]; // 64KB chunks + let mut total_read = 0u64; + + loop { + // Check for cancellation every chunk (every 64KB) + if cancel_token.is_cancelled() { + return Err("Hash computation cancelled".to_string()); + } + + let bytes_read = file + .read(&mut buffer) + .await + .map_err(|e| format!("Failed to read file for hashing: {}", e))?; + + if bytes_read == 0 { + break; // EOF + } + + hasher.update(&buffer[..bytes_read]); + total_read += bytes_read as u64; + + // Log progress for very large files (every 100MB) + if total_read % (100 * 1024 * 1024) == 0 { + #[cfg(feature = "logging")] + log::debug!("Hash progress: {} MB processed", total_read / (1024 * 1024)); + } + } + + // Final cancellation check + if cancel_token.is_cancelled() { + return Err("Hash computation cancelled".to_string()); + } + + let hash_bytes = hasher.finalize(); + let hash_hex = format!("{:x}", hash_bytes); + + #[cfg(feature = "logging")] + log::debug!("Hash computation completed for {} bytes", total_read); + Ok(hash_hex) +} diff --git a/web-app/src/containers/DownloadManegement.tsx b/web-app/src/containers/DownloadManegement.tsx index 6044a9c80..5557b2741 100644 --- a/web-app/src/containers/DownloadManegement.tsx +++ b/web-app/src/containers/DownloadManegement.tsx @@ -168,9 +168,46 @@ export function DownloadManagement() { [removeDownload, removeLocalDownloadingModel, t] ) + const onModelValidationStarted = useCallback( + (event: { modelId: string; downloadType: string }) => { + console.debug('onModelValidationStarted', event) + + // Show validation in progress toast + toast.info(t('common:toast.modelValidationStarted.title'), { + id: `model-validation-started-${event.modelId}`, + description: t('common:toast.modelValidationStarted.description', { + modelId: event.modelId, + }), + duration: 10000, + }) + }, + [t] + ) + + const onModelValidationFailed = useCallback( + (event: { modelId: string; error: string; reason: string }) => { + console.debug('onModelValidationFailed', event) + + // Dismiss the validation started toast + toast.dismiss(`model-validation-started-${event.modelId}`) + + removeDownload(event.modelId) + removeLocalDownloadingModel(event.modelId) + + // Show specific toast for validation failure + toast.error(t('common:toast.modelValidationFailed.title'), { + description: t('common:toast.modelValidationFailed.description', { + modelId: event.modelId, + }), + duration: 30000, // Requires manual dismissal for security-critical message + }) + }, + [removeDownload, removeLocalDownloadingModel, t] + ) + const onFileDownloadStopped = useCallback( (state: DownloadState) => { - console.debug('onFileDownloadError', state) + console.debug('onFileDownloadStopped', state) removeDownload(state.modelId) removeLocalDownloadingModel(state.modelId) }, @@ -180,6 +217,10 @@ export function DownloadManagement() { const onFileDownloadSuccess = useCallback( async (state: DownloadState) => { console.debug('onFileDownloadSuccess', state) + + // Dismiss any validation started toast when download completes successfully + toast.dismiss(`model-validation-started-${state.modelId}`) + removeDownload(state.modelId) removeLocalDownloadingModel(state.modelId) toast.success(t('common:toast.downloadComplete.title'), { @@ -192,12 +233,34 @@ export function DownloadManagement() { [removeDownload, removeLocalDownloadingModel, t] ) + const onFileDownloadAndVerificationSuccess = useCallback( + async (state: DownloadState) => { + console.debug('onFileDownloadAndVerificationSuccess', state) + + // Dismiss any validation started toast when download and verification complete successfully + toast.dismiss(`model-validation-started-${state.modelId}`) + + removeDownload(state.modelId) + removeLocalDownloadingModel(state.modelId) + toast.success(t('common:toast.downloadAndVerificationComplete.title'), { + id: 'download-complete', + description: t('common:toast.downloadAndVerificationComplete.description', { + item: state.modelId, + }), + }) + }, + [removeDownload, removeLocalDownloadingModel, t] + ) + useEffect(() => { console.debug('DownloadListener: registering event listeners...') events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate) events.on(DownloadEvent.onFileDownloadError, onFileDownloadError) events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) events.on(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped) + events.on(DownloadEvent.onModelValidationStarted, onModelValidationStarted) + events.on(DownloadEvent.onModelValidationFailed, onModelValidationFailed) + events.on(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess) // Register app update event listeners events.on(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate) @@ -210,6 +273,12 @@ export function DownloadManagement() { events.off(DownloadEvent.onFileDownloadError, onFileDownloadError) events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped) + events.off( + DownloadEvent.onModelValidationStarted, + onModelValidationStarted + ) + events.off(DownloadEvent.onModelValidationFailed, onModelValidationFailed) + events.off(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess) // Unregister app update event listeners events.off(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate) @@ -224,6 +293,9 @@ export function DownloadManagement() { onFileDownloadError, onFileDownloadSuccess, onFileDownloadStopped, + onModelValidationStarted, + onModelValidationFailed, + onFileDownloadAndVerificationSuccess, onAppUpdateDownloadUpdate, onAppUpdateDownloadSuccess, onAppUpdateDownloadError, diff --git a/web-app/src/locales/de-DE/common.json b/web-app/src/locales/de-DE/common.json index f418070c8..4a7cd2b01 100644 --- a/web-app/src/locales/de-DE/common.json +++ b/web-app/src/locales/de-DE/common.json @@ -256,6 +256,22 @@ "downloadCancelled": { "title": "Download abgebrochen", "description": "Der Download-Prozess wurde abgebrochen" + }, + "downloadFailed": { + "title": "Download fehlgeschlagen", + "description": "{{item}} Download fehlgeschlagen" + }, + "modelValidationStarted": { + "title": "Modell wird validiert", + "description": "Modell \"{{modelId}}\" erfolgreich heruntergeladen. Integrität wird überprüft..." + }, + "modelValidationFailed": { + "title": "Modellvalidierung fehlgeschlagen", + "description": "Das heruntergeladene Modell \"{{modelId}}\" ist bei der Integritätsprüfung fehlgeschlagen und wurde entfernt. Die Datei könnte beschädigt oder manipuliert worden sein." + }, + "downloadAndVerificationComplete": { + "title": "Download abgeschlossen", + "description": "Modell \"{{item}}\" erfolgreich heruntergeladen und verifiziert" } } } diff --git a/web-app/src/locales/de-DE/hub.json b/web-app/src/locales/de-DE/hub.json index 4fd86b6a6..18b3f3b48 100644 --- a/web-app/src/locales/de-DE/hub.json +++ b/web-app/src/locales/de-DE/hub.json @@ -12,6 +12,7 @@ "showVariants": "Zeige Varianten", "useModel": "Nutze dieses Modell", "downloadModel": "Modell herunterladen", + "tools": "Werkzeuge", "searchPlaceholder": "Suche nach Modellen auf Hugging Face...", "editTheme": "Bearbeite Erscheinungsbild", "joyride": { diff --git a/web-app/src/locales/en/common.json b/web-app/src/locales/en/common.json index 4eaf498c5..46f2d5a8a 100644 --- a/web-app/src/locales/en/common.json +++ b/web-app/src/locales/en/common.json @@ -261,6 +261,18 @@ "downloadFailed": { "title": "Download Failed", "description": "{{item}} download failed" + }, + "modelValidationStarted": { + "title": "Validating Model", + "description": "Downloaded model \"{{modelId}}\" successfully. Verifying integrity..." + }, + "modelValidationFailed": { + "title": "Model Validation Failed", + "description": "The downloaded model \"{{modelId}}\" failed integrity verification and was removed. The file may be corrupted or tampered with." + }, + "downloadAndVerificationComplete": { + "title": "Download Complete", + "description": "Model \"{{item}}\" downloaded and verified successfully" } } } \ No newline at end of file diff --git a/web-app/src/locales/en/hub.json b/web-app/src/locales/en/hub.json index e082c05b5..4855ec868 100644 --- a/web-app/src/locales/en/hub.json +++ b/web-app/src/locales/en/hub.json @@ -12,6 +12,7 @@ "showVariants": "Show variants", "useModel": "Use this model", "downloadModel": "Download model", + "tools": "Tools", "searchPlaceholder": "Search for models on Hugging Face...", "joyride": { "recommendedModelTitle": "Recommended Model", diff --git a/web-app/src/locales/id/common.json b/web-app/src/locales/id/common.json index 4433488d0..03f526bed 100644 --- a/web-app/src/locales/id/common.json +++ b/web-app/src/locales/id/common.json @@ -249,6 +249,22 @@ "downloadCancelled": { "title": "Unduhan Dibatalkan", "description": "Proses unduhan telah dibatalkan" + }, + "downloadFailed": { + "title": "Unduhan Gagal", + "description": "Unduhan {{item}} gagal" + }, + "modelValidationStarted": { + "title": "Memvalidasi Model", + "description": "Model \"{{modelId}}\" berhasil diunduh. Memverifikasi integritas..." + }, + "modelValidationFailed": { + "title": "Validasi Model Gagal", + "description": "Model yang diunduh \"{{modelId}}\" gagal verifikasi integritas dan telah dihapus. File mungkin rusak atau telah dimanipulasi." + }, + "downloadAndVerificationComplete": { + "title": "Unduhan Selesai", + "description": "Model \"{{item}}\" berhasil diunduh dan diverifikasi" } } } diff --git a/web-app/src/locales/id/hub.json b/web-app/src/locales/id/hub.json index 5aa1e7d1c..bdecd9533 100644 --- a/web-app/src/locales/id/hub.json +++ b/web-app/src/locales/id/hub.json @@ -12,6 +12,7 @@ "showVariants": "Tampilkan Varian", "useModel": "Gunakan model ini", "downloadModel": "Unduh model", + "tools": "Alat", "searchPlaceholder": "Cari model di Hugging Face...", "joyride": { "recommendedModelTitle": "Model yang Direkomendasikan", diff --git a/web-app/src/locales/vn/common.json b/web-app/src/locales/vn/common.json index 06974aedc..9bc2b25f0 100644 --- a/web-app/src/locales/vn/common.json +++ b/web-app/src/locales/vn/common.json @@ -249,6 +249,22 @@ "downloadCancelled": { "title": "Đã hủy tải xuống", "description": "Quá trình tải xuống đã bị hủy" + }, + "downloadFailed": { + "title": "Tải xuống thất bại", + "description": "Tải xuống {{item}} thất bại" + }, + "modelValidationStarted": { + "title": "Đang xác thực mô hình", + "description": "Đã tải xuống mô hình \"{{modelId}}\" thành công. Đang xác minh tính toàn vẹn..." + }, + "modelValidationFailed": { + "title": "Xác thực mô hình thất bại", + "description": "Mô hình đã tải xuống \"{{modelId}}\" không vượt qua kiểm tra tính toàn vẹn và đã bị xóa. Tệp có thể bị hỏng hoặc bị giả mạo." + }, + "downloadAndVerificationComplete": { + "title": "Tải xuống hoàn tất", + "description": "Mô hình \"{{item}}\" đã được tải xuống và xác minh thành công" } } } diff --git a/web-app/src/locales/vn/hub.json b/web-app/src/locales/vn/hub.json index 8b38d84cc..00ec4f06d 100644 --- a/web-app/src/locales/vn/hub.json +++ b/web-app/src/locales/vn/hub.json @@ -12,6 +12,7 @@ "showVariants": "Hiển thị biến thể", "useModel": "Sử dụng mô hình này", "downloadModel": "Tải xuống mô hình", + "tools": "Công cụ", "searchPlaceholder": "Tìm kiếm các mô hình trên Hugging Face...", "joyride": { "recommendedModelTitle": "Mô hình được đề xuất", diff --git a/web-app/src/locales/zh-CN/common.json b/web-app/src/locales/zh-CN/common.json index 34af5ae95..a783d0f14 100644 --- a/web-app/src/locales/zh-CN/common.json +++ b/web-app/src/locales/zh-CN/common.json @@ -249,6 +249,22 @@ "downloadCancelled": { "title": "下载已取消", "description": "下载过程已取消" + }, + "downloadFailed": { + "title": "下载失败", + "description": "{{item}} 下载失败" + }, + "modelValidationStarted": { + "title": "正在验证模型", + "description": "模型 \"{{modelId}}\" 下载成功。正在验证完整性..." + }, + "modelValidationFailed": { + "title": "模型验证失败", + "description": "已下载的模型 \"{{modelId}}\" 未通过完整性验证并已被删除。文件可能损坏或被篡改。" + }, + "downloadAndVerificationComplete": { + "title": "下载完成", + "description": "模型 \"{{item}}\" 下载并验证成功" } } } diff --git a/web-app/src/locales/zh-CN/hub.json b/web-app/src/locales/zh-CN/hub.json index dc005611a..234107a2b 100644 --- a/web-app/src/locales/zh-CN/hub.json +++ b/web-app/src/locales/zh-CN/hub.json @@ -12,6 +12,7 @@ "showVariants": "显示变体", "useModel": "使用此模型", "downloadModel": "下载模型", + "tools": "工具", "searchPlaceholder": "在 Hugging Face 上搜索模型...", "joyride": { "recommendedModelTitle": "推荐模型", diff --git a/web-app/src/locales/zh-TW/common.json b/web-app/src/locales/zh-TW/common.json index 485c41369..055819646 100644 --- a/web-app/src/locales/zh-TW/common.json +++ b/web-app/src/locales/zh-TW/common.json @@ -249,6 +249,22 @@ "downloadCancelled": { "title": "下載已取消", "description": "下載過程已取消" + }, + "downloadFailed": { + "title": "下載失敗", + "description": "{{item}} 下載失敗" + }, + "modelValidationStarted": { + "title": "正在驗證模型", + "description": "模型 \"{{modelId}}\" 下載成功。正在驗證完整性..." + }, + "modelValidationFailed": { + "title": "模型驗證失敗", + "description": "已下載的模型 \"{{modelId}}\" 未通過完整性驗證並已被刪除。檔案可能損壞或被篡改。" + }, + "downloadAndVerificationComplete": { + "title": "下載完成", + "description": "模型 \"{{item}}\" 下載並驗證成功" } } } diff --git a/web-app/src/locales/zh-TW/hub.json b/web-app/src/locales/zh-TW/hub.json index f35a4485a..0781f1f7a 100644 --- a/web-app/src/locales/zh-TW/hub.json +++ b/web-app/src/locales/zh-TW/hub.json @@ -12,6 +12,7 @@ "showVariants": "顯示變體", "useModel": "使用此模型", "downloadModel": "下載模型", + "tools": "工具", "searchPlaceholder": "在 Hugging Face 上搜尋模型...", "joyride": { "recommendedModelTitle": "推薦模型", diff --git a/web-app/src/routes/hub/$modelId.tsx b/web-app/src/routes/hub/$modelId.tsx index f34057ae4..e5f2f44b3 100644 --- a/web-app/src/routes/hub/$modelId.tsx +++ b/web-app/src/routes/hub/$modelId.tsx @@ -22,7 +22,7 @@ import { CatalogModel, convertHfRepoToCatalogModel, fetchHuggingFaceRepo, - pullModel, + pullModelWithMetadata, } from '@/services/models' import { Progress } from '@/components/ui/progress' import { Button } from '@/components/ui/button' @@ -408,9 +408,11 @@ function HubModelDetail() { addLocalDownloadingModel( variant.model_id ) - pullModel( + pullModelWithMetadata( variant.model_id, - variant.path + variant.path, + modelData.mmproj_models?.[0]?.path, + huggingfaceToken ) }} className={cn(isDownloading && 'hidden')} diff --git a/web-app/src/routes/hub/index.tsx b/web-app/src/routes/hub/index.tsx index 2e5db8ba2..93658816b 100644 --- a/web-app/src/routes/hub/index.tsx +++ b/web-app/src/routes/hub/index.tsx @@ -41,7 +41,7 @@ import { } from '@/components/ui/dropdown-menu' import { CatalogModel, - pullModel, + pullModelWithMetadata, fetchHuggingFaceRepo, convertHfRepoToCatalogModel, } from '@/services/models' @@ -313,7 +313,12 @@ function Hub() { // Immediately set local downloading state addLocalDownloadingModel(modelId) const mmprojPath = model.mmproj_models?.[0]?.path - pullModel(modelId, modelUrl, mmprojPath) + pullModelWithMetadata( + modelId, + modelUrl, + mmprojPath, + huggingfaceToken + ) } return ( @@ -812,12 +817,13 @@ function Hub() { addLocalDownloadingModel( variant.model_id ) - pullModel( + pullModelWithMetadata( variant.model_id, variant.path, filteredModels[ virtualItem.index - ].mmproj_models?.[0]?.path + ].mmproj_models?.[0]?.path, + huggingfaceToken ) }} > diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index dc30dc54f..368fd19be 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -325,7 +325,7 @@ describe('models service', () => { expect(result).toEqual(mockRepoData) expect(fetch).toHaveBeenCalledWith( - 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true', + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true', { headers: {}, } @@ -344,7 +344,7 @@ describe('models service', () => { 'https://huggingface.co/microsoft/DialoGPT-medium' ) expect(fetch).toHaveBeenCalledWith( - 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true', + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true', { headers: {}, } @@ -353,7 +353,7 @@ describe('models service', () => { // Test with domain prefix await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium') expect(fetch).toHaveBeenCalledWith( - 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true', + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true', { headers: {}, } @@ -362,7 +362,7 @@ describe('models service', () => { // Test with trailing slash await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/') expect(fetch).toHaveBeenCalledWith( - 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true', + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true', { headers: {}, } @@ -391,7 +391,7 @@ describe('models service', () => { expect(result).toBeNull() expect(fetch).toHaveBeenCalledWith( - 'https://huggingface.co/api/models/nonexistent/model?blobs=true', + 'https://huggingface.co/api/models/nonexistent/model?blobs=true&files_metadata=true', { headers: {}, } diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index 790620f22..0edfe165a 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -62,6 +62,11 @@ export interface HuggingFaceRepo { rfilename: string size?: number blobId?: string + lfs?: { + sha256: string + size: number + pointerSize: number + } }> readme?: string } @@ -126,7 +131,7 @@ export const fetchHuggingFaceRepo = async ( } const response = await fetch( - `https://huggingface.co/api/models/${cleanRepoId}?blobs=true`, + `https://huggingface.co/api/models/${cleanRepoId}?blobs=true&files_metadata=true`, { headers: hfToken ? { @@ -237,14 +242,103 @@ export const updateModel = async ( export const pullModel = async ( id: string, modelPath: string, - mmprojPath?: string + modelSha256?: string, + modelSize?: number, + mmprojPath?: string, + mmprojSha256?: string, + mmprojSize?: number ) => { return getEngine()?.import(id, { modelPath, mmprojPath, + modelSha256, + modelSize, + mmprojSha256, + mmprojSize, }) } +/** + * Pull a model with real-time metadata fetching from HuggingFace. + * Extracts hash and size information from the model URL for both main model and mmproj files. + * @param id The model ID + * @param modelPath The model file URL (HuggingFace download URL) + * @param mmprojPath Optional mmproj file URL + * @param hfToken Optional HuggingFace token for authentication + * @returns A promise that resolves when the model download task is created. + */ +export const pullModelWithMetadata = async ( + id: string, + modelPath: string, + mmprojPath?: string, + hfToken?: string +) => { + let modelSha256: string | undefined + let modelSize: number | undefined + let mmprojSha256: string | undefined + let mmprojSize: number | undefined + + // Extract repo ID from model URL + // URL format: https://huggingface.co/{repo}/resolve/main/{filename} + const modelUrlMatch = modelPath.match( + /https:\/\/huggingface\.co\/([^/]+\/[^/]+)\/resolve\/main\/(.+)/ + ) + + if (modelUrlMatch) { + const [, repoId, modelFilename] = modelUrlMatch + + try { + // Fetch real-time metadata from HuggingFace + const repoInfo = await fetchHuggingFaceRepo(repoId, hfToken) + + if (repoInfo?.siblings) { + // Find the specific model file + const modelFile = repoInfo.siblings.find( + (file) => file.rfilename === modelFilename + ) + if (modelFile?.lfs) { + modelSha256 = modelFile.lfs.sha256 + modelSize = modelFile.lfs.size + } + + // If mmproj path provided, extract its metadata too + if (mmprojPath) { + const mmprojUrlMatch = mmprojPath.match( + /https:\/\/huggingface\.co\/[^/]+\/[^/]+\/resolve\/main\/(.+)/ + ) + if (mmprojUrlMatch) { + const [, mmprojFilename] = mmprojUrlMatch + const mmprojFile = repoInfo.siblings.find( + (file) => file.rfilename === mmprojFilename + ) + if (mmprojFile?.lfs) { + mmprojSha256 = mmprojFile.lfs.sha256 + mmprojSize = mmprojFile.lfs.size + } + } + } + } + } catch (error) { + console.warn( + 'Failed to fetch HuggingFace metadata, proceeding without hash verification:', + error + ) + // Continue with download even if metadata fetch fails + } + } + + // Call the original pullModel with the fetched metadata + return pullModel( + id, + modelPath, + modelSha256, + modelSize, + mmprojPath, + mmprojSha256, + mmprojSize + ) +} + /** * Aborts a model download. * @param id