625 lines
20 KiB
Rust
625 lines
20 KiB
Rust
use super::models::{DownloadEvent, DownloadItem, ProxyConfig, ProgressTracker};
|
|
use crate::core::app::commands::get_jan_data_folder_path;
|
|
use futures_util::StreamExt;
|
|
use jan_utils::normalize_path;
|
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
|
use std::collections::HashMap;
|
|
use std::path::Path;
|
|
use std::time::Duration;
|
|
use tauri::{Emitter, Runtime};
|
|
use tokio::fs::File;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio_util::sync::CancellationToken;
|
|
use url::Url;
|
|
|
|
// ===== UTILITY FUNCTIONS =====
|
|
|
|
pub fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
|
format!("Error: {}", e)
|
|
}
|
|
|
|
|
|
// ===== VALIDATION FUNCTIONS =====
|
|
|
|
/// Validates a downloaded file against expected hash and size
|
|
async fn validate_downloaded_file(
|
|
item: &DownloadItem,
|
|
save_path: &Path,
|
|
app: &tauri::AppHandle<impl Runtime>,
|
|
cancel_token: &CancellationToken,
|
|
) -> Result<(), String> {
|
|
// Skip validation if no verification data is provided
|
|
if item.sha256.is_none() && item.size.is_none() {
|
|
log::debug!(
|
|
"No validation data provided for {}, skipping validation",
|
|
item.url
|
|
);
|
|
return Ok(());
|
|
}
|
|
|
|
// Extract model ID from save path for validation events
|
|
// Path structure: llamacpp/models/{modelId}/model.gguf or llamacpp/models/{modelId}/mmproj.gguf
|
|
let model_id = save_path
|
|
.parent() // get parent directory (modelId folder)
|
|
.and_then(|p| p.file_name())
|
|
.and_then(|n| n.to_str())
|
|
.unwrap_or("unknown");
|
|
|
|
// Emit validation started event
|
|
app.emit(
|
|
"onModelValidationStarted",
|
|
serde_json::json!({
|
|
"modelId": model_id,
|
|
"downloadType": "Model",
|
|
}),
|
|
)
|
|
.unwrap();
|
|
|
|
log::info!("Starting validation for model: {}", model_id);
|
|
|
|
// Validate size if provided (fast check first)
|
|
if let Some(expected_size) = &item.size {
|
|
log::info!("Starting size verification for {}", item.url);
|
|
|
|
match tokio::fs::metadata(save_path).await {
|
|
Ok(metadata) => {
|
|
let actual_size = metadata.len();
|
|
|
|
if actual_size != *expected_size {
|
|
log::error!(
|
|
"Size verification failed for {}. Expected: {} bytes, Actual: {} bytes",
|
|
item.url,
|
|
expected_size,
|
|
actual_size
|
|
);
|
|
return Err(format!(
|
|
"Size verification failed. Expected {} bytes but got {} bytes.",
|
|
expected_size, actual_size
|
|
));
|
|
}
|
|
|
|
log::info!(
|
|
"Size verification successful for {} ({} bytes)",
|
|
item.url,
|
|
actual_size
|
|
);
|
|
}
|
|
Err(e) => {
|
|
log::error!(
|
|
"Failed to get file metadata for {}: {}",
|
|
save_path.display(),
|
|
e
|
|
);
|
|
return Err(format!("Failed to verify file size: {}", e));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for cancellation before expensive hash computation
|
|
if cancel_token.is_cancelled() {
|
|
log::info!("Validation cancelled for {}", item.url);
|
|
return Err("Validation cancelled".to_string());
|
|
}
|
|
|
|
// Validate hash if provided (expensive check second)
|
|
if let Some(expected_sha256) = &item.sha256 {
|
|
log::info!("Starting Hash verification for {}", item.url);
|
|
|
|
match jan_utils::crypto::compute_file_sha256_with_cancellation(save_path, cancel_token).await {
|
|
Ok(computed_sha256) => {
|
|
if computed_sha256 != *expected_sha256 {
|
|
log::error!(
|
|
"Hash verification failed for {}. Expected: {}, Computed: {}",
|
|
item.url,
|
|
expected_sha256,
|
|
computed_sha256
|
|
);
|
|
|
|
return Err(format!(
|
|
"Hash verification failed. The downloaded file is corrupted or has been tampered with."
|
|
));
|
|
}
|
|
|
|
log::info!("Hash verification successful for {}", item.url);
|
|
}
|
|
Err(e) => {
|
|
log::error!(
|
|
"Failed to compute SHA256 for {}: {}",
|
|
save_path.display(),
|
|
e
|
|
);
|
|
return Err(format!("Failed to verify file integrity: {}", e));
|
|
}
|
|
}
|
|
}
|
|
|
|
log::info!("All validations passed for {}", item.url);
|
|
Ok(())
|
|
}
|
|
|
|
pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> {
|
|
// Validate proxy URL format
|
|
if let Err(e) = Url::parse(&config.url) {
|
|
return Err(format!("Invalid proxy URL '{}': {}", config.url, e));
|
|
}
|
|
|
|
// Check if proxy URL has valid scheme
|
|
let url = Url::parse(&config.url).unwrap(); // Safe to unwrap as we just validated it
|
|
match url.scheme() {
|
|
"http" | "https" | "socks4" | "socks5" => {}
|
|
scheme => return Err(format!("Unsupported proxy scheme: {}", scheme)),
|
|
}
|
|
|
|
// Validate authentication credentials
|
|
if config.username.is_some() && config.password.is_none() {
|
|
return Err("Username provided without password".to_string());
|
|
}
|
|
|
|
if config.password.is_some() && config.username.is_none() {
|
|
return Err("Password provided without username".to_string());
|
|
}
|
|
|
|
// Validate no_proxy entries
|
|
if let Some(no_proxy) = &config.no_proxy {
|
|
for entry in no_proxy {
|
|
if entry.is_empty() {
|
|
return Err("Empty no_proxy entry".to_string());
|
|
}
|
|
// Basic validation for wildcard patterns
|
|
if entry.starts_with("*.") && entry.len() < 3 {
|
|
return Err(format!("Invalid wildcard pattern: {}", entry));
|
|
}
|
|
}
|
|
}
|
|
|
|
// SSL verification settings are all optional booleans, no validation needed
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn create_proxy_from_config(config: &ProxyConfig) -> Result<reqwest::Proxy, String> {
|
|
// Validate the configuration first
|
|
validate_proxy_config(config)?;
|
|
|
|
let mut proxy = reqwest::Proxy::all(&config.url).map_err(err_to_string)?;
|
|
|
|
// Add authentication if provided
|
|
if let (Some(username), Some(password)) = (&config.username, &config.password) {
|
|
proxy = proxy.basic_auth(username, password);
|
|
}
|
|
|
|
Ok(proxy)
|
|
}
|
|
|
|
pub fn should_bypass_proxy(url: &str, no_proxy: &[String]) -> bool {
|
|
if no_proxy.is_empty() {
|
|
return false;
|
|
}
|
|
|
|
// Parse the URL to get the host
|
|
let parsed_url = match Url::parse(url) {
|
|
Ok(u) => u,
|
|
Err(_) => return false,
|
|
};
|
|
|
|
let host = match parsed_url.host_str() {
|
|
Some(h) => h,
|
|
None => return false,
|
|
};
|
|
|
|
// Check if host matches any no_proxy entry
|
|
for entry in no_proxy {
|
|
if entry == "*" {
|
|
return true;
|
|
}
|
|
|
|
// Simple wildcard matching
|
|
if entry.starts_with("*.") {
|
|
let domain = &entry[2..];
|
|
if host.ends_with(domain) {
|
|
return true;
|
|
}
|
|
} else if host == entry {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
false
|
|
}
|
|
|
|
pub fn _get_client_for_item(
|
|
item: &DownloadItem,
|
|
header_map: &HeaderMap,
|
|
) -> Result<reqwest::Client, String> {
|
|
let mut client_builder = reqwest::Client::builder()
|
|
.http2_keep_alive_timeout(Duration::from_secs(15))
|
|
.default_headers(header_map.clone());
|
|
|
|
// Add proxy configuration if provided
|
|
if let Some(proxy_config) = &item.proxy {
|
|
// Handle SSL verification settings
|
|
if proxy_config.ignore_ssl.unwrap_or(false) {
|
|
client_builder = client_builder.danger_accept_invalid_certs(true);
|
|
log::info!("SSL certificate verification disabled for URL {}", item.url);
|
|
}
|
|
|
|
// Note: reqwest doesn't have fine-grained SSL verification controls
|
|
// for verify_proxy_ssl, verify_proxy_host_ssl, verify_peer_ssl, verify_host_ssl
|
|
// These settings are handled by the underlying TLS implementation
|
|
|
|
// Check if this URL should bypass proxy
|
|
let no_proxy = proxy_config.no_proxy.as_deref().unwrap_or(&[]);
|
|
if !should_bypass_proxy(&item.url, no_proxy) {
|
|
let proxy = create_proxy_from_config(proxy_config)?;
|
|
client_builder = client_builder.proxy(proxy);
|
|
log::info!("Using proxy {} for URL {}", proxy_config.url, item.url);
|
|
} else {
|
|
log::info!("Bypassing proxy for URL {}", item.url);
|
|
}
|
|
}
|
|
|
|
client_builder.build().map_err(err_to_string)
|
|
}
|
|
|
|
pub fn _convert_headers(
|
|
headers: &HashMap<String, String>,
|
|
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
|
|
let mut header_map = HeaderMap::new();
|
|
for (k, v) in headers {
|
|
let key = HeaderName::from_bytes(k.as_bytes())?;
|
|
let value = HeaderValue::from_str(v)?;
|
|
header_map.insert(key, value);
|
|
}
|
|
Ok(header_map)
|
|
}
|
|
|
|
pub async fn _get_file_size(
|
|
client: &reqwest::Client,
|
|
url: &str,
|
|
) -> Result<u64, Box<dyn std::error::Error>> {
|
|
let resp = client.head(url).send().await?;
|
|
if !resp.status().is_success() {
|
|
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
|
|
}
|
|
// this is buggy, always return 0 for HEAD request
|
|
// Ok(resp.content_length().unwrap_or(0))
|
|
|
|
match resp.headers().get("content-length") {
|
|
Some(value) => {
|
|
let value_str = value.to_str()?;
|
|
let value_u64: u64 = value_str.parse()?;
|
|
Ok(value_u64)
|
|
}
|
|
None => Ok(0),
|
|
}
|
|
}
|
|
|
|
// ===== MAIN DOWNLOAD FUNCTIONS =====
|
|
|
|
/// Downloads multiple files in parallel with individual progress tracking
|
|
pub async fn _download_files_internal(
|
|
app: tauri::AppHandle<impl Runtime>,
|
|
items: &[DownloadItem],
|
|
headers: &HashMap<String, String>,
|
|
task_id: &str,
|
|
resume: bool,
|
|
cancel_token: CancellationToken,
|
|
) -> Result<(), String> {
|
|
log::info!("Start download task: {}", task_id);
|
|
|
|
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
|
|
|
// Calculate sizes for each file
|
|
let mut file_sizes = HashMap::new();
|
|
for item in items.iter() {
|
|
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
|
let size = _get_file_size(&client, &item.url)
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
file_sizes.insert(item.url.clone(), size);
|
|
}
|
|
|
|
let total_size: u64 = file_sizes.values().sum();
|
|
log::info!("Total download size: {}", total_size);
|
|
|
|
let evt_name = format!("download-{}", task_id);
|
|
|
|
// Create progress tracker
|
|
let progress_tracker = ProgressTracker::new(items, file_sizes.clone());
|
|
|
|
// save file under Jan data folder
|
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
|
|
|
// Collect download tasks for parallel execution
|
|
let mut download_tasks = Vec::new();
|
|
|
|
for (index, item) in items.iter().enumerate() {
|
|
let save_path = jan_data_folder.join(&item.save_path);
|
|
let save_path = normalize_path(&save_path);
|
|
|
|
if !save_path.starts_with(&jan_data_folder) {
|
|
return Err(format!(
|
|
"Path {} is outside of Jan data folder {}",
|
|
save_path.display(),
|
|
jan_data_folder.display()
|
|
));
|
|
}
|
|
|
|
// Spawn download task for each file
|
|
let item_clone = item.clone();
|
|
let app_clone = app.clone();
|
|
let header_map_clone = header_map.clone();
|
|
let cancel_token_clone = cancel_token.clone();
|
|
let evt_name_clone = evt_name.clone();
|
|
let progress_tracker_clone = progress_tracker.clone();
|
|
let file_id = format!("{}-{}", task_id, index);
|
|
let file_size = file_sizes.get(&item.url).copied().unwrap_or(0);
|
|
|
|
let task = tokio::spawn(async move {
|
|
download_single_file(
|
|
app_clone,
|
|
&item_clone,
|
|
&header_map_clone,
|
|
&save_path,
|
|
resume,
|
|
cancel_token_clone,
|
|
evt_name_clone,
|
|
progress_tracker_clone,
|
|
file_id,
|
|
file_size,
|
|
)
|
|
.await
|
|
});
|
|
|
|
download_tasks.push(task);
|
|
}
|
|
|
|
// Wait for all downloads to complete
|
|
let mut validation_tasks = Vec::new();
|
|
for (task, item) in download_tasks.into_iter().zip(items.iter()) {
|
|
let result = task.await.map_err(|e| format!("Task join error: {}", e))?;
|
|
|
|
match result {
|
|
Ok(downloaded_path) => {
|
|
// Spawn validation task in parallel
|
|
let item_clone = item.clone();
|
|
let app_clone = app.clone();
|
|
let path_clone = downloaded_path.clone();
|
|
let cancel_token_clone = cancel_token.clone();
|
|
let validation_task = tokio::spawn(async move {
|
|
validate_downloaded_file(&item_clone, &path_clone, &app_clone, &cancel_token_clone).await
|
|
});
|
|
validation_tasks.push((validation_task, downloaded_path, item.clone()));
|
|
}
|
|
Err(e) => return Err(e),
|
|
}
|
|
}
|
|
|
|
// Wait for all validations to complete
|
|
for (validation_task, save_path, _item) in validation_tasks {
|
|
let validation_result = validation_task
|
|
.await
|
|
.map_err(|e| format!("Validation task join error: {}", e))?;
|
|
|
|
if let Err(validation_error) = validation_result {
|
|
// Clean up the file if validation fails
|
|
let _ = tokio::fs::remove_file(&save_path).await;
|
|
|
|
// Try to clean up the parent directory if it's empty
|
|
if let Some(parent) = save_path.parent() {
|
|
let _ = tokio::fs::remove_dir(parent).await;
|
|
}
|
|
|
|
return Err(validation_error);
|
|
}
|
|
}
|
|
|
|
// Emit final progress
|
|
let (transferred, total) = progress_tracker.get_total_progress().await;
|
|
let final_evt = DownloadEvent { transferred, total };
|
|
app.emit(&evt_name, final_evt).unwrap();
|
|
Ok(())
|
|
}
|
|
|
|
/// Downloads a single file without blocking other downloads
|
|
async fn download_single_file(
|
|
app: tauri::AppHandle<impl Runtime>,
|
|
item: &DownloadItem,
|
|
header_map: &HeaderMap,
|
|
save_path: &std::path::Path,
|
|
resume: bool,
|
|
cancel_token: CancellationToken,
|
|
evt_name: String,
|
|
progress_tracker: ProgressTracker,
|
|
file_id: String,
|
|
_file_size: u64,
|
|
) -> Result<std::path::PathBuf, String> {
|
|
// Create parent directories if they don't exist
|
|
if let Some(parent) = save_path.parent() {
|
|
if !parent.exists() {
|
|
tokio::fs::create_dir_all(parent)
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
}
|
|
}
|
|
|
|
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
|
|
let append_extension = |ext: &str| {
|
|
if current_extension.is_empty() {
|
|
ext.to_string()
|
|
} else {
|
|
format!("{}.{}", current_extension, ext)
|
|
}
|
|
};
|
|
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
|
|
let url_save_path = save_path.with_extension(append_extension("url"));
|
|
|
|
let mut should_resume = resume
|
|
&& tmp_save_path.exists()
|
|
&& tokio::fs::read_to_string(&url_save_path)
|
|
.await
|
|
.map(|url| url == item.url) // check if we resume the same URL
|
|
.unwrap_or(false);
|
|
|
|
tokio::fs::write(&url_save_path, item.url.clone())
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
|
|
log::info!("Started downloading: {}", item.url);
|
|
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
|
let mut download_delta = 0u64;
|
|
let mut initial_progress = 0u64;
|
|
|
|
let resp = if should_resume {
|
|
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
|
|
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
|
|
Ok(resp) => {
|
|
log::info!(
|
|
"Resume download: {}, already downloaded {} bytes",
|
|
item.url,
|
|
downloaded_size
|
|
);
|
|
initial_progress = downloaded_size;
|
|
|
|
// Initialize progress for resumed download
|
|
progress_tracker
|
|
.update_progress(&file_id, downloaded_size)
|
|
.await;
|
|
|
|
// Emit initial combined progress
|
|
let (combined_transferred, combined_total) =
|
|
progress_tracker.get_total_progress().await;
|
|
let evt = DownloadEvent {
|
|
transferred: combined_transferred,
|
|
total: combined_total,
|
|
};
|
|
app.emit(&evt_name, evt).unwrap();
|
|
|
|
resp
|
|
}
|
|
Err(e) => {
|
|
// fallback to normal download
|
|
log::warn!("Failed to resume download: {}", e);
|
|
should_resume = false;
|
|
_get_maybe_resume(&client, &item.url, 0).await?
|
|
}
|
|
}
|
|
} else {
|
|
_get_maybe_resume(&client, &item.url, 0).await?
|
|
};
|
|
let mut stream = resp.bytes_stream();
|
|
|
|
let file = if should_resume {
|
|
// resume download, append to existing file
|
|
tokio::fs::OpenOptions::new()
|
|
.write(true)
|
|
.append(true)
|
|
.open(&tmp_save_path)
|
|
.await
|
|
.map_err(err_to_string)?
|
|
} else {
|
|
// start new download, create a new file
|
|
File::create(&tmp_save_path).await.map_err(err_to_string)?
|
|
};
|
|
let mut writer = tokio::io::BufWriter::new(file);
|
|
let mut total_transferred = initial_progress;
|
|
|
|
// write chunk to file
|
|
while let Some(chunk) = stream.next().await {
|
|
if cancel_token.is_cancelled() {
|
|
if !should_resume {
|
|
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
|
|
.await
|
|
.ok();
|
|
}
|
|
log::info!("Download cancelled: {}", item.url);
|
|
return Err("Download cancelled".to_string());
|
|
}
|
|
|
|
let chunk = chunk.map_err(err_to_string)?;
|
|
writer.write_all(&chunk).await.map_err(err_to_string)?;
|
|
download_delta += chunk.len() as u64;
|
|
total_transferred += chunk.len() as u64;
|
|
|
|
// Update progress every 10 MB
|
|
if download_delta >= 10 * 1024 * 1024 {
|
|
// Update individual file progress
|
|
progress_tracker
|
|
.update_progress(&file_id, total_transferred)
|
|
.await;
|
|
|
|
// Emit combined progress event
|
|
let (combined_transferred, combined_total) =
|
|
progress_tracker.get_total_progress().await;
|
|
let evt = DownloadEvent {
|
|
transferred: combined_transferred,
|
|
total: combined_total,
|
|
};
|
|
app.emit(&evt_name, evt).unwrap();
|
|
|
|
download_delta = 0u64;
|
|
}
|
|
}
|
|
|
|
writer.flush().await.map_err(err_to_string)?;
|
|
|
|
// Final progress update for this file
|
|
progress_tracker
|
|
.update_progress(&file_id, total_transferred)
|
|
.await;
|
|
|
|
// Emit final combined progress
|
|
let (combined_transferred, combined_total) = progress_tracker.get_total_progress().await;
|
|
let evt = DownloadEvent {
|
|
transferred: combined_transferred,
|
|
total: combined_total,
|
|
};
|
|
app.emit(&evt_name, evt).unwrap();
|
|
|
|
// rename tmp file to final file
|
|
tokio::fs::rename(&tmp_save_path, &save_path)
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
tokio::fs::remove_file(&url_save_path)
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
|
|
log::info!("Finished downloading: {}", item.url);
|
|
Ok(save_path.to_path_buf())
|
|
}
|
|
|
|
// ===== HTTP CLIENT HELPER FUNCTIONS =====
|
|
|
|
pub async fn _get_maybe_resume(
|
|
client: &reqwest::Client,
|
|
url: &str,
|
|
start_bytes: u64,
|
|
) -> Result<reqwest::Response, String> {
|
|
if start_bytes > 0 {
|
|
let resp = client
|
|
.get(url)
|
|
.header("Range", format!("bytes={}-", start_bytes))
|
|
.send()
|
|
.await
|
|
.map_err(err_to_string)?;
|
|
if resp.status() != reqwest::StatusCode::PARTIAL_CONTENT {
|
|
return Err(format!(
|
|
"Failed to resume download: HTTP status {}, {}",
|
|
resp.status(),
|
|
resp.text().await.unwrap_or_default()
|
|
));
|
|
}
|
|
Ok(resp)
|
|
} else {
|
|
let resp = client.get(url).send().await.map_err(err_to_string)?;
|
|
if !resp.status().is_success() {
|
|
return Err(format!(
|
|
"Failed to download: HTTP status {}, {}",
|
|
resp.status(),
|
|
resp.text().await.unwrap_or_default()
|
|
));
|
|
}
|
|
Ok(resp)
|
|
}
|
|
}
|