From 56f4ec3b61397168209b92f6588393284c7c0186 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 23 May 2025 16:49:41 +0800 Subject: [PATCH] feat: improve download extension (#5073) --- extensions/download-extension/src/index.ts | 117 ++---- src-tauri/src/core/utils/download.rs | 411 ++++++--------------- src-tauri/src/lib.rs | 3 +- 3 files changed, 144 insertions(+), 387 deletions(-) diff --git a/extensions/download-extension/src/index.ts b/extensions/download-extension/src/index.ts index 639fd677f..11315ba85 100644 --- a/extensions/download-extension/src/index.ts +++ b/extensions/download-extension/src/index.ts @@ -6,104 +6,59 @@ export enum Settings { hfToken = 'hf-token', } +interface DownloadItem { + url: string + save_path: string +} + type DownloadEvent = { - task_id: string - total_size: number - downloaded_size: number - download_type: string - event_type: string + transferred: number + total: number } export default class DownloadManager extends BaseExtension { - hf_token?: string + hfToken?: string async onLoad() { this.registerSettings(SETTINGS) - this.hf_token = await this.getSetting(Settings.hfToken, undefined) + this.hfToken = await this.getSetting(Settings.hfToken, undefined) } async onUnload() { } - async downloadFile(url: string, path: string, taskId: string) { - // relay tauri events to Jan events - const unlisten = await listen('download', (event) => { - let payload = event.payload - let eventName = { - Updated: 'onFileDownloadUpdate', - Error: 'onFileDownloadError', - Success: 'onFileDownloadSuccess', - Stopped: 'onFileDownloadStopped', - Started: 'onFileDownloadStarted', - }[payload.event_type] - - // remove this once event system is back in web-app - console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size) - - events.emit(eventName, { - modelId: taskId, - percent: payload.downloaded_size / payload.total_size, - size: { - transferred: payload.downloaded_size, - total: payload.total_size, - }, - downloadType: payload.download_type, - }) - }) - - try { - await invoke( - "download_file", - { url, path, taskId, headers: this._getHeaders() }, - ) - } catch (error) { - console.error("Error downloading file:", error) - events.emit('onFileDownloadError', { - modelId: url, - downloadType: 'Model', - }) - throw error - } finally { - unlisten() - } + async downloadFile( + url: string, + savePath: string, + taskId: string, + onProgress?: (transferred: number, total: number) => void + ) { + return await this.downloadFiles( + [{ url, save_path: savePath }], + taskId, + onProgress + ) } - async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) { - // relay tauri events to Jan events - const unlisten = await listen('download', (event) => { - let payload = event.payload - let eventName = { - Updated: 'onFileDownloadUpdate', - Error: 'onFileDownloadError', - Success: 'onFileDownloadSuccess', - Stopped: 'onFileDownloadStopped', - Started: 'onFileDownloadStarted', - }[payload.event_type] - - // remove this once event system is back in web-app - console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size) - - events.emit(eventName, { - modelId: taskId, - percent: payload.downloaded_size / payload.total_size, - size: { - transferred: payload.downloaded_size, - total: payload.total_size, - }, - downloadType: payload.download_type, - }) + async downloadFiles( + items: DownloadItem[], + taskId: string, + onProgress?: (transferred: number, total: number) => void + ) { + // relay tauri events to onProgress callback + const unlisten = await listen(`download-${taskId}`, (event) => { + if (onProgress) { + let payload = event.payload + onProgress(payload.transferred, payload.total) + } }) try { await invoke( - "download_hf_repo", - { modelId, saveDir, taskId, branch, headers: this._getHeaders() }, - ) + "download_files", + { items, taskId, headers: this._getHeaders() }, + ) } catch (error) { - console.error("Error downloading file:", error) - events.emit('onFileDownloadError', { - modelId: modelId, - downloadType: 'Model', - }) + console.error("Error downloading task", taskId, error) throw error } finally { unlisten() @@ -121,7 +76,7 @@ export default class DownloadManager extends BaseExtension { _getHeaders() { return { - ...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` }) + ...(this.hfToken && { Authorization: `Bearer ${this.hfToken}` }) } } } diff --git a/src-tauri/src/core/utils/download.rs b/src-tauri/src/core/utils/download.rs index 4ec4d057b..5aa685ac8 100644 --- a/src-tauri/src/core/utils/download.rs +++ b/src-tauri/src/core/utils/download.rs @@ -4,7 +4,6 @@ use crate::core::utils::normalize_path; use futures_util::StreamExt; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use std::collections::HashMap; -use std::path::Path; use std::time::Duration; use tauri::{Emitter, State}; use tokio::fs::File; @@ -16,24 +15,16 @@ pub struct DownloadManagerState { pub cancel_tokens: HashMap, } -// this is to emulate the current way of downloading files by Cortex + Jan -// we can change this later -#[derive(serde::Serialize, Clone, Debug, PartialEq)] -pub enum DownloadEventType { - Started, - Updated, - Success, - // Error, // we don't need to emit an Error event. just return an error directly - Stopped, +#[derive(serde::Deserialize, Clone, Debug)] +pub struct DownloadItem { + pub url: String, + pub save_path: String, } #[derive(serde::Serialize, Clone, Debug)] pub struct DownloadEvent { - pub task_id: String, - pub total_size: u64, - pub downloaded_size: u64, - pub download_type: String, // TODO: make this an enum as well - pub event_type: DownloadEventType, + pub transferred: u64, + pub total: u64, } fn err_to_string(e: E) -> String { @@ -41,11 +32,10 @@ fn err_to_string(e: E) -> String { } #[tauri::command] -pub async fn download_file( +pub async fn download_files( app: tauri::AppHandle, state: State<'_, AppState>, - url: &str, - path: &Path, + items: Vec, task_id: &str, headers: HashMap, ) -> Result<(), String> { @@ -53,172 +43,34 @@ pub async fn download_file( let cancel_token = CancellationToken::new(); { let mut download_manager = state.download_manager.lock().await; - if download_manager.cancel_tokens.contains_key(url) { - return Err(format!("URL {} is already being downloaded", url)); + if download_manager.cancel_tokens.contains_key(task_id) { + return Err(format!("task_id {} exists", task_id)); } download_manager .cancel_tokens .insert(task_id.to_string(), cancel_token.clone()); } - let header_map = _convert_headers(headers).map_err(err_to_string)?; - let total_size = _get_file_size(url, header_map.clone()) - .await - .map_err(err_to_string)?; - log::info!("File size: {}", total_size); - let mut evt = DownloadEvent { - task_id: task_id.to_string(), - total_size, - downloaded_size: 0, - download_type: "Model".to_string(), - event_type: DownloadEventType::Started, - }; - app.emit("download", evt.clone()).unwrap(); - - // save file under Jan data folder - let data_dir = get_jan_data_folder_path(app.clone()); - let save_path = data_dir.join(path); - - let mut has_error = false; - let mut error_msg = String::new(); - match _download_file_internal( - app.clone(), - url, - &save_path, - header_map.clone(), - evt, - cancel_token.clone(), - ) - .await - { - Ok(evt_) => { - evt = evt_; // reassign ownership - } - Err((evt_, e)) => { - evt = evt_; // reassign ownership - error_msg = format!("Failed to download file: {}", e); - log::error!("{}", error_msg); - has_error = true; - } - } + let result = + _download_files_internal(app.clone(), &items, &headers, task_id, cancel_token.clone()) + .await; // cleanup { let mut download_manager = state.download_manager.lock().await; - download_manager.cancel_tokens.remove(url); - } - if has_error { - let _ = std::fs::remove_file(&save_path); // don't check error - return Err(error_msg); + download_manager.cancel_tokens.remove(task_id); } - // emit final event - if evt.event_type == DownloadEventType::Stopped { - let _ = std::fs::remove_file(&save_path); // don't check error - } else { - evt.event_type = DownloadEventType::Success; - } - app.emit("download", evt.clone()).unwrap(); - - Ok(()) -} - -#[tauri::command] -pub async fn download_hf_repo( - app: tauri::AppHandle, - state: State<'_, AppState>, - model_id: &str, - save_dir: &Path, - task_id: &str, - branch: Option<&str>, - headers: HashMap, -) -> Result<(), String> { - let branch_str = branch.unwrap_or("main"); - let header_map = _convert_headers(headers).map_err(err_to_string)?; - - log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str); - - // get all files from repo, including subdirs - let items = _list_hf_repo_files(model_id, branch, header_map.clone()) - .await - .map_err(err_to_string)?; - - // insert cancel tokens - let cancel_token = CancellationToken::new(); - { - let mut download_manager = state.download_manager.lock().await; - if download_manager.cancel_tokens.contains_key(model_id) { - return Err(format!("model_id {} is already being downloaded", model_id)); - } - download_manager - .cancel_tokens - .insert(task_id.to_string(), cancel_token.clone()); - } - - let total_size = items.iter().map(|f| f.size).sum::(); - let mut evt = DownloadEvent { - task_id: task_id.to_string(), - total_size, - downloaded_size: 0, - download_type: "Model".to_string(), - event_type: DownloadEventType::Started, - }; - app.emit("download", evt.clone()).unwrap(); - - let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir); - let mut has_error = false; - let mut error_msg = String::new(); - for item in items { - let url = format!( - "https://huggingface.co/{}/resolve/{}/{}", - model_id, branch_str, item.path - ); - let save_path = local_dir.join(&item.path); - match _download_file_internal( - app.clone(), - &url, - &save_path, - header_map.clone(), - evt, - cancel_token.clone(), - ) - .await - { - Ok(evt_) => { - evt = evt_; // reassign ownership - if evt.event_type == DownloadEventType::Stopped { - break; - } - } - Err((evt_, e)) => { - evt = evt_; // reassign ownership - error_msg = format!("Failed to download file: {}", e); - log::error!("{}", error_msg); - has_error = true; - break; - } + // delete files if cancelled + if cancel_token.is_cancelled() { + let jan_data_folder = get_jan_data_folder_path(app.clone()); + for item in items { + let save_path = jan_data_folder.join(&item.save_path); + let _ = std::fs::remove_file(&save_path); // don't check error } } - // cleanup - { - let mut download_manager = state.download_manager.lock().await; - download_manager.cancel_tokens.remove(model_id); - } - if has_error { - let _ = std::fs::remove_dir_all(&local_dir); // don't check error - return Err(error_msg); - } - - // emit final event - if evt.event_type == DownloadEventType::Stopped { - let _ = std::fs::remove_dir_all(&local_dir); // don't check error - } else { - evt.event_type = DownloadEventType::Success; - } - app.emit("download", evt.clone()).unwrap(); - - Ok(()) + result.map_err(err_to_string) } #[tauri::command] @@ -227,31 +79,30 @@ pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> let mut download_manager = state.download_manager.lock().await; if let Some(token) = download_manager.cancel_tokens.remove(task_id) { token.cancel(); - log::info!("Cancelled download task_id: {}", task_id); + log::info!("Cancelled download task: {}", task_id); Ok(()) } else { - Err(format!("No download task_id: {}", task_id)) + Err(format!("No download task: {}", task_id)) } } fn _convert_headers( - headers: HashMap, + headers: &HashMap, ) -> Result> { let mut header_map = HeaderMap::new(); for (k, v) in headers { let key = HeaderName::from_bytes(k.as_bytes())?; - let value = HeaderValue::from_str(&v)?; + let value = HeaderValue::from_str(v)?; header_map.insert(key, value); } Ok(header_map) } async fn _get_file_size( + client: &reqwest::Client, url: &str, - header_map: HeaderMap, ) -> Result> { - let client = reqwest::Client::new(); - let resp = client.head(url).headers(header_map).send().await?; + let resp = client.head(url).send().await?; if !resp.status().is_success() { return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into()); } @@ -268,33 +119,16 @@ async fn _get_file_size( } } -// NOTE: Caller of this function should pass ownership of `evt` to this function -// (no .clone()) and obtain it back. Both Ok and Err will return ownership of -// the modified `evt` object back to the caller. -async fn _download_file_internal( +async fn _download_files_internal( app: tauri::AppHandle, - url: &str, - path: &Path, // this is absolute path - header_map: HeaderMap, - mut evt: DownloadEvent, + items: &[DownloadItem], + headers: &HashMap, + task_id: &str, cancel_token: CancellationToken, -) -> Result)> { - log::info!("Downloading file: {}", url); +) -> Result<(), String> { + log::info!("Start download task: {}", task_id); - // normalize and enforce scope - let path = normalize_path(path); - let jan_data_folder = get_jan_data_folder_path(app.clone()); - if !path.starts_with(&jan_data_folder) { - return Err(( - evt.clone(), - format!( - "Path {} is outside of Jan data folder {}", - path.display(), - jan_data_folder.display() - ) - .into(), - )); - } + let header_map = _convert_headers(headers).map_err(err_to_string)?; // .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not // compatible with hyper 0.14 @@ -302,120 +136,89 @@ async fn _download_file_internal( .http2_keep_alive_timeout(Duration::from_secs(15)) // .read_timeout(Duration::from_secs(10)) // timeout between chunks // .connect_timeout(Duration::from_secs(10)) // timeout for first connection + .default_headers(header_map.clone()) .build() - .map_err(|e| (evt.clone(), e.into()))?; + .map_err(err_to_string)?; - let resp = client - .get(url) - .headers(header_map) - .send() - .await - .map_err(|e| (evt.clone(), e.into()))?; + let total_size = { + let mut total_size = 0u64; + for item in items.iter() { + total_size += _get_file_size(&client, &item.url) + .await + .map_err(err_to_string)?; + } + total_size + }; + log::info!("Total download size: {}", total_size); - if !resp.status().is_success() { - return Err(( - evt, - format!( + let mut evt = DownloadEvent { + transferred: 0, + total: total_size, + }; + let evt_name = format!("download-{}", task_id); + + // save file under Jan data folder + let jan_data_folder = get_jan_data_folder_path(app.clone()); + + for item in items.iter() { + let save_path = jan_data_folder.join(&item.save_path); + let save_path = normalize_path(&save_path); + + // enforce scope + if !save_path.starts_with(&jan_data_folder) { + return Err(format!( + "Path {} is outside of Jan data folder {}", + save_path.display(), + jan_data_folder.display() + )); + } + + log::info!("Started downloading: {}", item.url); + let resp = client.get(&item.url).send().await.map_err(err_to_string)?; + if !resp.status().is_success() { + return Err(format!( "Failed to download: HTTP status {}, {}", resp.status(), resp.text().await.unwrap_or_default() - ) - .into(), - )); - } - - // Create parent directories if they don't exist - if let Some(parent) = path.parent() { - if !parent.exists() { - std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?; - } - } - let mut file = File::create(&path) - .await - .map_err(|e| (evt.clone(), e.into()))?; - - // write chunk to file - let mut stream = resp.bytes_stream(); - let mut download_delta = 0u64; - evt.event_type = DownloadEventType::Updated; - - while let Some(chunk) = stream.next().await { - if cancel_token.is_cancelled() { - log::info!("Download cancelled: {}", url); - evt.event_type = DownloadEventType::Stopped; - break; + )); } - let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?; - file.write_all(&chunk) - .await - .map_err(|e| (evt.clone(), e.into()))?; - download_delta += chunk.len() as u64; - - // only update every 1MB - if download_delta >= 1024 * 1024 { - evt.downloaded_size += download_delta; - app.emit("download", evt.clone()).unwrap(); - download_delta = 0u64; - } - } - - // cleanup - file.flush().await.map_err(|e| (evt.clone(), e.into()))?; - if evt.event_type == DownloadEventType::Stopped { - let _ = std::fs::remove_file(&path); // don't check error - } - - // caller should emit a final event after calling this function - evt.downloaded_size += download_delta; - - Ok(evt) -} - -#[derive(serde::Deserialize)] -struct HfItem { - r#type: String, - // oid: String, // unused - path: String, - size: u64, -} - -async fn _list_hf_repo_files( - model_id: &str, - branch: Option<&str>, - header_map: HeaderMap, -) -> Result, Box> { - let branch_str = branch.unwrap_or("main"); - - let mut files = vec![]; - - // DFS - let mut stack = vec!["".to_string()]; - let client = reqwest::Client::new(); - while let Some(subdir) = stack.pop() { - let url = format!( - "https://huggingface.co/api/models/{}/tree/{}/{}", - model_id, branch_str, subdir - ); - let resp = client.get(&url).headers(header_map.clone()).send().await?; - - if !resp.status().is_success() { - return Err(format!( - "Failed to list files: HTTP status {}, {}", - resp.status(), - resp.text().await.unwrap_or_default(), - ) - .into()); - } - - for item in resp.json::>().await?.into_iter() { - if item.r#type == "directory" { - stack.push(item.path); - } else { - files.push(item); + // Create parent directories if they don't exist + if let Some(parent) = save_path.parent() { + if !parent.exists() { + tokio::fs::create_dir_all(parent).await.map_err(err_to_string)?; } } + let mut file = File::create(&save_path).await.map_err(err_to_string)?; + + // write chunk to file + let mut stream = resp.bytes_stream(); + let mut download_delta = 0u64; + + while let Some(chunk) = stream.next().await { + if cancel_token.is_cancelled() { + log::info!("Download cancelled for task: {}", task_id); + app.emit(&evt_name, evt.clone()).unwrap(); + return Ok(()); + } + + let chunk = chunk.map_err(err_to_string)?; + file.write_all(&chunk).await.map_err(err_to_string)?; + download_delta += chunk.len() as u64; + + // only update every 1MB + if download_delta >= 1024 * 1024 { + evt.transferred += download_delta; + app.emit(&evt_name, evt.clone()).unwrap(); + download_delta = 0u64; + } + } + + file.flush().await.map_err(err_to_string)?; + evt.transferred += download_delta; + log::info!("Finished downloading: {}", item.url); } - Ok(files) + app.emit(&evt_name, evt.clone()).unwrap(); + Ok(()) } diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index cc689d97c..432ae1358 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -70,8 +70,7 @@ pub fn run() { core::threads::create_thread_assistant, core::threads::modify_thread_assistant, // Download - core::utils::download::download_file, - core::utils::download::download_hf_repo, + core::utils::download::download_files, core::utils::download::cancel_download_task, // hardware core::hardware::get_system_info,