feat: improve download extension (#5073)

This commit is contained in:
Thien Tran 2025-05-23 16:49:41 +08:00 committed by GitHub
parent dfe15fac32
commit 56f4ec3b61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 144 additions and 387 deletions

View File

@ -6,104 +6,59 @@ export enum Settings {
hfToken = 'hf-token',
}
interface DownloadItem {
url: string
save_path: string
}
type DownloadEvent = {
task_id: string
total_size: number
downloaded_size: number
download_type: string
event_type: string
transferred: number
total: number
}
export default class DownloadManager extends BaseExtension {
hf_token?: string
hfToken?: string
async onLoad() {
this.registerSettings(SETTINGS)
this.hf_token = await this.getSetting<string>(Settings.hfToken, undefined)
this.hfToken = await this.getSetting<string>(Settings.hfToken, undefined)
}
async onUnload() { }
async downloadFile(url: string, path: string, taskId: string) {
// relay tauri events to Jan events
const unlisten = await listen<DownloadEvent>('download', (event) => {
async downloadFile(
url: string,
savePath: string,
taskId: string,
onProgress?: (transferred: number, total: number) => void
) {
return await this.downloadFiles(
[{ url, save_path: savePath }],
taskId,
onProgress
)
}
async downloadFiles(
items: DownloadItem[],
taskId: string,
onProgress?: (transferred: number, total: number) => void
) {
// relay tauri events to onProgress callback
const unlisten = await listen<DownloadEvent>(`download-${taskId}`, (event) => {
if (onProgress) {
let payload = event.payload
let eventName = {
Updated: 'onFileDownloadUpdate',
Error: 'onFileDownloadError',
Success: 'onFileDownloadSuccess',
Stopped: 'onFileDownloadStopped',
Started: 'onFileDownloadStarted',
}[payload.event_type]
// remove this once event system is back in web-app
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
events.emit(eventName, {
modelId: taskId,
percent: payload.downloaded_size / payload.total_size,
size: {
transferred: payload.downloaded_size,
total: payload.total_size,
},
downloadType: payload.download_type,
})
onProgress(payload.transferred, payload.total)
}
})
try {
await invoke<void>(
"download_file",
{ url, path, taskId, headers: this._getHeaders() },
"download_files",
{ items, taskId, headers: this._getHeaders() },
)
} catch (error) {
console.error("Error downloading file:", error)
events.emit('onFileDownloadError', {
modelId: url,
downloadType: 'Model',
})
throw error
} finally {
unlisten()
}
}
async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) {
// relay tauri events to Jan events
const unlisten = await listen<DownloadEvent>('download', (event) => {
let payload = event.payload
let eventName = {
Updated: 'onFileDownloadUpdate',
Error: 'onFileDownloadError',
Success: 'onFileDownloadSuccess',
Stopped: 'onFileDownloadStopped',
Started: 'onFileDownloadStarted',
}[payload.event_type]
// remove this once event system is back in web-app
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
events.emit(eventName, {
modelId: taskId,
percent: payload.downloaded_size / payload.total_size,
size: {
transferred: payload.downloaded_size,
total: payload.total_size,
},
downloadType: payload.download_type,
})
})
try {
await invoke<void>(
"download_hf_repo",
{ modelId, saveDir, taskId, branch, headers: this._getHeaders() },
)
} catch (error) {
console.error("Error downloading file:", error)
events.emit('onFileDownloadError', {
modelId: modelId,
downloadType: 'Model',
})
console.error("Error downloading task", taskId, error)
throw error
} finally {
unlisten()
@ -121,7 +76,7 @@ export default class DownloadManager extends BaseExtension {
_getHeaders() {
return {
...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` })
...(this.hfToken && { Authorization: `Bearer ${this.hfToken}` })
}
}
}

View File

@ -4,7 +4,6 @@ use crate::core::utils::normalize_path;
use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use tauri::{Emitter, State};
use tokio::fs::File;
@ -16,24 +15,16 @@ pub struct DownloadManagerState {
pub cancel_tokens: HashMap<String, CancellationToken>,
}
// this is to emulate the current way of downloading files by Cortex + Jan
// we can change this later
#[derive(serde::Serialize, Clone, Debug, PartialEq)]
pub enum DownloadEventType {
Started,
Updated,
Success,
// Error, // we don't need to emit an Error event. just return an error directly
Stopped,
#[derive(serde::Deserialize, Clone, Debug)]
pub struct DownloadItem {
pub url: String,
pub save_path: String,
}
#[derive(serde::Serialize, Clone, Debug)]
pub struct DownloadEvent {
pub task_id: String,
pub total_size: u64,
pub downloaded_size: u64,
pub download_type: String, // TODO: make this an enum as well
pub event_type: DownloadEventType,
pub transferred: u64,
pub total: u64,
}
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
@ -41,11 +32,10 @@ fn err_to_string<E: std::fmt::Display>(e: E) -> String {
}
#[tauri::command]
pub async fn download_file(
pub async fn download_files(
app: tauri::AppHandle,
state: State<'_, AppState>,
url: &str,
path: &Path,
items: Vec<DownloadItem>,
task_id: &str,
headers: HashMap<String, String>,
) -> Result<(), String> {
@ -53,172 +43,34 @@ pub async fn download_file(
let cancel_token = CancellationToken::new();
{
let mut download_manager = state.download_manager.lock().await;
if download_manager.cancel_tokens.contains_key(url) {
return Err(format!("URL {} is already being downloaded", url));
if download_manager.cancel_tokens.contains_key(task_id) {
return Err(format!("task_id {} exists", task_id));
}
download_manager
.cancel_tokens
.insert(task_id.to_string(), cancel_token.clone());
}
let header_map = _convert_headers(headers).map_err(err_to_string)?;
let total_size = _get_file_size(url, header_map.clone())
.await
.map_err(err_to_string)?;
log::info!("File size: {}", total_size);
let mut evt = DownloadEvent {
task_id: task_id.to_string(),
total_size,
downloaded_size: 0,
download_type: "Model".to_string(),
event_type: DownloadEventType::Started,
};
app.emit("download", evt.clone()).unwrap();
// save file under Jan data folder
let data_dir = get_jan_data_folder_path(app.clone());
let save_path = data_dir.join(path);
let mut has_error = false;
let mut error_msg = String::new();
match _download_file_internal(
app.clone(),
url,
&save_path,
header_map.clone(),
evt,
cancel_token.clone(),
)
.await
{
Ok(evt_) => {
evt = evt_; // reassign ownership
}
Err((evt_, e)) => {
evt = evt_; // reassign ownership
error_msg = format!("Failed to download file: {}", e);
log::error!("{}", error_msg);
has_error = true;
}
}
let result =
_download_files_internal(app.clone(), &items, &headers, task_id, cancel_token.clone())
.await;
// cleanup
{
let mut download_manager = state.download_manager.lock().await;
download_manager.cancel_tokens.remove(url);
}
if has_error {
let _ = std::fs::remove_file(&save_path); // don't check error
return Err(error_msg);
download_manager.cancel_tokens.remove(task_id);
}
// emit final event
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_file(&save_path); // don't check error
} else {
evt.event_type = DownloadEventType::Success;
}
app.emit("download", evt.clone()).unwrap();
Ok(())
}
#[tauri::command]
pub async fn download_hf_repo(
app: tauri::AppHandle,
state: State<'_, AppState>,
model_id: &str,
save_dir: &Path,
task_id: &str,
branch: Option<&str>,
headers: HashMap<String, String>,
) -> Result<(), String> {
let branch_str = branch.unwrap_or("main");
let header_map = _convert_headers(headers).map_err(err_to_string)?;
log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str);
// get all files from repo, including subdirs
let items = _list_hf_repo_files(model_id, branch, header_map.clone())
.await
.map_err(err_to_string)?;
// insert cancel tokens
let cancel_token = CancellationToken::new();
{
let mut download_manager = state.download_manager.lock().await;
if download_manager.cancel_tokens.contains_key(model_id) {
return Err(format!("model_id {} is already being downloaded", model_id));
}
download_manager
.cancel_tokens
.insert(task_id.to_string(), cancel_token.clone());
}
let total_size = items.iter().map(|f| f.size).sum::<u64>();
let mut evt = DownloadEvent {
task_id: task_id.to_string(),
total_size,
downloaded_size: 0,
download_type: "Model".to_string(),
event_type: DownloadEventType::Started,
};
app.emit("download", evt.clone()).unwrap();
let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir);
let mut has_error = false;
let mut error_msg = String::new();
// delete files if cancelled
if cancel_token.is_cancelled() {
let jan_data_folder = get_jan_data_folder_path(app.clone());
for item in items {
let url = format!(
"https://huggingface.co/{}/resolve/{}/{}",
model_id, branch_str, item.path
);
let save_path = local_dir.join(&item.path);
match _download_file_internal(
app.clone(),
&url,
&save_path,
header_map.clone(),
evt,
cancel_token.clone(),
)
.await
{
Ok(evt_) => {
evt = evt_; // reassign ownership
if evt.event_type == DownloadEventType::Stopped {
break;
}
}
Err((evt_, e)) => {
evt = evt_; // reassign ownership
error_msg = format!("Failed to download file: {}", e);
log::error!("{}", error_msg);
has_error = true;
break;
}
let save_path = jan_data_folder.join(&item.save_path);
let _ = std::fs::remove_file(&save_path); // don't check error
}
}
// cleanup
{
let mut download_manager = state.download_manager.lock().await;
download_manager.cancel_tokens.remove(model_id);
}
if has_error {
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
return Err(error_msg);
}
// emit final event
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
} else {
evt.event_type = DownloadEventType::Success;
}
app.emit("download", evt.clone()).unwrap();
Ok(())
result.map_err(err_to_string)
}
#[tauri::command]
@ -227,31 +79,30 @@ pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) ->
let mut download_manager = state.download_manager.lock().await;
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
token.cancel();
log::info!("Cancelled download task_id: {}", task_id);
log::info!("Cancelled download task: {}", task_id);
Ok(())
} else {
Err(format!("No download task_id: {}", task_id))
Err(format!("No download task: {}", task_id))
}
}
fn _convert_headers(
headers: HashMap<String, String>,
headers: &HashMap<String, String>,
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
let mut header_map = HeaderMap::new();
for (k, v) in headers {
let key = HeaderName::from_bytes(k.as_bytes())?;
let value = HeaderValue::from_str(&v)?;
let value = HeaderValue::from_str(v)?;
header_map.insert(key, value);
}
Ok(header_map)
}
async fn _get_file_size(
client: &reqwest::Client,
url: &str,
header_map: HeaderMap,
) -> Result<u64, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let resp = client.head(url).headers(header_map).send().await?;
let resp = client.head(url).send().await?;
if !resp.status().is_success() {
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
}
@ -268,33 +119,16 @@ async fn _get_file_size(
}
}
// NOTE: Caller of this function should pass ownership of `evt` to this function
// (no .clone()) and obtain it back. Both Ok and Err will return ownership of
// the modified `evt` object back to the caller.
async fn _download_file_internal(
async fn _download_files_internal(
app: tauri::AppHandle,
url: &str,
path: &Path, // this is absolute path
header_map: HeaderMap,
mut evt: DownloadEvent,
items: &[DownloadItem],
headers: &HashMap<String, String>,
task_id: &str,
cancel_token: CancellationToken,
) -> Result<DownloadEvent, (DownloadEvent, Box<dyn std::error::Error>)> {
log::info!("Downloading file: {}", url);
) -> Result<(), String> {
log::info!("Start download task: {}", task_id);
// normalize and enforce scope
let path = normalize_path(path);
let jan_data_folder = get_jan_data_folder_path(app.clone());
if !path.starts_with(&jan_data_folder) {
return Err((
evt.clone(),
format!(
"Path {} is outside of Jan data folder {}",
path.display(),
jan_data_folder.display()
)
.into(),
));
}
let header_map = _convert_headers(headers).map_err(err_to_string)?;
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
// compatible with hyper 0.14
@ -302,120 +136,89 @@ async fn _download_file_internal(
.http2_keep_alive_timeout(Duration::from_secs(15))
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
.default_headers(header_map.clone())
.build()
.map_err(|e| (evt.clone(), e.into()))?;
.map_err(err_to_string)?;
let resp = client
.get(url)
.headers(header_map)
.send()
let total_size = {
let mut total_size = 0u64;
for item in items.iter() {
total_size += _get_file_size(&client, &item.url)
.await
.map_err(|e| (evt.clone(), e.into()))?;
.map_err(err_to_string)?;
}
total_size
};
log::info!("Total download size: {}", total_size);
let mut evt = DownloadEvent {
transferred: 0,
total: total_size,
};
let evt_name = format!("download-{}", task_id);
// save file under Jan data folder
let jan_data_folder = get_jan_data_folder_path(app.clone());
for item in items.iter() {
let save_path = jan_data_folder.join(&item.save_path);
let save_path = normalize_path(&save_path);
// enforce scope
if !save_path.starts_with(&jan_data_folder) {
return Err(format!(
"Path {} is outside of Jan data folder {}",
save_path.display(),
jan_data_folder.display()
));
}
log::info!("Started downloading: {}", item.url);
let resp = client.get(&item.url).send().await.map_err(err_to_string)?;
if !resp.status().is_success() {
return Err((
evt,
format!(
return Err(format!(
"Failed to download: HTTP status {}, {}",
resp.status(),
resp.text().await.unwrap_or_default()
)
.into(),
));
}
// Create parent directories if they don't exist
if let Some(parent) = path.parent() {
if let Some(parent) = save_path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?;
tokio::fs::create_dir_all(parent).await.map_err(err_to_string)?;
}
}
let mut file = File::create(&path)
.await
.map_err(|e| (evt.clone(), e.into()))?;
let mut file = File::create(&save_path).await.map_err(err_to_string)?;
// write chunk to file
let mut stream = resp.bytes_stream();
let mut download_delta = 0u64;
evt.event_type = DownloadEventType::Updated;
while let Some(chunk) = stream.next().await {
if cancel_token.is_cancelled() {
log::info!("Download cancelled: {}", url);
evt.event_type = DownloadEventType::Stopped;
break;
log::info!("Download cancelled for task: {}", task_id);
app.emit(&evt_name, evt.clone()).unwrap();
return Ok(());
}
let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?;
file.write_all(&chunk)
.await
.map_err(|e| (evt.clone(), e.into()))?;
let chunk = chunk.map_err(err_to_string)?;
file.write_all(&chunk).await.map_err(err_to_string)?;
download_delta += chunk.len() as u64;
// only update every 1MB
if download_delta >= 1024 * 1024 {
evt.downloaded_size += download_delta;
app.emit("download", evt.clone()).unwrap();
evt.transferred += download_delta;
app.emit(&evt_name, evt.clone()).unwrap();
download_delta = 0u64;
}
}
// cleanup
file.flush().await.map_err(|e| (evt.clone(), e.into()))?;
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_file(&path); // don't check error
file.flush().await.map_err(err_to_string)?;
evt.transferred += download_delta;
log::info!("Finished downloading: {}", item.url);
}
// caller should emit a final event after calling this function
evt.downloaded_size += download_delta;
Ok(evt)
}
#[derive(serde::Deserialize)]
struct HfItem {
r#type: String,
// oid: String, // unused
path: String,
size: u64,
}
async fn _list_hf_repo_files(
model_id: &str,
branch: Option<&str>,
header_map: HeaderMap,
) -> Result<Vec<HfItem>, Box<dyn std::error::Error>> {
let branch_str = branch.unwrap_or("main");
let mut files = vec![];
// DFS
let mut stack = vec!["".to_string()];
let client = reqwest::Client::new();
while let Some(subdir) = stack.pop() {
let url = format!(
"https://huggingface.co/api/models/{}/tree/{}/{}",
model_id, branch_str, subdir
);
let resp = client.get(&url).headers(header_map.clone()).send().await?;
if !resp.status().is_success() {
return Err(format!(
"Failed to list files: HTTP status {}, {}",
resp.status(),
resp.text().await.unwrap_or_default(),
)
.into());
}
for item in resp.json::<Vec<HfItem>>().await?.into_iter() {
if item.r#type == "directory" {
stack.push(item.path);
} else {
files.push(item);
}
}
}
Ok(files)
app.emit(&evt_name, evt.clone()).unwrap();
Ok(())
}

View File

@ -70,8 +70,7 @@ pub fn run() {
core::threads::create_thread_assistant,
core::threads::modify_thread_assistant,
// Download
core::utils::download::download_file,
core::utils::download::download_hf_repo,
core::utils::download::download_files,
core::utils::download::cancel_download_task,
// hardware
core::hardware::get_system_info,