* feat: gguf file size + hash validation * fix tests fe * update cargo tests * handle asyn download for both models and mmproj * move progress tracker to models * handle file download cancelled * add cancellation mid hash run
This commit is contained in:
parent
41b4cc3bb3
commit
32a2ca95b6
@ -194,6 +194,10 @@ export interface chatOptions {
|
||||
export interface ImportOptions {
|
||||
modelPath: string
|
||||
mmprojPath?: string
|
||||
modelSha256?: string
|
||||
modelSize?: number
|
||||
mmprojSha256?: string
|
||||
mmprojSize?: number
|
||||
}
|
||||
|
||||
export interface importResult {
|
||||
|
||||
@ -73,6 +73,9 @@ export enum DownloadEvent {
|
||||
onFileDownloadSuccess = 'onFileDownloadSuccess',
|
||||
onFileDownloadStopped = 'onFileDownloadStopped',
|
||||
onFileDownloadStarted = 'onFileDownloadStarted',
|
||||
onModelValidationStarted = 'onModelValidationStarted',
|
||||
onModelValidationFailed = 'onModelValidationFailed',
|
||||
onFileDownloadAndVerificationSuccess = 'onFileDownloadAndVerificationSuccess',
|
||||
}
|
||||
export enum ExtensionRoute {
|
||||
baseExtensions = 'baseExtensions',
|
||||
|
||||
@ -10,6 +10,8 @@ interface DownloadItem {
|
||||
url: string
|
||||
save_path: string
|
||||
proxy?: Record<string, string | string[] | boolean>
|
||||
sha256?: string
|
||||
size?: number
|
||||
}
|
||||
|
||||
type DownloadEvent = {
|
||||
|
||||
@ -20,9 +20,11 @@ import {
|
||||
chatCompletionRequest,
|
||||
events,
|
||||
AppEvent,
|
||||
DownloadEvent,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||||
import { listen } from '@tauri-apps/api/event'
|
||||
|
||||
import {
|
||||
listSupportedBackends,
|
||||
@ -71,6 +73,8 @@ interface DownloadItem {
|
||||
url: string
|
||||
save_path: string
|
||||
proxy?: Record<string, string | string[] | boolean>
|
||||
sha256?: string
|
||||
size?: number
|
||||
}
|
||||
|
||||
interface ModelConfig {
|
||||
@ -79,6 +83,9 @@ interface ModelConfig {
|
||||
name: string // user-friendly
|
||||
// some model info that we cache upon import
|
||||
size_bytes: number
|
||||
sha256?: string
|
||||
mmproj_sha256?: string
|
||||
mmproj_size_bytes?: number
|
||||
}
|
||||
|
||||
interface EmbeddingResponse {
|
||||
@ -154,6 +161,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||||
private isConfiguringBackends: boolean = false
|
||||
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
||||
private unlistenValidationStarted?: () => void
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
super.onLoad() // Calls registerEngine() from AIEngine
|
||||
@ -181,6 +189,19 @@ export default class llamacpp_extension extends AIEngine {
|
||||
await getJanDataFolderPath(),
|
||||
this.providerId,
|
||||
])
|
||||
|
||||
// Set up validation event listeners to bridge Tauri events to frontend
|
||||
this.unlistenValidationStarted = await listen<{
|
||||
modelId: string
|
||||
downloadType: string
|
||||
}>('onModelValidationStarted', (event) => {
|
||||
console.debug(
|
||||
'LlamaCPP: bridging onModelValidationStarted event',
|
||||
event.payload
|
||||
)
|
||||
events.emit(DownloadEvent.onModelValidationStarted, event.payload)
|
||||
})
|
||||
|
||||
this.configureBackends()
|
||||
}
|
||||
|
||||
@ -774,6 +795,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
override async onUnload(): Promise<void> {
|
||||
// Terminate all active sessions
|
||||
|
||||
// Clean up validation event listeners
|
||||
if (this.unlistenValidationStarted) {
|
||||
this.unlistenValidationStarted()
|
||||
}
|
||||
}
|
||||
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
@ -1006,6 +1032,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
url: path,
|
||||
save_path: localPath,
|
||||
proxy: getProxyConfig(),
|
||||
sha256:
|
||||
saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256,
|
||||
size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize,
|
||||
})
|
||||
return localPath
|
||||
}
|
||||
@ -1023,8 +1052,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
: undefined
|
||||
|
||||
if (downloadItems.length > 0) {
|
||||
let downloadCompleted = false
|
||||
|
||||
try {
|
||||
// emit download update event on progress
|
||||
const onProgress = (transferred: number, total: number) => {
|
||||
@ -1034,7 +1061,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
size: { transferred, total },
|
||||
downloadType: 'Model',
|
||||
})
|
||||
downloadCompleted = transferred === total
|
||||
}
|
||||
const downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
@ -1045,13 +1071,67 @@ export default class llamacpp_extension extends AIEngine {
|
||||
onProgress
|
||||
)
|
||||
|
||||
const eventName = downloadCompleted
|
||||
? 'onFileDownloadSuccess'
|
||||
: 'onFileDownloadStopped'
|
||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||
// If we reach here, download completed successfully (including validation)
|
||||
// The downloadFiles function only returns successfully if all files downloaded AND validated
|
||||
events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, {
|
||||
modelId,
|
||||
downloadType: 'Model'
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error downloading model:', modelId, opts, error)
|
||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : String(error)
|
||||
|
||||
// Check if this is a cancellation
|
||||
const isCancellationError = errorMessage.includes('Download cancelled') ||
|
||||
errorMessage.includes('Validation cancelled') ||
|
||||
errorMessage.includes('Hash computation cancelled') ||
|
||||
errorMessage.includes('cancelled') ||
|
||||
errorMessage.includes('aborted')
|
||||
|
||||
// Check if this is a validation failure
|
||||
const isValidationError =
|
||||
errorMessage.includes('Hash verification failed') ||
|
||||
errorMessage.includes('Size verification failed') ||
|
||||
errorMessage.includes('Failed to verify file')
|
||||
|
||||
if (isCancellationError) {
|
||||
logger.info('Download cancelled for model:', modelId)
|
||||
// Emit download stopped event instead of error
|
||||
events.emit(DownloadEvent.onFileDownloadStopped, {
|
||||
modelId,
|
||||
downloadType: 'Model',
|
||||
})
|
||||
} else if (isValidationError) {
|
||||
logger.error(
|
||||
'Validation failed for model:',
|
||||
modelId,
|
||||
'Error:',
|
||||
errorMessage
|
||||
)
|
||||
|
||||
// Cancel any other download tasks for this model
|
||||
try {
|
||||
this.abortImport(modelId)
|
||||
} catch (cancelError) {
|
||||
logger.warn('Failed to cancel download task:', cancelError)
|
||||
}
|
||||
|
||||
// Emit validation failure event
|
||||
events.emit(DownloadEvent.onModelValidationFailed, {
|
||||
modelId,
|
||||
downloadType: 'Model',
|
||||
error: errorMessage,
|
||||
reason: 'validation_failed',
|
||||
})
|
||||
} else {
|
||||
// Regular download error
|
||||
events.emit(DownloadEvent.onFileDownloadError, {
|
||||
modelId,
|
||||
downloadType: 'Model',
|
||||
error: errorMessage,
|
||||
})
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -1078,7 +1158,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
} catch (error) {
|
||||
logger.error('GGUF validation failed:', error)
|
||||
throw new Error(
|
||||
`Invalid GGUF file(s): ${error.message || 'File format validation failed'}`
|
||||
`Invalid GGUF file(s): ${
|
||||
error.message || 'File format validation failed'
|
||||
}`
|
||||
)
|
||||
}
|
||||
|
||||
@ -1097,6 +1179,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
mmproj_path: mmprojPath,
|
||||
name: modelId,
|
||||
size_bytes,
|
||||
model_sha256: opts.modelSha256,
|
||||
model_size_bytes: opts.modelSize,
|
||||
mmproj_sha256: opts.mmprojSha256,
|
||||
mmproj_size_bytes: opts.mmprojSize,
|
||||
} as ModelConfig
|
||||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||||
await invoke<void>('write_yaml', {
|
||||
@ -1108,16 +1194,50 @@ export default class llamacpp_extension extends AIEngine {
|
||||
modelPath,
|
||||
mmprojPath,
|
||||
size_bytes,
|
||||
model_sha256: opts.modelSha256,
|
||||
model_size_bytes: opts.modelSize,
|
||||
mmproj_sha256: opts.mmprojSha256,
|
||||
mmproj_size_bytes: opts.mmprojSize,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes the entire model folder for a given modelId
|
||||
* @param modelId The model ID to delete
|
||||
*/
|
||||
private async deleteModelFolder(modelId: string): Promise<void> {
|
||||
try {
|
||||
const modelDir = await joinPath([
|
||||
await this.getProviderPath(),
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
|
||||
if (await fs.existsSync(modelDir)) {
|
||||
logger.info(`Cleaning up model directory: ${modelDir}`)
|
||||
await fs.rm(modelDir)
|
||||
}
|
||||
} catch (deleteError) {
|
||||
logger.warn('Failed to delete model directory:', deleteError)
|
||||
}
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
// prepand provider name to avoid name collision
|
||||
// Cancel any active download task
|
||||
// prepend provider name to avoid name collision
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
const downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
await downloadManager.cancelDownload(taskId)
|
||||
|
||||
try {
|
||||
await downloadManager.cancelDownload(taskId)
|
||||
} catch (cancelError) {
|
||||
logger.warn('Failed to cancel download task:', cancelError)
|
||||
}
|
||||
|
||||
// Delete the entire model folder if it exists (for validation failures)
|
||||
await this.deleteModelFolder(modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
11
src-tauri/Cargo.lock
generated
11
src-tauri/Cargo.lock
generated
@ -2323,6 +2323,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"url",
|
||||
]
|
||||
|
||||
@ -4019,8 +4020,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp"
|
||||
version = "0.5.0"
|
||||
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb21cd3555f1059f27e4813827338dec44429a08ecd0011acc41d9907b160c00"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
@ -4045,8 +4047,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rmcp-macros"
|
||||
version = "0.5.0"
|
||||
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab5d16ae1ff3ce2c5fd86c37047b2869b75bec795d53a4b1d8257b15415a2354"
|
||||
dependencies = [
|
||||
"darling 0.21.2",
|
||||
"proc-macro2",
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
use super::models::{DownloadEvent, DownloadItem, ProxyConfig};
|
||||
use super::models::{DownloadEvent, DownloadItem, ProxyConfig, ProgressTracker};
|
||||
use crate::core::app::commands::get_jan_data_folder_path;
|
||||
use futures_util::StreamExt;
|
||||
use jan_utils::normalize_path;
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tauri::Emitter;
|
||||
use tokio::fs::File;
|
||||
@ -11,10 +12,131 @@ use tokio::io::AsyncWriteExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use url::Url;
|
||||
|
||||
// ===== UTILITY FUNCTIONS =====
|
||||
|
||||
pub fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
||||
format!("Error: {}", e)
|
||||
}
|
||||
|
||||
|
||||
// ===== VALIDATION FUNCTIONS =====
|
||||
|
||||
/// Validates a downloaded file against expected hash and size
|
||||
async fn validate_downloaded_file(
|
||||
item: &DownloadItem,
|
||||
save_path: &Path,
|
||||
app: &tauri::AppHandle,
|
||||
cancel_token: &CancellationToken,
|
||||
) -> Result<(), String> {
|
||||
// Skip validation if no verification data is provided
|
||||
if item.sha256.is_none() && item.size.is_none() {
|
||||
log::debug!(
|
||||
"No validation data provided for {}, skipping validation",
|
||||
item.url
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Extract model ID from save path for validation events
|
||||
// Path structure: llamacpp/models/{modelId}/model.gguf or llamacpp/models/{modelId}/mmproj.gguf
|
||||
let model_id = save_path
|
||||
.parent() // get parent directory (modelId folder)
|
||||
.and_then(|p| p.file_name())
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
// Emit validation started event
|
||||
app.emit(
|
||||
"onModelValidationStarted",
|
||||
serde_json::json!({
|
||||
"modelId": model_id,
|
||||
"downloadType": "Model",
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
log::info!("Starting validation for model: {}", model_id);
|
||||
|
||||
// Validate size if provided (fast check first)
|
||||
if let Some(expected_size) = &item.size {
|
||||
log::info!("Starting size verification for {}", item.url);
|
||||
|
||||
match tokio::fs::metadata(save_path).await {
|
||||
Ok(metadata) => {
|
||||
let actual_size = metadata.len();
|
||||
|
||||
if actual_size != *expected_size {
|
||||
log::error!(
|
||||
"Size verification failed for {}. Expected: {} bytes, Actual: {} bytes",
|
||||
item.url,
|
||||
expected_size,
|
||||
actual_size
|
||||
);
|
||||
return Err(format!(
|
||||
"Size verification failed. Expected {} bytes but got {} bytes.",
|
||||
expected_size, actual_size
|
||||
));
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Size verification successful for {} ({} bytes)",
|
||||
item.url,
|
||||
actual_size
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Failed to get file metadata for {}: {}",
|
||||
save_path.display(),
|
||||
e
|
||||
);
|
||||
return Err(format!("Failed to verify file size: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for cancellation before expensive hash computation
|
||||
if cancel_token.is_cancelled() {
|
||||
log::info!("Validation cancelled for {}", item.url);
|
||||
return Err("Validation cancelled".to_string());
|
||||
}
|
||||
|
||||
// Validate hash if provided (expensive check second)
|
||||
if let Some(expected_sha256) = &item.sha256 {
|
||||
log::info!("Starting Hash verification for {}", item.url);
|
||||
|
||||
match jan_utils::crypto::compute_file_sha256_with_cancellation(save_path, cancel_token).await {
|
||||
Ok(computed_sha256) => {
|
||||
if computed_sha256 != *expected_sha256 {
|
||||
log::error!(
|
||||
"Hash verification failed for {}. Expected: {}, Computed: {}",
|
||||
item.url,
|
||||
expected_sha256,
|
||||
computed_sha256
|
||||
);
|
||||
|
||||
return Err(format!(
|
||||
"Hash verification failed. The downloaded file is corrupted or has been tampered with."
|
||||
));
|
||||
}
|
||||
|
||||
log::info!("Hash verification successful for {}", item.url);
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Failed to compute SHA256 for {}: {}",
|
||||
save_path.display(),
|
||||
e
|
||||
);
|
||||
return Err(format!("Failed to verify file integrity: {}", e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("All validations passed for {}", item.url);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> {
|
||||
// Validate proxy URL format
|
||||
if let Err(e) = Url::parse(&config.url) {
|
||||
@ -172,6 +294,9 @@ pub async fn _get_file_size(
|
||||
}
|
||||
}
|
||||
|
||||
// ===== MAIN DOWNLOAD FUNCTIONS =====
|
||||
|
||||
/// Downloads multiple files in parallel with individual progress tracking
|
||||
pub async fn _download_files_internal(
|
||||
app: tauri::AppHandle,
|
||||
items: &[DownloadItem],
|
||||
@ -184,28 +309,31 @@ pub async fn _download_files_internal(
|
||||
|
||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||
|
||||
let total_size = {
|
||||
let mut total_size = 0u64;
|
||||
for item in items.iter() {
|
||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||
total_size += _get_file_size(&client, &item.url)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
}
|
||||
total_size
|
||||
};
|
||||
// Calculate sizes for each file
|
||||
let mut file_sizes = HashMap::new();
|
||||
for item in items.iter() {
|
||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||
let size = _get_file_size(&client, &item.url)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
file_sizes.insert(item.url.clone(), size);
|
||||
}
|
||||
|
||||
let total_size: u64 = file_sizes.values().sum();
|
||||
log::info!("Total download size: {}", total_size);
|
||||
|
||||
let mut evt = DownloadEvent {
|
||||
transferred: 0,
|
||||
total: total_size,
|
||||
};
|
||||
let evt_name = format!("download-{}", task_id);
|
||||
|
||||
// Create progress tracker
|
||||
let progress_tracker = ProgressTracker::new(items, file_sizes.clone());
|
||||
|
||||
// save file under Jan data folder
|
||||
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||
|
||||
for item in items.iter() {
|
||||
// Collect download tasks for parallel execution
|
||||
let mut download_tasks = Vec::new();
|
||||
|
||||
for (index, item) in items.iter().enumerate() {
|
||||
let save_path = jan_data_folder.join(&item.save_path);
|
||||
let save_path = normalize_path(&save_path);
|
||||
|
||||
@ -217,120 +345,251 @@ pub async fn _download_files_internal(
|
||||
));
|
||||
}
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = save_path.parent() {
|
||||
if !parent.exists() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
}
|
||||
}
|
||||
// Spawn download task for each file
|
||||
let item_clone = item.clone();
|
||||
let app_clone = app.clone();
|
||||
let header_map_clone = header_map.clone();
|
||||
let cancel_token_clone = cancel_token.clone();
|
||||
let evt_name_clone = evt_name.clone();
|
||||
let progress_tracker_clone = progress_tracker.clone();
|
||||
let file_id = format!("{}-{}", task_id, index);
|
||||
let file_size = file_sizes.get(&item.url).copied().unwrap_or(0);
|
||||
|
||||
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
|
||||
let append_extension = |ext: &str| {
|
||||
if current_extension.is_empty() {
|
||||
ext.to_string()
|
||||
} else {
|
||||
format!("{}.{}", current_extension, ext)
|
||||
}
|
||||
};
|
||||
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
|
||||
let url_save_path = save_path.with_extension(append_extension("url"));
|
||||
|
||||
let mut should_resume = resume
|
||||
&& tmp_save_path.exists()
|
||||
&& tokio::fs::read_to_string(&url_save_path)
|
||||
.await
|
||||
.map(|url| url == item.url) // check if we resume the same URL
|
||||
.unwrap_or(false);
|
||||
|
||||
tokio::fs::write(&url_save_path, item.url.clone())
|
||||
let task = tokio::spawn(async move {
|
||||
download_single_file(
|
||||
app_clone,
|
||||
&item_clone,
|
||||
&header_map_clone,
|
||||
&save_path,
|
||||
resume,
|
||||
cancel_token_clone,
|
||||
evt_name_clone,
|
||||
progress_tracker_clone,
|
||||
file_id,
|
||||
file_size,
|
||||
)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
});
|
||||
|
||||
log::info!("Started downloading: {}", item.url);
|
||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||
let mut download_delta = 0u64;
|
||||
let resp = if should_resume {
|
||||
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
|
||||
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
|
||||
Ok(resp) => {
|
||||
log::info!(
|
||||
"Resume download: {}, already downloaded {} bytes",
|
||||
item.url,
|
||||
downloaded_size
|
||||
);
|
||||
download_delta += downloaded_size;
|
||||
resp
|
||||
}
|
||||
Err(e) => {
|
||||
// fallback to normal download
|
||||
log::warn!("Failed to resume download: {}", e);
|
||||
should_resume = false;
|
||||
_get_maybe_resume(&client, &item.url, 0).await?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_get_maybe_resume(&client, &item.url, 0).await?
|
||||
};
|
||||
let mut stream = resp.bytes_stream();
|
||||
|
||||
let file = if should_resume {
|
||||
// resume download, append to existing file
|
||||
tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&tmp_save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?
|
||||
} else {
|
||||
// start new download, create a new file
|
||||
File::create(&tmp_save_path).await.map_err(err_to_string)?
|
||||
};
|
||||
let mut writer = tokio::io::BufWriter::new(file);
|
||||
|
||||
// write chunk to file
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if cancel_token.is_cancelled() {
|
||||
if !should_resume {
|
||||
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
log::info!("Download cancelled for task: {}", task_id);
|
||||
app.emit(&evt_name, evt.clone()).unwrap();
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let chunk = chunk.map_err(err_to_string)?;
|
||||
writer.write_all(&chunk).await.map_err(err_to_string)?;
|
||||
download_delta += chunk.len() as u64;
|
||||
|
||||
// only update every 10 MB
|
||||
if download_delta >= 10 * 1024 * 1024 {
|
||||
evt.transferred += download_delta;
|
||||
app.emit(&evt_name, evt.clone()).unwrap();
|
||||
download_delta = 0u64;
|
||||
}
|
||||
}
|
||||
|
||||
writer.flush().await.map_err(err_to_string)?;
|
||||
evt.transferred += download_delta;
|
||||
|
||||
// rename tmp file to final file
|
||||
tokio::fs::rename(&tmp_save_path, &save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
tokio::fs::remove_file(&url_save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
log::info!("Finished downloading: {}", item.url);
|
||||
download_tasks.push(task);
|
||||
}
|
||||
|
||||
app.emit(&evt_name, evt.clone()).unwrap();
|
||||
// Wait for all downloads to complete
|
||||
let mut validation_tasks = Vec::new();
|
||||
for (task, item) in download_tasks.into_iter().zip(items.iter()) {
|
||||
let result = task.await.map_err(|e| format!("Task join error: {}", e))?;
|
||||
|
||||
match result {
|
||||
Ok(downloaded_path) => {
|
||||
// Spawn validation task in parallel
|
||||
let item_clone = item.clone();
|
||||
let app_clone = app.clone();
|
||||
let path_clone = downloaded_path.clone();
|
||||
let cancel_token_clone = cancel_token.clone();
|
||||
let validation_task = tokio::spawn(async move {
|
||||
validate_downloaded_file(&item_clone, &path_clone, &app_clone, &cancel_token_clone).await
|
||||
});
|
||||
validation_tasks.push((validation_task, downloaded_path, item.clone()));
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all validations to complete
|
||||
for (validation_task, save_path, _item) in validation_tasks {
|
||||
let validation_result = validation_task
|
||||
.await
|
||||
.map_err(|e| format!("Validation task join error: {}", e))?;
|
||||
|
||||
if let Err(validation_error) = validation_result {
|
||||
// Clean up the file if validation fails
|
||||
let _ = tokio::fs::remove_file(&save_path).await;
|
||||
|
||||
// Try to clean up the parent directory if it's empty
|
||||
if let Some(parent) = save_path.parent() {
|
||||
let _ = tokio::fs::remove_dir(parent).await;
|
||||
}
|
||||
|
||||
return Err(validation_error);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit final progress
|
||||
let (transferred, total) = progress_tracker.get_total_progress().await;
|
||||
let final_evt = DownloadEvent { transferred, total };
|
||||
app.emit(&evt_name, final_evt).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Downloads a single file without blocking other downloads
|
||||
async fn download_single_file(
|
||||
app: tauri::AppHandle,
|
||||
item: &DownloadItem,
|
||||
header_map: &HeaderMap,
|
||||
save_path: &std::path::Path,
|
||||
resume: bool,
|
||||
cancel_token: CancellationToken,
|
||||
evt_name: String,
|
||||
progress_tracker: ProgressTracker,
|
||||
file_id: String,
|
||||
_file_size: u64,
|
||||
) -> Result<std::path::PathBuf, String> {
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = save_path.parent() {
|
||||
if !parent.exists() {
|
||||
tokio::fs::create_dir_all(parent)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
}
|
||||
}
|
||||
|
||||
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
|
||||
let append_extension = |ext: &str| {
|
||||
if current_extension.is_empty() {
|
||||
ext.to_string()
|
||||
} else {
|
||||
format!("{}.{}", current_extension, ext)
|
||||
}
|
||||
};
|
||||
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
|
||||
let url_save_path = save_path.with_extension(append_extension("url"));
|
||||
|
||||
let mut should_resume = resume
|
||||
&& tmp_save_path.exists()
|
||||
&& tokio::fs::read_to_string(&url_save_path)
|
||||
.await
|
||||
.map(|url| url == item.url) // check if we resume the same URL
|
||||
.unwrap_or(false);
|
||||
|
||||
tokio::fs::write(&url_save_path, item.url.clone())
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
|
||||
log::info!("Started downloading: {}", item.url);
|
||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||
let mut download_delta = 0u64;
|
||||
let mut initial_progress = 0u64;
|
||||
|
||||
let resp = if should_resume {
|
||||
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
|
||||
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
|
||||
Ok(resp) => {
|
||||
log::info!(
|
||||
"Resume download: {}, already downloaded {} bytes",
|
||||
item.url,
|
||||
downloaded_size
|
||||
);
|
||||
initial_progress = downloaded_size;
|
||||
|
||||
// Initialize progress for resumed download
|
||||
progress_tracker
|
||||
.update_progress(&file_id, downloaded_size)
|
||||
.await;
|
||||
|
||||
// Emit initial combined progress
|
||||
let (combined_transferred, combined_total) =
|
||||
progress_tracker.get_total_progress().await;
|
||||
let evt = DownloadEvent {
|
||||
transferred: combined_transferred,
|
||||
total: combined_total,
|
||||
};
|
||||
app.emit(&evt_name, evt).unwrap();
|
||||
|
||||
resp
|
||||
}
|
||||
Err(e) => {
|
||||
// fallback to normal download
|
||||
log::warn!("Failed to resume download: {}", e);
|
||||
should_resume = false;
|
||||
_get_maybe_resume(&client, &item.url, 0).await?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
_get_maybe_resume(&client, &item.url, 0).await?
|
||||
};
|
||||
let mut stream = resp.bytes_stream();
|
||||
|
||||
let file = if should_resume {
|
||||
// resume download, append to existing file
|
||||
tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.append(true)
|
||||
.open(&tmp_save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?
|
||||
} else {
|
||||
// start new download, create a new file
|
||||
File::create(&tmp_save_path).await.map_err(err_to_string)?
|
||||
};
|
||||
let mut writer = tokio::io::BufWriter::new(file);
|
||||
let mut total_transferred = initial_progress;
|
||||
|
||||
// write chunk to file
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if cancel_token.is_cancelled() {
|
||||
if !should_resume {
|
||||
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
log::info!("Download cancelled: {}", item.url);
|
||||
return Err("Download cancelled".to_string());
|
||||
}
|
||||
|
||||
let chunk = chunk.map_err(err_to_string)?;
|
||||
writer.write_all(&chunk).await.map_err(err_to_string)?;
|
||||
download_delta += chunk.len() as u64;
|
||||
total_transferred += chunk.len() as u64;
|
||||
|
||||
// Update progress every 10 MB
|
||||
if download_delta >= 10 * 1024 * 1024 {
|
||||
// Update individual file progress
|
||||
progress_tracker
|
||||
.update_progress(&file_id, total_transferred)
|
||||
.await;
|
||||
|
||||
// Emit combined progress event
|
||||
let (combined_transferred, combined_total) =
|
||||
progress_tracker.get_total_progress().await;
|
||||
let evt = DownloadEvent {
|
||||
transferred: combined_transferred,
|
||||
total: combined_total,
|
||||
};
|
||||
app.emit(&evt_name, evt).unwrap();
|
||||
|
||||
download_delta = 0u64;
|
||||
}
|
||||
}
|
||||
|
||||
writer.flush().await.map_err(err_to_string)?;
|
||||
|
||||
// Final progress update for this file
|
||||
progress_tracker
|
||||
.update_progress(&file_id, total_transferred)
|
||||
.await;
|
||||
|
||||
// Emit final combined progress
|
||||
let (combined_transferred, combined_total) = progress_tracker.get_total_progress().await;
|
||||
let evt = DownloadEvent {
|
||||
transferred: combined_transferred,
|
||||
total: combined_total,
|
||||
};
|
||||
app.emit(&evt_name, evt).unwrap();
|
||||
|
||||
// rename tmp file to final file
|
||||
tokio::fs::rename(&tmp_save_path, &save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
tokio::fs::remove_file(&url_save_path)
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
|
||||
log::info!("Finished downloading: {}", item.url);
|
||||
Ok(save_path.to_path_buf())
|
||||
}
|
||||
|
||||
// ===== HTTP CLIENT HELPER FUNCTIONS =====
|
||||
|
||||
pub async fn _get_maybe_resume(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Default)]
|
||||
@ -20,6 +22,8 @@ pub struct DownloadItem {
|
||||
pub url: String,
|
||||
pub save_path: String,
|
||||
pub proxy: Option<ProxyConfig>,
|
||||
pub sha256: Option<String>,
|
||||
pub size: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Clone, Debug)]
|
||||
@ -27,3 +31,31 @@ pub struct DownloadEvent {
|
||||
pub transferred: u64,
|
||||
pub total: u64,
|
||||
}
|
||||
|
||||
/// Structure to track progress for each file in parallel downloads
|
||||
#[derive(Clone)]
|
||||
pub struct ProgressTracker {
|
||||
file_progress: Arc<Mutex<HashMap<String, u64>>>,
|
||||
total_size: u64,
|
||||
}
|
||||
|
||||
impl ProgressTracker {
|
||||
pub fn new(_items: &[DownloadItem], sizes: HashMap<String, u64>) -> Self {
|
||||
let total_size = sizes.values().sum();
|
||||
ProgressTracker {
|
||||
file_progress: Arc::new(Mutex::new(HashMap::new())),
|
||||
total_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_progress(&self, file_id: &str, transferred: u64) {
|
||||
let mut progress = self.file_progress.lock().await;
|
||||
progress.insert(file_id.to_string(), transferred);
|
||||
}
|
||||
|
||||
pub async fn get_total_progress(&self) -> (u64, u64) {
|
||||
let progress = self.file_progress.lock().await;
|
||||
let total_transferred: u64 = progress.values().sum();
|
||||
(total_transferred, self.total_size)
|
||||
}
|
||||
}
|
||||
|
||||
@ -194,6 +194,8 @@ fn test_download_item_with_ssl_proxy() {
|
||||
url: "https://example.com/file.zip".to_string(),
|
||||
save_path: "downloads/file.zip".to_string(),
|
||||
proxy: Some(proxy_config),
|
||||
sha256: None,
|
||||
size: None,
|
||||
};
|
||||
|
||||
assert!(download_item.proxy.is_some());
|
||||
@ -211,6 +213,8 @@ fn test_client_creation_with_ssl_settings() {
|
||||
url: "https://example.com/file.zip".to_string(),
|
||||
save_path: "downloads/file.zip".to_string(),
|
||||
proxy: Some(proxy_config),
|
||||
sha256: None,
|
||||
size: None,
|
||||
};
|
||||
|
||||
let header_map = HeaderMap::new();
|
||||
@ -256,6 +260,8 @@ fn test_download_item_creation() {
|
||||
url: "https://example.com/file.tar.gz".to_string(),
|
||||
save_path: "models/test.tar.gz".to_string(),
|
||||
proxy: None,
|
||||
sha256: None,
|
||||
size: None,
|
||||
};
|
||||
|
||||
assert_eq!(item.url, "https://example.com/file.tar.gz");
|
||||
|
||||
@ -13,6 +13,7 @@ serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
sha2 = "0.10"
|
||||
tokio = { version = "1", features = ["process"] }
|
||||
tokio-util = "0.7.14"
|
||||
url = "2.5"
|
||||
|
||||
[features]
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use sha2::Sha256;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::path::Path;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
@ -24,3 +28,59 @@ pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String,
|
||||
let hash = general_purpose::STANDARD.encode(code_bytes);
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
/// Compute SHA256 hash of a file with cancellation support by chunking the file
|
||||
pub async fn compute_file_sha256_with_cancellation(
|
||||
file_path: &Path,
|
||||
cancel_token: &CancellationToken,
|
||||
) -> Result<String, String> {
|
||||
// Check for cancellation before starting
|
||||
if cancel_token.is_cancelled() {
|
||||
return Err("Hash computation cancelled".to_string());
|
||||
}
|
||||
|
||||
let mut file = File::open(file_path)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to open file for hashing: {}", e))?;
|
||||
|
||||
let mut hasher = Sha256::new();
|
||||
let mut buffer = vec![0u8; 64 * 1024]; // 64KB chunks
|
||||
let mut total_read = 0u64;
|
||||
|
||||
loop {
|
||||
// Check for cancellation every chunk (every 64KB)
|
||||
if cancel_token.is_cancelled() {
|
||||
return Err("Hash computation cancelled".to_string());
|
||||
}
|
||||
|
||||
let bytes_read = file
|
||||
.read(&mut buffer)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read file for hashing: {}", e))?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
break; // EOF
|
||||
}
|
||||
|
||||
hasher.update(&buffer[..bytes_read]);
|
||||
total_read += bytes_read as u64;
|
||||
|
||||
// Log progress for very large files (every 100MB)
|
||||
if total_read % (100 * 1024 * 1024) == 0 {
|
||||
#[cfg(feature = "logging")]
|
||||
log::debug!("Hash progress: {} MB processed", total_read / (1024 * 1024));
|
||||
}
|
||||
}
|
||||
|
||||
// Final cancellation check
|
||||
if cancel_token.is_cancelled() {
|
||||
return Err("Hash computation cancelled".to_string());
|
||||
}
|
||||
|
||||
let hash_bytes = hasher.finalize();
|
||||
let hash_hex = format!("{:x}", hash_bytes);
|
||||
|
||||
#[cfg(feature = "logging")]
|
||||
log::debug!("Hash computation completed for {} bytes", total_read);
|
||||
Ok(hash_hex)
|
||||
}
|
||||
|
||||
@ -168,9 +168,46 @@ export function DownloadManagement() {
|
||||
[removeDownload, removeLocalDownloadingModel, t]
|
||||
)
|
||||
|
||||
const onModelValidationStarted = useCallback(
|
||||
(event: { modelId: string; downloadType: string }) => {
|
||||
console.debug('onModelValidationStarted', event)
|
||||
|
||||
// Show validation in progress toast
|
||||
toast.info(t('common:toast.modelValidationStarted.title'), {
|
||||
id: `model-validation-started-${event.modelId}`,
|
||||
description: t('common:toast.modelValidationStarted.description', {
|
||||
modelId: event.modelId,
|
||||
}),
|
||||
duration: 10000,
|
||||
})
|
||||
},
|
||||
[t]
|
||||
)
|
||||
|
||||
const onModelValidationFailed = useCallback(
|
||||
(event: { modelId: string; error: string; reason: string }) => {
|
||||
console.debug('onModelValidationFailed', event)
|
||||
|
||||
// Dismiss the validation started toast
|
||||
toast.dismiss(`model-validation-started-${event.modelId}`)
|
||||
|
||||
removeDownload(event.modelId)
|
||||
removeLocalDownloadingModel(event.modelId)
|
||||
|
||||
// Show specific toast for validation failure
|
||||
toast.error(t('common:toast.modelValidationFailed.title'), {
|
||||
description: t('common:toast.modelValidationFailed.description', {
|
||||
modelId: event.modelId,
|
||||
}),
|
||||
duration: 30000, // Requires manual dismissal for security-critical message
|
||||
})
|
||||
},
|
||||
[removeDownload, removeLocalDownloadingModel, t]
|
||||
)
|
||||
|
||||
const onFileDownloadStopped = useCallback(
|
||||
(state: DownloadState) => {
|
||||
console.debug('onFileDownloadError', state)
|
||||
console.debug('onFileDownloadStopped', state)
|
||||
removeDownload(state.modelId)
|
||||
removeLocalDownloadingModel(state.modelId)
|
||||
},
|
||||
@ -180,6 +217,10 @@ export function DownloadManagement() {
|
||||
const onFileDownloadSuccess = useCallback(
|
||||
async (state: DownloadState) => {
|
||||
console.debug('onFileDownloadSuccess', state)
|
||||
|
||||
// Dismiss any validation started toast when download completes successfully
|
||||
toast.dismiss(`model-validation-started-${state.modelId}`)
|
||||
|
||||
removeDownload(state.modelId)
|
||||
removeLocalDownloadingModel(state.modelId)
|
||||
toast.success(t('common:toast.downloadComplete.title'), {
|
||||
@ -192,12 +233,34 @@ export function DownloadManagement() {
|
||||
[removeDownload, removeLocalDownloadingModel, t]
|
||||
)
|
||||
|
||||
const onFileDownloadAndVerificationSuccess = useCallback(
|
||||
async (state: DownloadState) => {
|
||||
console.debug('onFileDownloadAndVerificationSuccess', state)
|
||||
|
||||
// Dismiss any validation started toast when download and verification complete successfully
|
||||
toast.dismiss(`model-validation-started-${state.modelId}`)
|
||||
|
||||
removeDownload(state.modelId)
|
||||
removeLocalDownloadingModel(state.modelId)
|
||||
toast.success(t('common:toast.downloadAndVerificationComplete.title'), {
|
||||
id: 'download-complete',
|
||||
description: t('common:toast.downloadAndVerificationComplete.description', {
|
||||
item: state.modelId,
|
||||
}),
|
||||
})
|
||||
},
|
||||
[removeDownload, removeLocalDownloadingModel, t]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
console.debug('DownloadListener: registering event listeners...')
|
||||
events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
|
||||
events.on(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
||||
events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||
events.on(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
||||
events.on(DownloadEvent.onModelValidationStarted, onModelValidationStarted)
|
||||
events.on(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
|
||||
events.on(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
|
||||
|
||||
// Register app update event listeners
|
||||
events.on(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
||||
@ -210,6 +273,12 @@ export function DownloadManagement() {
|
||||
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
||||
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
||||
events.off(
|
||||
DownloadEvent.onModelValidationStarted,
|
||||
onModelValidationStarted
|
||||
)
|
||||
events.off(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
|
||||
events.off(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
|
||||
|
||||
// Unregister app update event listeners
|
||||
events.off(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
||||
@ -224,6 +293,9 @@ export function DownloadManagement() {
|
||||
onFileDownloadError,
|
||||
onFileDownloadSuccess,
|
||||
onFileDownloadStopped,
|
||||
onModelValidationStarted,
|
||||
onModelValidationFailed,
|
||||
onFileDownloadAndVerificationSuccess,
|
||||
onAppUpdateDownloadUpdate,
|
||||
onAppUpdateDownloadSuccess,
|
||||
onAppUpdateDownloadError,
|
||||
|
||||
@ -256,6 +256,22 @@
|
||||
"downloadCancelled": {
|
||||
"title": "Download abgebrochen",
|
||||
"description": "Der Download-Prozess wurde abgebrochen"
|
||||
},
|
||||
"downloadFailed": {
|
||||
"title": "Download fehlgeschlagen",
|
||||
"description": "{{item}} Download fehlgeschlagen"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "Modell wird validiert",
|
||||
"description": "Modell \"{{modelId}}\" erfolgreich heruntergeladen. Integrität wird überprüft..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "Modellvalidierung fehlgeschlagen",
|
||||
"description": "Das heruntergeladene Modell \"{{modelId}}\" ist bei der Integritätsprüfung fehlgeschlagen und wurde entfernt. Die Datei könnte beschädigt oder manipuliert worden sein."
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "Download abgeschlossen",
|
||||
"description": "Modell \"{{item}}\" erfolgreich heruntergeladen und verifiziert"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "Zeige Varianten",
|
||||
"useModel": "Nutze dieses Modell",
|
||||
"downloadModel": "Modell herunterladen",
|
||||
"tools": "Werkzeuge",
|
||||
"searchPlaceholder": "Suche nach Modellen auf Hugging Face...",
|
||||
"editTheme": "Bearbeite Erscheinungsbild",
|
||||
"joyride": {
|
||||
|
||||
@ -261,6 +261,18 @@
|
||||
"downloadFailed": {
|
||||
"title": "Download Failed",
|
||||
"description": "{{item}} download failed"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "Validating Model",
|
||||
"description": "Downloaded model \"{{modelId}}\" successfully. Verifying integrity..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "Model Validation Failed",
|
||||
"description": "The downloaded model \"{{modelId}}\" failed integrity verification and was removed. The file may be corrupted or tampered with."
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "Download Complete",
|
||||
"description": "Model \"{{item}}\" downloaded and verified successfully"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "Show variants",
|
||||
"useModel": "Use this model",
|
||||
"downloadModel": "Download model",
|
||||
"tools": "Tools",
|
||||
"searchPlaceholder": "Search for models on Hugging Face...",
|
||||
"joyride": {
|
||||
"recommendedModelTitle": "Recommended Model",
|
||||
|
||||
@ -249,6 +249,22 @@
|
||||
"downloadCancelled": {
|
||||
"title": "Unduhan Dibatalkan",
|
||||
"description": "Proses unduhan telah dibatalkan"
|
||||
},
|
||||
"downloadFailed": {
|
||||
"title": "Unduhan Gagal",
|
||||
"description": "Unduhan {{item}} gagal"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "Memvalidasi Model",
|
||||
"description": "Model \"{{modelId}}\" berhasil diunduh. Memverifikasi integritas..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "Validasi Model Gagal",
|
||||
"description": "Model yang diunduh \"{{modelId}}\" gagal verifikasi integritas dan telah dihapus. File mungkin rusak atau telah dimanipulasi."
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "Unduhan Selesai",
|
||||
"description": "Model \"{{item}}\" berhasil diunduh dan diverifikasi"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "Tampilkan Varian",
|
||||
"useModel": "Gunakan model ini",
|
||||
"downloadModel": "Unduh model",
|
||||
"tools": "Alat",
|
||||
"searchPlaceholder": "Cari model di Hugging Face...",
|
||||
"joyride": {
|
||||
"recommendedModelTitle": "Model yang Direkomendasikan",
|
||||
|
||||
@ -249,6 +249,22 @@
|
||||
"downloadCancelled": {
|
||||
"title": "Đã hủy tải xuống",
|
||||
"description": "Quá trình tải xuống đã bị hủy"
|
||||
},
|
||||
"downloadFailed": {
|
||||
"title": "Tải xuống thất bại",
|
||||
"description": "Tải xuống {{item}} thất bại"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "Đang xác thực mô hình",
|
||||
"description": "Đã tải xuống mô hình \"{{modelId}}\" thành công. Đang xác minh tính toàn vẹn..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "Xác thực mô hình thất bại",
|
||||
"description": "Mô hình đã tải xuống \"{{modelId}}\" không vượt qua kiểm tra tính toàn vẹn và đã bị xóa. Tệp có thể bị hỏng hoặc bị giả mạo."
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "Tải xuống hoàn tất",
|
||||
"description": "Mô hình \"{{item}}\" đã được tải xuống và xác minh thành công"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "Hiển thị biến thể",
|
||||
"useModel": "Sử dụng mô hình này",
|
||||
"downloadModel": "Tải xuống mô hình",
|
||||
"tools": "Công cụ",
|
||||
"searchPlaceholder": "Tìm kiếm các mô hình trên Hugging Face...",
|
||||
"joyride": {
|
||||
"recommendedModelTitle": "Mô hình được đề xuất",
|
||||
|
||||
@ -249,6 +249,22 @@
|
||||
"downloadCancelled": {
|
||||
"title": "下载已取消",
|
||||
"description": "下载过程已取消"
|
||||
},
|
||||
"downloadFailed": {
|
||||
"title": "下载失败",
|
||||
"description": "{{item}} 下载失败"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "正在验证模型",
|
||||
"description": "模型 \"{{modelId}}\" 下载成功。正在验证完整性..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "模型验证失败",
|
||||
"description": "已下载的模型 \"{{modelId}}\" 未通过完整性验证并已被删除。文件可能损坏或被篡改。"
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "下载完成",
|
||||
"description": "模型 \"{{item}}\" 下载并验证成功"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "显示变体",
|
||||
"useModel": "使用此模型",
|
||||
"downloadModel": "下载模型",
|
||||
"tools": "工具",
|
||||
"searchPlaceholder": "在 Hugging Face 上搜索模型...",
|
||||
"joyride": {
|
||||
"recommendedModelTitle": "推荐模型",
|
||||
|
||||
@ -249,6 +249,22 @@
|
||||
"downloadCancelled": {
|
||||
"title": "下載已取消",
|
||||
"description": "下載過程已取消"
|
||||
},
|
||||
"downloadFailed": {
|
||||
"title": "下載失敗",
|
||||
"description": "{{item}} 下載失敗"
|
||||
},
|
||||
"modelValidationStarted": {
|
||||
"title": "正在驗證模型",
|
||||
"description": "模型 \"{{modelId}}\" 下載成功。正在驗證完整性..."
|
||||
},
|
||||
"modelValidationFailed": {
|
||||
"title": "模型驗證失敗",
|
||||
"description": "已下載的模型 \"{{modelId}}\" 未通過完整性驗證並已被刪除。檔案可能損壞或被篡改。"
|
||||
},
|
||||
"downloadAndVerificationComplete": {
|
||||
"title": "下載完成",
|
||||
"description": "模型 \"{{item}}\" 下載並驗證成功"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"showVariants": "顯示變體",
|
||||
"useModel": "使用此模型",
|
||||
"downloadModel": "下載模型",
|
||||
"tools": "工具",
|
||||
"searchPlaceholder": "在 Hugging Face 上搜尋模型...",
|
||||
"joyride": {
|
||||
"recommendedModelTitle": "推薦模型",
|
||||
|
||||
@ -22,7 +22,7 @@ import {
|
||||
CatalogModel,
|
||||
convertHfRepoToCatalogModel,
|
||||
fetchHuggingFaceRepo,
|
||||
pullModel,
|
||||
pullModelWithMetadata,
|
||||
} from '@/services/models'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import { Button } from '@/components/ui/button'
|
||||
@ -408,9 +408,11 @@ function HubModelDetail() {
|
||||
addLocalDownloadingModel(
|
||||
variant.model_id
|
||||
)
|
||||
pullModel(
|
||||
pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path
|
||||
variant.path,
|
||||
modelData.mmproj_models?.[0]?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
}}
|
||||
className={cn(isDownloading && 'hidden')}
|
||||
|
||||
@ -41,7 +41,7 @@ import {
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
import {
|
||||
CatalogModel,
|
||||
pullModel,
|
||||
pullModelWithMetadata,
|
||||
fetchHuggingFaceRepo,
|
||||
convertHfRepoToCatalogModel,
|
||||
} from '@/services/models'
|
||||
@ -313,7 +313,12 @@ function Hub() {
|
||||
// Immediately set local downloading state
|
||||
addLocalDownloadingModel(modelId)
|
||||
const mmprojPath = model.mmproj_models?.[0]?.path
|
||||
pullModel(modelId, modelUrl, mmprojPath)
|
||||
pullModelWithMetadata(
|
||||
modelId,
|
||||
modelUrl,
|
||||
mmprojPath,
|
||||
huggingfaceToken
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
@ -812,12 +817,13 @@ function Hub() {
|
||||
addLocalDownloadingModel(
|
||||
variant.model_id
|
||||
)
|
||||
pullModel(
|
||||
pullModelWithMetadata(
|
||||
variant.model_id,
|
||||
variant.path,
|
||||
filteredModels[
|
||||
virtualItem.index
|
||||
].mmproj_models?.[0]?.path
|
||||
].mmproj_models?.[0]?.path,
|
||||
huggingfaceToken
|
||||
)
|
||||
}}
|
||||
>
|
||||
|
||||
@ -325,7 +325,7 @@ describe('models service', () => {
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||
{
|
||||
headers: {},
|
||||
}
|
||||
@ -344,7 +344,7 @@ describe('models service', () => {
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium'
|
||||
)
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||
{
|
||||
headers: {},
|
||||
}
|
||||
@ -353,7 +353,7 @@ describe('models service', () => {
|
||||
// Test with domain prefix
|
||||
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||
{
|
||||
headers: {},
|
||||
}
|
||||
@ -362,7 +362,7 @@ describe('models service', () => {
|
||||
// Test with trailing slash
|
||||
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||
{
|
||||
headers: {},
|
||||
}
|
||||
@ -391,7 +391,7 @@ describe('models service', () => {
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/nonexistent/model?blobs=true',
|
||||
'https://huggingface.co/api/models/nonexistent/model?blobs=true&files_metadata=true',
|
||||
{
|
||||
headers: {},
|
||||
}
|
||||
|
||||
@ -62,6 +62,11 @@ export interface HuggingFaceRepo {
|
||||
rfilename: string
|
||||
size?: number
|
||||
blobId?: string
|
||||
lfs?: {
|
||||
sha256: string
|
||||
size: number
|
||||
pointerSize: number
|
||||
}
|
||||
}>
|
||||
readme?: string
|
||||
}
|
||||
@ -126,7 +131,7 @@ export const fetchHuggingFaceRepo = async (
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true`,
|
||||
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true&files_metadata=true`,
|
||||
{
|
||||
headers: hfToken
|
||||
? {
|
||||
@ -237,14 +242,103 @@ export const updateModel = async (
|
||||
export const pullModel = async (
|
||||
id: string,
|
||||
modelPath: string,
|
||||
mmprojPath?: string
|
||||
modelSha256?: string,
|
||||
modelSize?: number,
|
||||
mmprojPath?: string,
|
||||
mmprojSha256?: string,
|
||||
mmprojSize?: number
|
||||
) => {
|
||||
return getEngine()?.import(id, {
|
||||
modelPath,
|
||||
mmprojPath,
|
||||
modelSha256,
|
||||
modelSize,
|
||||
mmprojSha256,
|
||||
mmprojSize,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model with real-time metadata fetching from HuggingFace.
|
||||
* Extracts hash and size information from the model URL for both main model and mmproj files.
|
||||
* @param id The model ID
|
||||
* @param modelPath The model file URL (HuggingFace download URL)
|
||||
* @param mmprojPath Optional mmproj file URL
|
||||
* @param hfToken Optional HuggingFace token for authentication
|
||||
* @returns A promise that resolves when the model download task is created.
|
||||
*/
|
||||
export const pullModelWithMetadata = async (
|
||||
id: string,
|
||||
modelPath: string,
|
||||
mmprojPath?: string,
|
||||
hfToken?: string
|
||||
) => {
|
||||
let modelSha256: string | undefined
|
||||
let modelSize: number | undefined
|
||||
let mmprojSha256: string | undefined
|
||||
let mmprojSize: number | undefined
|
||||
|
||||
// Extract repo ID from model URL
|
||||
// URL format: https://huggingface.co/{repo}/resolve/main/{filename}
|
||||
const modelUrlMatch = modelPath.match(
|
||||
/https:\/\/huggingface\.co\/([^/]+\/[^/]+)\/resolve\/main\/(.+)/
|
||||
)
|
||||
|
||||
if (modelUrlMatch) {
|
||||
const [, repoId, modelFilename] = modelUrlMatch
|
||||
|
||||
try {
|
||||
// Fetch real-time metadata from HuggingFace
|
||||
const repoInfo = await fetchHuggingFaceRepo(repoId, hfToken)
|
||||
|
||||
if (repoInfo?.siblings) {
|
||||
// Find the specific model file
|
||||
const modelFile = repoInfo.siblings.find(
|
||||
(file) => file.rfilename === modelFilename
|
||||
)
|
||||
if (modelFile?.lfs) {
|
||||
modelSha256 = modelFile.lfs.sha256
|
||||
modelSize = modelFile.lfs.size
|
||||
}
|
||||
|
||||
// If mmproj path provided, extract its metadata too
|
||||
if (mmprojPath) {
|
||||
const mmprojUrlMatch = mmprojPath.match(
|
||||
/https:\/\/huggingface\.co\/[^/]+\/[^/]+\/resolve\/main\/(.+)/
|
||||
)
|
||||
if (mmprojUrlMatch) {
|
||||
const [, mmprojFilename] = mmprojUrlMatch
|
||||
const mmprojFile = repoInfo.siblings.find(
|
||||
(file) => file.rfilename === mmprojFilename
|
||||
)
|
||||
if (mmprojFile?.lfs) {
|
||||
mmprojSha256 = mmprojFile.lfs.sha256
|
||||
mmprojSize = mmprojFile.lfs.size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'Failed to fetch HuggingFace metadata, proceeding without hash verification:',
|
||||
error
|
||||
)
|
||||
// Continue with download even if metadata fetch fails
|
||||
}
|
||||
}
|
||||
|
||||
// Call the original pullModel with the fetched metadata
|
||||
return pullModel(
|
||||
id,
|
||||
modelPath,
|
||||
modelSha256,
|
||||
modelSize,
|
||||
mmprojPath,
|
||||
mmprojSha256,
|
||||
mmprojSize
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Aborts a model download.
|
||||
* @param id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user