feat: gguf file size + hash validation (#5266) (#6259)

* feat: gguf file size + hash validation

* fix tests fe

* update cargo tests

* handle asyn download for both models and mmproj

* move progress tracker to models

* handle file download cancelled

* add cancellation mid hash run
This commit is contained in:
Dinh Long Nguyen 2025-08-21 16:17:58 +07:00 committed by GitHub
parent 41b4cc3bb3
commit 32a2ca95b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 915 additions and 153 deletions

View File

@ -194,6 +194,10 @@ export interface chatOptions {
export interface ImportOptions {
modelPath: string
mmprojPath?: string
modelSha256?: string
modelSize?: number
mmprojSha256?: string
mmprojSize?: number
}
export interface importResult {

View File

@ -73,6 +73,9 @@ export enum DownloadEvent {
onFileDownloadSuccess = 'onFileDownloadSuccess',
onFileDownloadStopped = 'onFileDownloadStopped',
onFileDownloadStarted = 'onFileDownloadStarted',
onModelValidationStarted = 'onModelValidationStarted',
onModelValidationFailed = 'onModelValidationFailed',
onFileDownloadAndVerificationSuccess = 'onFileDownloadAndVerificationSuccess',
}
export enum ExtensionRoute {
baseExtensions = 'baseExtensions',

View File

@ -10,6 +10,8 @@ interface DownloadItem {
url: string
save_path: string
proxy?: Record<string, string | string[] | boolean>
sha256?: string
size?: number
}
type DownloadEvent = {

View File

@ -20,9 +20,11 @@ import {
chatCompletionRequest,
events,
AppEvent,
DownloadEvent,
} from '@janhq/core'
import { error, info, warn } from '@tauri-apps/plugin-log'
import { listen } from '@tauri-apps/api/event'
import {
listSupportedBackends,
@ -71,6 +73,8 @@ interface DownloadItem {
url: string
save_path: string
proxy?: Record<string, string | string[] | boolean>
sha256?: string
size?: number
}
interface ModelConfig {
@ -79,6 +83,9 @@ interface ModelConfig {
name: string // user-friendly
// some model info that we cache upon import
size_bytes: number
sha256?: string
mmproj_sha256?: string
mmproj_size_bytes?: number
}
interface EmbeddingResponse {
@ -154,6 +161,7 @@ export default class llamacpp_extension extends AIEngine {
private pendingDownloads: Map<string, Promise<void>> = new Map()
private isConfiguringBackends: boolean = false
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
private unlistenValidationStarted?: () => void
override async onLoad(): Promise<void> {
super.onLoad() // Calls registerEngine() from AIEngine
@ -181,6 +189,19 @@ export default class llamacpp_extension extends AIEngine {
await getJanDataFolderPath(),
this.providerId,
])
// Set up validation event listeners to bridge Tauri events to frontend
this.unlistenValidationStarted = await listen<{
modelId: string
downloadType: string
}>('onModelValidationStarted', (event) => {
console.debug(
'LlamaCPP: bridging onModelValidationStarted event',
event.payload
)
events.emit(DownloadEvent.onModelValidationStarted, event.payload)
})
this.configureBackends()
}
@ -774,6 +795,11 @@ export default class llamacpp_extension extends AIEngine {
override async onUnload(): Promise<void> {
// Terminate all active sessions
// Clean up validation event listeners
if (this.unlistenValidationStarted) {
this.unlistenValidationStarted()
}
}
onSettingUpdate<T>(key: string, value: T): void {
@ -1006,6 +1032,9 @@ export default class llamacpp_extension extends AIEngine {
url: path,
save_path: localPath,
proxy: getProxyConfig(),
sha256:
saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256,
size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize,
})
return localPath
}
@ -1023,8 +1052,6 @@ export default class llamacpp_extension extends AIEngine {
: undefined
if (downloadItems.length > 0) {
let downloadCompleted = false
try {
// emit download update event on progress
const onProgress = (transferred: number, total: number) => {
@ -1034,7 +1061,6 @@ export default class llamacpp_extension extends AIEngine {
size: { transferred, total },
downloadType: 'Model',
})
downloadCompleted = transferred === total
}
const downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
@ -1045,13 +1071,67 @@ export default class llamacpp_extension extends AIEngine {
onProgress
)
const eventName = downloadCompleted
? 'onFileDownloadSuccess'
: 'onFileDownloadStopped'
events.emit(eventName, { modelId, downloadType: 'Model' })
// If we reach here, download completed successfully (including validation)
// The downloadFiles function only returns successfully if all files downloaded AND validated
events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, {
modelId,
downloadType: 'Model'
})
} catch (error) {
logger.error('Error downloading model:', modelId, opts, error)
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
const errorMessage =
error instanceof Error ? error.message : String(error)
// Check if this is a cancellation
const isCancellationError = errorMessage.includes('Download cancelled') ||
errorMessage.includes('Validation cancelled') ||
errorMessage.includes('Hash computation cancelled') ||
errorMessage.includes('cancelled') ||
errorMessage.includes('aborted')
// Check if this is a validation failure
const isValidationError =
errorMessage.includes('Hash verification failed') ||
errorMessage.includes('Size verification failed') ||
errorMessage.includes('Failed to verify file')
if (isCancellationError) {
logger.info('Download cancelled for model:', modelId)
// Emit download stopped event instead of error
events.emit(DownloadEvent.onFileDownloadStopped, {
modelId,
downloadType: 'Model',
})
} else if (isValidationError) {
logger.error(
'Validation failed for model:',
modelId,
'Error:',
errorMessage
)
// Cancel any other download tasks for this model
try {
this.abortImport(modelId)
} catch (cancelError) {
logger.warn('Failed to cancel download task:', cancelError)
}
// Emit validation failure event
events.emit(DownloadEvent.onModelValidationFailed, {
modelId,
downloadType: 'Model',
error: errorMessage,
reason: 'validation_failed',
})
} else {
// Regular download error
events.emit(DownloadEvent.onFileDownloadError, {
modelId,
downloadType: 'Model',
error: errorMessage,
})
}
throw error
}
}
@ -1078,7 +1158,9 @@ export default class llamacpp_extension extends AIEngine {
} catch (error) {
logger.error('GGUF validation failed:', error)
throw new Error(
`Invalid GGUF file(s): ${error.message || 'File format validation failed'}`
`Invalid GGUF file(s): ${
error.message || 'File format validation failed'
}`
)
}
@ -1097,6 +1179,10 @@ export default class llamacpp_extension extends AIEngine {
mmproj_path: mmprojPath,
name: modelId,
size_bytes,
model_sha256: opts.modelSha256,
model_size_bytes: opts.modelSize,
mmproj_sha256: opts.mmprojSha256,
mmproj_size_bytes: opts.mmprojSize,
} as ModelConfig
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
await invoke<void>('write_yaml', {
@ -1108,16 +1194,50 @@ export default class llamacpp_extension extends AIEngine {
modelPath,
mmprojPath,
size_bytes,
model_sha256: opts.modelSha256,
model_size_bytes: opts.modelSize,
mmproj_sha256: opts.mmprojSha256,
mmproj_size_bytes: opts.mmprojSize,
})
}
/**
* Deletes the entire model folder for a given modelId
* @param modelId The model ID to delete
*/
private async deleteModelFolder(modelId: string): Promise<void> {
try {
const modelDir = await joinPath([
await this.getProviderPath(),
'models',
modelId,
])
if (await fs.existsSync(modelDir)) {
logger.info(`Cleaning up model directory: ${modelDir}`)
await fs.rm(modelDir)
}
} catch (deleteError) {
logger.warn('Failed to delete model directory:', deleteError)
}
}
override async abortImport(modelId: string): Promise<void> {
// prepand provider name to avoid name collision
// Cancel any active download task
// prepend provider name to avoid name collision
const taskId = this.createDownloadTaskId(modelId)
const downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
await downloadManager.cancelDownload(taskId)
try {
await downloadManager.cancelDownload(taskId)
} catch (cancelError) {
logger.warn('Failed to cancel download task:', cancelError)
}
// Delete the entire model folder if it exists (for validation failures)
await this.deleteModelFolder(modelId)
}
/**

11
src-tauri/Cargo.lock generated
View File

@ -2323,6 +2323,7 @@ dependencies = [
"serde_json",
"sha2",
"tokio",
"tokio-util",
"url",
]
@ -4019,8 +4020,9 @@ dependencies = [
[[package]]
name = "rmcp"
version = "0.5.0"
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb21cd3555f1059f27e4813827338dec44429a08ecd0011acc41d9907b160c00"
dependencies = [
"base64 0.22.1",
"chrono",
@ -4045,8 +4047,9 @@ dependencies = [
[[package]]
name = "rmcp-macros"
version = "0.5.0"
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5d16ae1ff3ce2c5fd86c37047b2869b75bec795d53a4b1d8257b15415a2354"
dependencies = [
"darling 0.21.2",
"proc-macro2",

View File

@ -1,9 +1,10 @@
use super::models::{DownloadEvent, DownloadItem, ProxyConfig};
use super::models::{DownloadEvent, DownloadItem, ProxyConfig, ProgressTracker};
use crate::core::app::commands::get_jan_data_folder_path;
use futures_util::StreamExt;
use jan_utils::normalize_path;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use tauri::Emitter;
use tokio::fs::File;
@ -11,10 +12,131 @@ use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use url::Url;
// ===== UTILITY FUNCTIONS =====
pub fn err_to_string<E: std::fmt::Display>(e: E) -> String {
format!("Error: {}", e)
}
// ===== VALIDATION FUNCTIONS =====
/// Validates a downloaded file against expected hash and size
async fn validate_downloaded_file(
item: &DownloadItem,
save_path: &Path,
app: &tauri::AppHandle,
cancel_token: &CancellationToken,
) -> Result<(), String> {
// Skip validation if no verification data is provided
if item.sha256.is_none() && item.size.is_none() {
log::debug!(
"No validation data provided for {}, skipping validation",
item.url
);
return Ok(());
}
// Extract model ID from save path for validation events
// Path structure: llamacpp/models/{modelId}/model.gguf or llamacpp/models/{modelId}/mmproj.gguf
let model_id = save_path
.parent() // get parent directory (modelId folder)
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("unknown");
// Emit validation started event
app.emit(
"onModelValidationStarted",
serde_json::json!({
"modelId": model_id,
"downloadType": "Model",
}),
)
.unwrap();
log::info!("Starting validation for model: {}", model_id);
// Validate size if provided (fast check first)
if let Some(expected_size) = &item.size {
log::info!("Starting size verification for {}", item.url);
match tokio::fs::metadata(save_path).await {
Ok(metadata) => {
let actual_size = metadata.len();
if actual_size != *expected_size {
log::error!(
"Size verification failed for {}. Expected: {} bytes, Actual: {} bytes",
item.url,
expected_size,
actual_size
);
return Err(format!(
"Size verification failed. Expected {} bytes but got {} bytes.",
expected_size, actual_size
));
}
log::info!(
"Size verification successful for {} ({} bytes)",
item.url,
actual_size
);
}
Err(e) => {
log::error!(
"Failed to get file metadata for {}: {}",
save_path.display(),
e
);
return Err(format!("Failed to verify file size: {}", e));
}
}
}
// Check for cancellation before expensive hash computation
if cancel_token.is_cancelled() {
log::info!("Validation cancelled for {}", item.url);
return Err("Validation cancelled".to_string());
}
// Validate hash if provided (expensive check second)
if let Some(expected_sha256) = &item.sha256 {
log::info!("Starting Hash verification for {}", item.url);
match jan_utils::crypto::compute_file_sha256_with_cancellation(save_path, cancel_token).await {
Ok(computed_sha256) => {
if computed_sha256 != *expected_sha256 {
log::error!(
"Hash verification failed for {}. Expected: {}, Computed: {}",
item.url,
expected_sha256,
computed_sha256
);
return Err(format!(
"Hash verification failed. The downloaded file is corrupted or has been tampered with."
));
}
log::info!("Hash verification successful for {}", item.url);
}
Err(e) => {
log::error!(
"Failed to compute SHA256 for {}: {}",
save_path.display(),
e
);
return Err(format!("Failed to verify file integrity: {}", e));
}
}
}
log::info!("All validations passed for {}", item.url);
Ok(())
}
pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> {
// Validate proxy URL format
if let Err(e) = Url::parse(&config.url) {
@ -172,6 +294,9 @@ pub async fn _get_file_size(
}
}
// ===== MAIN DOWNLOAD FUNCTIONS =====
/// Downloads multiple files in parallel with individual progress tracking
pub async fn _download_files_internal(
app: tauri::AppHandle,
items: &[DownloadItem],
@ -184,28 +309,31 @@ pub async fn _download_files_internal(
let header_map = _convert_headers(headers).map_err(err_to_string)?;
let total_size = {
let mut total_size = 0u64;
for item in items.iter() {
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
total_size += _get_file_size(&client, &item.url)
.await
.map_err(err_to_string)?;
}
total_size
};
// Calculate sizes for each file
let mut file_sizes = HashMap::new();
for item in items.iter() {
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
let size = _get_file_size(&client, &item.url)
.await
.map_err(err_to_string)?;
file_sizes.insert(item.url.clone(), size);
}
let total_size: u64 = file_sizes.values().sum();
log::info!("Total download size: {}", total_size);
let mut evt = DownloadEvent {
transferred: 0,
total: total_size,
};
let evt_name = format!("download-{}", task_id);
// Create progress tracker
let progress_tracker = ProgressTracker::new(items, file_sizes.clone());
// save file under Jan data folder
let jan_data_folder = get_jan_data_folder_path(app.clone());
for item in items.iter() {
// Collect download tasks for parallel execution
let mut download_tasks = Vec::new();
for (index, item) in items.iter().enumerate() {
let save_path = jan_data_folder.join(&item.save_path);
let save_path = normalize_path(&save_path);
@ -217,120 +345,251 @@ pub async fn _download_files_internal(
));
}
// Create parent directories if they don't exist
if let Some(parent) = save_path.parent() {
if !parent.exists() {
tokio::fs::create_dir_all(parent)
.await
.map_err(err_to_string)?;
}
}
// Spawn download task for each file
let item_clone = item.clone();
let app_clone = app.clone();
let header_map_clone = header_map.clone();
let cancel_token_clone = cancel_token.clone();
let evt_name_clone = evt_name.clone();
let progress_tracker_clone = progress_tracker.clone();
let file_id = format!("{}-{}", task_id, index);
let file_size = file_sizes.get(&item.url).copied().unwrap_or(0);
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
let append_extension = |ext: &str| {
if current_extension.is_empty() {
ext.to_string()
} else {
format!("{}.{}", current_extension, ext)
}
};
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
let url_save_path = save_path.with_extension(append_extension("url"));
let mut should_resume = resume
&& tmp_save_path.exists()
&& tokio::fs::read_to_string(&url_save_path)
.await
.map(|url| url == item.url) // check if we resume the same URL
.unwrap_or(false);
tokio::fs::write(&url_save_path, item.url.clone())
let task = tokio::spawn(async move {
download_single_file(
app_clone,
&item_clone,
&header_map_clone,
&save_path,
resume,
cancel_token_clone,
evt_name_clone,
progress_tracker_clone,
file_id,
file_size,
)
.await
.map_err(err_to_string)?;
});
log::info!("Started downloading: {}", item.url);
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
let mut download_delta = 0u64;
let resp = if should_resume {
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
Ok(resp) => {
log::info!(
"Resume download: {}, already downloaded {} bytes",
item.url,
downloaded_size
);
download_delta += downloaded_size;
resp
}
Err(e) => {
// fallback to normal download
log::warn!("Failed to resume download: {}", e);
should_resume = false;
_get_maybe_resume(&client, &item.url, 0).await?
}
}
} else {
_get_maybe_resume(&client, &item.url, 0).await?
};
let mut stream = resp.bytes_stream();
let file = if should_resume {
// resume download, append to existing file
tokio::fs::OpenOptions::new()
.write(true)
.append(true)
.open(&tmp_save_path)
.await
.map_err(err_to_string)?
} else {
// start new download, create a new file
File::create(&tmp_save_path).await.map_err(err_to_string)?
};
let mut writer = tokio::io::BufWriter::new(file);
// write chunk to file
while let Some(chunk) = stream.next().await {
if cancel_token.is_cancelled() {
if !should_resume {
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
.await
.ok();
}
log::info!("Download cancelled for task: {}", task_id);
app.emit(&evt_name, evt.clone()).unwrap();
return Ok(());
}
let chunk = chunk.map_err(err_to_string)?;
writer.write_all(&chunk).await.map_err(err_to_string)?;
download_delta += chunk.len() as u64;
// only update every 10 MB
if download_delta >= 10 * 1024 * 1024 {
evt.transferred += download_delta;
app.emit(&evt_name, evt.clone()).unwrap();
download_delta = 0u64;
}
}
writer.flush().await.map_err(err_to_string)?;
evt.transferred += download_delta;
// rename tmp file to final file
tokio::fs::rename(&tmp_save_path, &save_path)
.await
.map_err(err_to_string)?;
tokio::fs::remove_file(&url_save_path)
.await
.map_err(err_to_string)?;
log::info!("Finished downloading: {}", item.url);
download_tasks.push(task);
}
app.emit(&evt_name, evt.clone()).unwrap();
// Wait for all downloads to complete
let mut validation_tasks = Vec::new();
for (task, item) in download_tasks.into_iter().zip(items.iter()) {
let result = task.await.map_err(|e| format!("Task join error: {}", e))?;
match result {
Ok(downloaded_path) => {
// Spawn validation task in parallel
let item_clone = item.clone();
let app_clone = app.clone();
let path_clone = downloaded_path.clone();
let cancel_token_clone = cancel_token.clone();
let validation_task = tokio::spawn(async move {
validate_downloaded_file(&item_clone, &path_clone, &app_clone, &cancel_token_clone).await
});
validation_tasks.push((validation_task, downloaded_path, item.clone()));
}
Err(e) => return Err(e),
}
}
// Wait for all validations to complete
for (validation_task, save_path, _item) in validation_tasks {
let validation_result = validation_task
.await
.map_err(|e| format!("Validation task join error: {}", e))?;
if let Err(validation_error) = validation_result {
// Clean up the file if validation fails
let _ = tokio::fs::remove_file(&save_path).await;
// Try to clean up the parent directory if it's empty
if let Some(parent) = save_path.parent() {
let _ = tokio::fs::remove_dir(parent).await;
}
return Err(validation_error);
}
}
// Emit final progress
let (transferred, total) = progress_tracker.get_total_progress().await;
let final_evt = DownloadEvent { transferred, total };
app.emit(&evt_name, final_evt).unwrap();
Ok(())
}
/// Downloads a single file without blocking other downloads
async fn download_single_file(
app: tauri::AppHandle,
item: &DownloadItem,
header_map: &HeaderMap,
save_path: &std::path::Path,
resume: bool,
cancel_token: CancellationToken,
evt_name: String,
progress_tracker: ProgressTracker,
file_id: String,
_file_size: u64,
) -> Result<std::path::PathBuf, String> {
// Create parent directories if they don't exist
if let Some(parent) = save_path.parent() {
if !parent.exists() {
tokio::fs::create_dir_all(parent)
.await
.map_err(err_to_string)?;
}
}
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
let append_extension = |ext: &str| {
if current_extension.is_empty() {
ext.to_string()
} else {
format!("{}.{}", current_extension, ext)
}
};
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
let url_save_path = save_path.with_extension(append_extension("url"));
let mut should_resume = resume
&& tmp_save_path.exists()
&& tokio::fs::read_to_string(&url_save_path)
.await
.map(|url| url == item.url) // check if we resume the same URL
.unwrap_or(false);
tokio::fs::write(&url_save_path, item.url.clone())
.await
.map_err(err_to_string)?;
log::info!("Started downloading: {}", item.url);
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
let mut download_delta = 0u64;
let mut initial_progress = 0u64;
let resp = if should_resume {
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
Ok(resp) => {
log::info!(
"Resume download: {}, already downloaded {} bytes",
item.url,
downloaded_size
);
initial_progress = downloaded_size;
// Initialize progress for resumed download
progress_tracker
.update_progress(&file_id, downloaded_size)
.await;
// Emit initial combined progress
let (combined_transferred, combined_total) =
progress_tracker.get_total_progress().await;
let evt = DownloadEvent {
transferred: combined_transferred,
total: combined_total,
};
app.emit(&evt_name, evt).unwrap();
resp
}
Err(e) => {
// fallback to normal download
log::warn!("Failed to resume download: {}", e);
should_resume = false;
_get_maybe_resume(&client, &item.url, 0).await?
}
}
} else {
_get_maybe_resume(&client, &item.url, 0).await?
};
let mut stream = resp.bytes_stream();
let file = if should_resume {
// resume download, append to existing file
tokio::fs::OpenOptions::new()
.write(true)
.append(true)
.open(&tmp_save_path)
.await
.map_err(err_to_string)?
} else {
// start new download, create a new file
File::create(&tmp_save_path).await.map_err(err_to_string)?
};
let mut writer = tokio::io::BufWriter::new(file);
let mut total_transferred = initial_progress;
// write chunk to file
while let Some(chunk) = stream.next().await {
if cancel_token.is_cancelled() {
if !should_resume {
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
.await
.ok();
}
log::info!("Download cancelled: {}", item.url);
return Err("Download cancelled".to_string());
}
let chunk = chunk.map_err(err_to_string)?;
writer.write_all(&chunk).await.map_err(err_to_string)?;
download_delta += chunk.len() as u64;
total_transferred += chunk.len() as u64;
// Update progress every 10 MB
if download_delta >= 10 * 1024 * 1024 {
// Update individual file progress
progress_tracker
.update_progress(&file_id, total_transferred)
.await;
// Emit combined progress event
let (combined_transferred, combined_total) =
progress_tracker.get_total_progress().await;
let evt = DownloadEvent {
transferred: combined_transferred,
total: combined_total,
};
app.emit(&evt_name, evt).unwrap();
download_delta = 0u64;
}
}
writer.flush().await.map_err(err_to_string)?;
// Final progress update for this file
progress_tracker
.update_progress(&file_id, total_transferred)
.await;
// Emit final combined progress
let (combined_transferred, combined_total) = progress_tracker.get_total_progress().await;
let evt = DownloadEvent {
transferred: combined_transferred,
total: combined_total,
};
app.emit(&evt_name, evt).unwrap();
// rename tmp file to final file
tokio::fs::rename(&tmp_save_path, &save_path)
.await
.map_err(err_to_string)?;
tokio::fs::remove_file(&url_save_path)
.await
.map_err(err_to_string)?;
log::info!("Finished downloading: {}", item.url);
Ok(save_path.to_path_buf())
}
// ===== HTTP CLIENT HELPER FUNCTIONS =====
pub async fn _get_maybe_resume(
client: &reqwest::Client,
url: &str,

View File

@ -1,4 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_util::sync::CancellationToken;
#[derive(Default)]
@ -20,6 +22,8 @@ pub struct DownloadItem {
pub url: String,
pub save_path: String,
pub proxy: Option<ProxyConfig>,
pub sha256: Option<String>,
pub size: Option<u64>,
}
#[derive(serde::Serialize, Clone, Debug)]
@ -27,3 +31,31 @@ pub struct DownloadEvent {
pub transferred: u64,
pub total: u64,
}
/// Structure to track progress for each file in parallel downloads
#[derive(Clone)]
pub struct ProgressTracker {
file_progress: Arc<Mutex<HashMap<String, u64>>>,
total_size: u64,
}
impl ProgressTracker {
pub fn new(_items: &[DownloadItem], sizes: HashMap<String, u64>) -> Self {
let total_size = sizes.values().sum();
ProgressTracker {
file_progress: Arc::new(Mutex::new(HashMap::new())),
total_size,
}
}
pub async fn update_progress(&self, file_id: &str, transferred: u64) {
let mut progress = self.file_progress.lock().await;
progress.insert(file_id.to_string(), transferred);
}
pub async fn get_total_progress(&self) -> (u64, u64) {
let progress = self.file_progress.lock().await;
let total_transferred: u64 = progress.values().sum();
(total_transferred, self.total_size)
}
}

View File

@ -194,6 +194,8 @@ fn test_download_item_with_ssl_proxy() {
url: "https://example.com/file.zip".to_string(),
save_path: "downloads/file.zip".to_string(),
proxy: Some(proxy_config),
sha256: None,
size: None,
};
assert!(download_item.proxy.is_some());
@ -211,6 +213,8 @@ fn test_client_creation_with_ssl_settings() {
url: "https://example.com/file.zip".to_string(),
save_path: "downloads/file.zip".to_string(),
proxy: Some(proxy_config),
sha256: None,
size: None,
};
let header_map = HeaderMap::new();
@ -256,6 +260,8 @@ fn test_download_item_creation() {
url: "https://example.com/file.tar.gz".to_string(),
save_path: "models/test.tar.gz".to_string(),
proxy: None,
sha256: None,
size: None,
};
assert_eq!(item.url, "https://example.com/file.tar.gz");

View File

@ -13,6 +13,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.10"
tokio = { version = "1", features = ["process"] }
tokio-util = "0.7.14"
url = "2.5"
[features]

View File

@ -1,7 +1,11 @@
use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use rand::{distributions::Alphanumeric, Rng};
use sha2::Sha256;
use sha2::{Digest, Sha256};
use std::path::Path;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio_util::sync::CancellationToken;
type HmacSha256 = Hmac<Sha256>;
@ -24,3 +28,59 @@ pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String,
let hash = general_purpose::STANDARD.encode(code_bytes);
Ok(hash)
}
/// Compute SHA256 hash of a file with cancellation support by chunking the file
pub async fn compute_file_sha256_with_cancellation(
file_path: &Path,
cancel_token: &CancellationToken,
) -> Result<String, String> {
// Check for cancellation before starting
if cancel_token.is_cancelled() {
return Err("Hash computation cancelled".to_string());
}
let mut file = File::open(file_path)
.await
.map_err(|e| format!("Failed to open file for hashing: {}", e))?;
let mut hasher = Sha256::new();
let mut buffer = vec![0u8; 64 * 1024]; // 64KB chunks
let mut total_read = 0u64;
loop {
// Check for cancellation every chunk (every 64KB)
if cancel_token.is_cancelled() {
return Err("Hash computation cancelled".to_string());
}
let bytes_read = file
.read(&mut buffer)
.await
.map_err(|e| format!("Failed to read file for hashing: {}", e))?;
if bytes_read == 0 {
break; // EOF
}
hasher.update(&buffer[..bytes_read]);
total_read += bytes_read as u64;
// Log progress for very large files (every 100MB)
if total_read % (100 * 1024 * 1024) == 0 {
#[cfg(feature = "logging")]
log::debug!("Hash progress: {} MB processed", total_read / (1024 * 1024));
}
}
// Final cancellation check
if cancel_token.is_cancelled() {
return Err("Hash computation cancelled".to_string());
}
let hash_bytes = hasher.finalize();
let hash_hex = format!("{:x}", hash_bytes);
#[cfg(feature = "logging")]
log::debug!("Hash computation completed for {} bytes", total_read);
Ok(hash_hex)
}

View File

@ -168,9 +168,46 @@ export function DownloadManagement() {
[removeDownload, removeLocalDownloadingModel, t]
)
const onModelValidationStarted = useCallback(
(event: { modelId: string; downloadType: string }) => {
console.debug('onModelValidationStarted', event)
// Show validation in progress toast
toast.info(t('common:toast.modelValidationStarted.title'), {
id: `model-validation-started-${event.modelId}`,
description: t('common:toast.modelValidationStarted.description', {
modelId: event.modelId,
}),
duration: 10000,
})
},
[t]
)
const onModelValidationFailed = useCallback(
(event: { modelId: string; error: string; reason: string }) => {
console.debug('onModelValidationFailed', event)
// Dismiss the validation started toast
toast.dismiss(`model-validation-started-${event.modelId}`)
removeDownload(event.modelId)
removeLocalDownloadingModel(event.modelId)
// Show specific toast for validation failure
toast.error(t('common:toast.modelValidationFailed.title'), {
description: t('common:toast.modelValidationFailed.description', {
modelId: event.modelId,
}),
duration: 30000, // Requires manual dismissal for security-critical message
})
},
[removeDownload, removeLocalDownloadingModel, t]
)
const onFileDownloadStopped = useCallback(
(state: DownloadState) => {
console.debug('onFileDownloadError', state)
console.debug('onFileDownloadStopped', state)
removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId)
},
@ -180,6 +217,10 @@ export function DownloadManagement() {
const onFileDownloadSuccess = useCallback(
async (state: DownloadState) => {
console.debug('onFileDownloadSuccess', state)
// Dismiss any validation started toast when download completes successfully
toast.dismiss(`model-validation-started-${state.modelId}`)
removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId)
toast.success(t('common:toast.downloadComplete.title'), {
@ -192,12 +233,34 @@ export function DownloadManagement() {
[removeDownload, removeLocalDownloadingModel, t]
)
const onFileDownloadAndVerificationSuccess = useCallback(
async (state: DownloadState) => {
console.debug('onFileDownloadAndVerificationSuccess', state)
// Dismiss any validation started toast when download and verification complete successfully
toast.dismiss(`model-validation-started-${state.modelId}`)
removeDownload(state.modelId)
removeLocalDownloadingModel(state.modelId)
toast.success(t('common:toast.downloadAndVerificationComplete.title'), {
id: 'download-complete',
description: t('common:toast.downloadAndVerificationComplete.description', {
item: state.modelId,
}),
})
},
[removeDownload, removeLocalDownloadingModel, t]
)
useEffect(() => {
console.debug('DownloadListener: registering event listeners...')
events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
events.on(DownloadEvent.onFileDownloadError, onFileDownloadError)
events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
events.on(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
events.on(DownloadEvent.onModelValidationStarted, onModelValidationStarted)
events.on(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
events.on(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
// Register app update event listeners
events.on(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
@ -210,6 +273,12 @@ export function DownloadManagement() {
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
events.off(
DownloadEvent.onModelValidationStarted,
onModelValidationStarted
)
events.off(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
events.off(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
// Unregister app update event listeners
events.off(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
@ -224,6 +293,9 @@ export function DownloadManagement() {
onFileDownloadError,
onFileDownloadSuccess,
onFileDownloadStopped,
onModelValidationStarted,
onModelValidationFailed,
onFileDownloadAndVerificationSuccess,
onAppUpdateDownloadUpdate,
onAppUpdateDownloadSuccess,
onAppUpdateDownloadError,

View File

@ -256,6 +256,22 @@
"downloadCancelled": {
"title": "Download abgebrochen",
"description": "Der Download-Prozess wurde abgebrochen"
},
"downloadFailed": {
"title": "Download fehlgeschlagen",
"description": "{{item}} Download fehlgeschlagen"
},
"modelValidationStarted": {
"title": "Modell wird validiert",
"description": "Modell \"{{modelId}}\" erfolgreich heruntergeladen. Integrität wird überprüft..."
},
"modelValidationFailed": {
"title": "Modellvalidierung fehlgeschlagen",
"description": "Das heruntergeladene Modell \"{{modelId}}\" ist bei der Integritätsprüfung fehlgeschlagen und wurde entfernt. Die Datei könnte beschädigt oder manipuliert worden sein."
},
"downloadAndVerificationComplete": {
"title": "Download abgeschlossen",
"description": "Modell \"{{item}}\" erfolgreich heruntergeladen und verifiziert"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "Zeige Varianten",
"useModel": "Nutze dieses Modell",
"downloadModel": "Modell herunterladen",
"tools": "Werkzeuge",
"searchPlaceholder": "Suche nach Modellen auf Hugging Face...",
"editTheme": "Bearbeite Erscheinungsbild",
"joyride": {

View File

@ -261,6 +261,18 @@
"downloadFailed": {
"title": "Download Failed",
"description": "{{item}} download failed"
},
"modelValidationStarted": {
"title": "Validating Model",
"description": "Downloaded model \"{{modelId}}\" successfully. Verifying integrity..."
},
"modelValidationFailed": {
"title": "Model Validation Failed",
"description": "The downloaded model \"{{modelId}}\" failed integrity verification and was removed. The file may be corrupted or tampered with."
},
"downloadAndVerificationComplete": {
"title": "Download Complete",
"description": "Model \"{{item}}\" downloaded and verified successfully"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "Show variants",
"useModel": "Use this model",
"downloadModel": "Download model",
"tools": "Tools",
"searchPlaceholder": "Search for models on Hugging Face...",
"joyride": {
"recommendedModelTitle": "Recommended Model",

View File

@ -249,6 +249,22 @@
"downloadCancelled": {
"title": "Unduhan Dibatalkan",
"description": "Proses unduhan telah dibatalkan"
},
"downloadFailed": {
"title": "Unduhan Gagal",
"description": "Unduhan {{item}} gagal"
},
"modelValidationStarted": {
"title": "Memvalidasi Model",
"description": "Model \"{{modelId}}\" berhasil diunduh. Memverifikasi integritas..."
},
"modelValidationFailed": {
"title": "Validasi Model Gagal",
"description": "Model yang diunduh \"{{modelId}}\" gagal verifikasi integritas dan telah dihapus. File mungkin rusak atau telah dimanipulasi."
},
"downloadAndVerificationComplete": {
"title": "Unduhan Selesai",
"description": "Model \"{{item}}\" berhasil diunduh dan diverifikasi"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "Tampilkan Varian",
"useModel": "Gunakan model ini",
"downloadModel": "Unduh model",
"tools": "Alat",
"searchPlaceholder": "Cari model di Hugging Face...",
"joyride": {
"recommendedModelTitle": "Model yang Direkomendasikan",

View File

@ -249,6 +249,22 @@
"downloadCancelled": {
"title": "Đã hủy tải xuống",
"description": "Quá trình tải xuống đã bị hủy"
},
"downloadFailed": {
"title": "Tải xuống thất bại",
"description": "Tải xuống {{item}} thất bại"
},
"modelValidationStarted": {
"title": "Đang xác thực mô hình",
"description": "Đã tải xuống mô hình \"{{modelId}}\" thành công. Đang xác minh tính toàn vẹn..."
},
"modelValidationFailed": {
"title": "Xác thực mô hình thất bại",
"description": "Mô hình đã tải xuống \"{{modelId}}\" không vượt qua kiểm tra tính toàn vẹn và đã bị xóa. Tệp có thể bị hỏng hoặc bị giả mạo."
},
"downloadAndVerificationComplete": {
"title": "Tải xuống hoàn tất",
"description": "Mô hình \"{{item}}\" đã được tải xuống và xác minh thành công"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "Hiển thị biến thể",
"useModel": "Sử dụng mô hình này",
"downloadModel": "Tải xuống mô hình",
"tools": "Công cụ",
"searchPlaceholder": "Tìm kiếm các mô hình trên Hugging Face...",
"joyride": {
"recommendedModelTitle": "Mô hình được đề xuất",

View File

@ -249,6 +249,22 @@
"downloadCancelled": {
"title": "下载已取消",
"description": "下载过程已取消"
},
"downloadFailed": {
"title": "下载失败",
"description": "{{item}} 下载失败"
},
"modelValidationStarted": {
"title": "正在验证模型",
"description": "模型 \"{{modelId}}\" 下载成功。正在验证完整性..."
},
"modelValidationFailed": {
"title": "模型验证失败",
"description": "已下载的模型 \"{{modelId}}\" 未通过完整性验证并已被删除。文件可能损坏或被篡改。"
},
"downloadAndVerificationComplete": {
"title": "下载完成",
"description": "模型 \"{{item}}\" 下载并验证成功"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "显示变体",
"useModel": "使用此模型",
"downloadModel": "下载模型",
"tools": "工具",
"searchPlaceholder": "在 Hugging Face 上搜索模型...",
"joyride": {
"recommendedModelTitle": "推荐模型",

View File

@ -249,6 +249,22 @@
"downloadCancelled": {
"title": "下載已取消",
"description": "下載過程已取消"
},
"downloadFailed": {
"title": "下載失敗",
"description": "{{item}} 下載失敗"
},
"modelValidationStarted": {
"title": "正在驗證模型",
"description": "模型 \"{{modelId}}\" 下載成功。正在驗證完整性..."
},
"modelValidationFailed": {
"title": "模型驗證失敗",
"description": "已下載的模型 \"{{modelId}}\" 未通過完整性驗證並已被刪除。檔案可能損壞或被篡改。"
},
"downloadAndVerificationComplete": {
"title": "下載完成",
"description": "模型 \"{{item}}\" 下載並驗證成功"
}
}
}

View File

@ -12,6 +12,7 @@
"showVariants": "顯示變體",
"useModel": "使用此模型",
"downloadModel": "下載模型",
"tools": "工具",
"searchPlaceholder": "在 Hugging Face 上搜尋模型...",
"joyride": {
"recommendedModelTitle": "推薦模型",

View File

@ -22,7 +22,7 @@ import {
CatalogModel,
convertHfRepoToCatalogModel,
fetchHuggingFaceRepo,
pullModel,
pullModelWithMetadata,
} from '@/services/models'
import { Progress } from '@/components/ui/progress'
import { Button } from '@/components/ui/button'
@ -408,9 +408,11 @@ function HubModelDetail() {
addLocalDownloadingModel(
variant.model_id
)
pullModel(
pullModelWithMetadata(
variant.model_id,
variant.path
variant.path,
modelData.mmproj_models?.[0]?.path,
huggingfaceToken
)
}}
className={cn(isDownloading && 'hidden')}

View File

@ -41,7 +41,7 @@ import {
} from '@/components/ui/dropdown-menu'
import {
CatalogModel,
pullModel,
pullModelWithMetadata,
fetchHuggingFaceRepo,
convertHfRepoToCatalogModel,
} from '@/services/models'
@ -313,7 +313,12 @@ function Hub() {
// Immediately set local downloading state
addLocalDownloadingModel(modelId)
const mmprojPath = model.mmproj_models?.[0]?.path
pullModel(modelId, modelUrl, mmprojPath)
pullModelWithMetadata(
modelId,
modelUrl,
mmprojPath,
huggingfaceToken
)
}
return (
@ -812,12 +817,13 @@ function Hub() {
addLocalDownloadingModel(
variant.model_id
)
pullModel(
pullModelWithMetadata(
variant.model_id,
variant.path,
filteredModels[
virtualItem.index
].mmproj_models?.[0]?.path
].mmproj_models?.[0]?.path,
huggingfaceToken
)
}}
>

View File

@ -325,7 +325,7 @@ describe('models service', () => {
expect(result).toEqual(mockRepoData)
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
{
headers: {},
}
@ -344,7 +344,7 @@ describe('models service', () => {
'https://huggingface.co/microsoft/DialoGPT-medium'
)
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
{
headers: {},
}
@ -353,7 +353,7 @@ describe('models service', () => {
// Test with domain prefix
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
{
headers: {},
}
@ -362,7 +362,7 @@ describe('models service', () => {
// Test with trailing slash
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
{
headers: {},
}
@ -391,7 +391,7 @@ describe('models service', () => {
expect(result).toBeNull()
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/nonexistent/model?blobs=true',
'https://huggingface.co/api/models/nonexistent/model?blobs=true&files_metadata=true',
{
headers: {},
}

View File

@ -62,6 +62,11 @@ export interface HuggingFaceRepo {
rfilename: string
size?: number
blobId?: string
lfs?: {
sha256: string
size: number
pointerSize: number
}
}>
readme?: string
}
@ -126,7 +131,7 @@ export const fetchHuggingFaceRepo = async (
}
const response = await fetch(
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true`,
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true&files_metadata=true`,
{
headers: hfToken
? {
@ -237,14 +242,103 @@ export const updateModel = async (
export const pullModel = async (
id: string,
modelPath: string,
mmprojPath?: string
modelSha256?: string,
modelSize?: number,
mmprojPath?: string,
mmprojSha256?: string,
mmprojSize?: number
) => {
return getEngine()?.import(id, {
modelPath,
mmprojPath,
modelSha256,
modelSize,
mmprojSha256,
mmprojSize,
})
}
/**
* Pull a model with real-time metadata fetching from HuggingFace.
* Extracts hash and size information from the model URL for both main model and mmproj files.
* @param id The model ID
* @param modelPath The model file URL (HuggingFace download URL)
* @param mmprojPath Optional mmproj file URL
* @param hfToken Optional HuggingFace token for authentication
* @returns A promise that resolves when the model download task is created.
*/
export const pullModelWithMetadata = async (
id: string,
modelPath: string,
mmprojPath?: string,
hfToken?: string
) => {
let modelSha256: string | undefined
let modelSize: number | undefined
let mmprojSha256: string | undefined
let mmprojSize: number | undefined
// Extract repo ID from model URL
// URL format: https://huggingface.co/{repo}/resolve/main/{filename}
const modelUrlMatch = modelPath.match(
/https:\/\/huggingface\.co\/([^/]+\/[^/]+)\/resolve\/main\/(.+)/
)
if (modelUrlMatch) {
const [, repoId, modelFilename] = modelUrlMatch
try {
// Fetch real-time metadata from HuggingFace
const repoInfo = await fetchHuggingFaceRepo(repoId, hfToken)
if (repoInfo?.siblings) {
// Find the specific model file
const modelFile = repoInfo.siblings.find(
(file) => file.rfilename === modelFilename
)
if (modelFile?.lfs) {
modelSha256 = modelFile.lfs.sha256
modelSize = modelFile.lfs.size
}
// If mmproj path provided, extract its metadata too
if (mmprojPath) {
const mmprojUrlMatch = mmprojPath.match(
/https:\/\/huggingface\.co\/[^/]+\/[^/]+\/resolve\/main\/(.+)/
)
if (mmprojUrlMatch) {
const [, mmprojFilename] = mmprojUrlMatch
const mmprojFile = repoInfo.siblings.find(
(file) => file.rfilename === mmprojFilename
)
if (mmprojFile?.lfs) {
mmprojSha256 = mmprojFile.lfs.sha256
mmprojSize = mmprojFile.lfs.size
}
}
}
}
} catch (error) {
console.warn(
'Failed to fetch HuggingFace metadata, proceeding without hash verification:',
error
)
// Continue with download even if metadata fetch fails
}
}
// Call the original pullModel with the fetched metadata
return pullModel(
id,
modelPath,
modelSha256,
modelSize,
mmprojPath,
mmprojSha256,
mmprojSize
)
}
/**
* Aborts a model download.
* @param id