feat: improve download extension (#5073)
This commit is contained in:
parent
dfe15fac32
commit
56f4ec3b61
@ -6,104 +6,59 @@ export enum Settings {
|
|||||||
hfToken = 'hf-token',
|
hfToken = 'hf-token',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface DownloadItem {
|
||||||
|
url: string
|
||||||
|
save_path: string
|
||||||
|
}
|
||||||
|
|
||||||
type DownloadEvent = {
|
type DownloadEvent = {
|
||||||
task_id: string
|
transferred: number
|
||||||
total_size: number
|
total: number
|
||||||
downloaded_size: number
|
|
||||||
download_type: string
|
|
||||||
event_type: string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export default class DownloadManager extends BaseExtension {
|
export default class DownloadManager extends BaseExtension {
|
||||||
hf_token?: string
|
hfToken?: string
|
||||||
|
|
||||||
async onLoad() {
|
async onLoad() {
|
||||||
this.registerSettings(SETTINGS)
|
this.registerSettings(SETTINGS)
|
||||||
this.hf_token = await this.getSetting<string>(Settings.hfToken, undefined)
|
this.hfToken = await this.getSetting<string>(Settings.hfToken, undefined)
|
||||||
}
|
}
|
||||||
|
|
||||||
async onUnload() { }
|
async onUnload() { }
|
||||||
|
|
||||||
async downloadFile(url: string, path: string, taskId: string) {
|
async downloadFile(
|
||||||
// relay tauri events to Jan events
|
url: string,
|
||||||
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
savePath: string,
|
||||||
let payload = event.payload
|
taskId: string,
|
||||||
let eventName = {
|
onProgress?: (transferred: number, total: number) => void
|
||||||
Updated: 'onFileDownloadUpdate',
|
) {
|
||||||
Error: 'onFileDownloadError',
|
return await this.downloadFiles(
|
||||||
Success: 'onFileDownloadSuccess',
|
[{ url, save_path: savePath }],
|
||||||
Stopped: 'onFileDownloadStopped',
|
taskId,
|
||||||
Started: 'onFileDownloadStarted',
|
onProgress
|
||||||
}[payload.event_type]
|
)
|
||||||
|
|
||||||
// remove this once event system is back in web-app
|
|
||||||
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
|
||||||
|
|
||||||
events.emit(eventName, {
|
|
||||||
modelId: taskId,
|
|
||||||
percent: payload.downloaded_size / payload.total_size,
|
|
||||||
size: {
|
|
||||||
transferred: payload.downloaded_size,
|
|
||||||
total: payload.total_size,
|
|
||||||
},
|
|
||||||
downloadType: payload.download_type,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
try {
|
|
||||||
await invoke<void>(
|
|
||||||
"download_file",
|
|
||||||
{ url, path, taskId, headers: this._getHeaders() },
|
|
||||||
)
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error downloading file:", error)
|
|
||||||
events.emit('onFileDownloadError', {
|
|
||||||
modelId: url,
|
|
||||||
downloadType: 'Model',
|
|
||||||
})
|
|
||||||
throw error
|
|
||||||
} finally {
|
|
||||||
unlisten()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) {
|
async downloadFiles(
|
||||||
// relay tauri events to Jan events
|
items: DownloadItem[],
|
||||||
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
taskId: string,
|
||||||
let payload = event.payload
|
onProgress?: (transferred: number, total: number) => void
|
||||||
let eventName = {
|
) {
|
||||||
Updated: 'onFileDownloadUpdate',
|
// relay tauri events to onProgress callback
|
||||||
Error: 'onFileDownloadError',
|
const unlisten = await listen<DownloadEvent>(`download-${taskId}`, (event) => {
|
||||||
Success: 'onFileDownloadSuccess',
|
if (onProgress) {
|
||||||
Stopped: 'onFileDownloadStopped',
|
let payload = event.payload
|
||||||
Started: 'onFileDownloadStarted',
|
onProgress(payload.transferred, payload.total)
|
||||||
}[payload.event_type]
|
}
|
||||||
|
|
||||||
// remove this once event system is back in web-app
|
|
||||||
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
|
||||||
|
|
||||||
events.emit(eventName, {
|
|
||||||
modelId: taskId,
|
|
||||||
percent: payload.downloaded_size / payload.total_size,
|
|
||||||
size: {
|
|
||||||
transferred: payload.downloaded_size,
|
|
||||||
total: payload.total_size,
|
|
||||||
},
|
|
||||||
downloadType: payload.download_type,
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await invoke<void>(
|
await invoke<void>(
|
||||||
"download_hf_repo",
|
"download_files",
|
||||||
{ modelId, saveDir, taskId, branch, headers: this._getHeaders() },
|
{ items, taskId, headers: this._getHeaders() },
|
||||||
)
|
)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error downloading file:", error)
|
console.error("Error downloading task", taskId, error)
|
||||||
events.emit('onFileDownloadError', {
|
|
||||||
modelId: modelId,
|
|
||||||
downloadType: 'Model',
|
|
||||||
})
|
|
||||||
throw error
|
throw error
|
||||||
} finally {
|
} finally {
|
||||||
unlisten()
|
unlisten()
|
||||||
@ -121,7 +76,7 @@ export default class DownloadManager extends BaseExtension {
|
|||||||
|
|
||||||
_getHeaders() {
|
_getHeaders() {
|
||||||
return {
|
return {
|
||||||
...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` })
|
...(this.hfToken && { Authorization: `Bearer ${this.hfToken}` })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,7 +4,6 @@ use crate::core::utils::normalize_path;
|
|||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tauri::{Emitter, State};
|
use tauri::{Emitter, State};
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
@ -16,24 +15,16 @@ pub struct DownloadManagerState {
|
|||||||
pub cancel_tokens: HashMap<String, CancellationToken>,
|
pub cancel_tokens: HashMap<String, CancellationToken>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// this is to emulate the current way of downloading files by Cortex + Jan
|
#[derive(serde::Deserialize, Clone, Debug)]
|
||||||
// we can change this later
|
pub struct DownloadItem {
|
||||||
#[derive(serde::Serialize, Clone, Debug, PartialEq)]
|
pub url: String,
|
||||||
pub enum DownloadEventType {
|
pub save_path: String,
|
||||||
Started,
|
|
||||||
Updated,
|
|
||||||
Success,
|
|
||||||
// Error, // we don't need to emit an Error event. just return an error directly
|
|
||||||
Stopped,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize, Clone, Debug)]
|
#[derive(serde::Serialize, Clone, Debug)]
|
||||||
pub struct DownloadEvent {
|
pub struct DownloadEvent {
|
||||||
pub task_id: String,
|
pub transferred: u64,
|
||||||
pub total_size: u64,
|
pub total: u64,
|
||||||
pub downloaded_size: u64,
|
|
||||||
pub download_type: String, // TODO: make this an enum as well
|
|
||||||
pub event_type: DownloadEventType,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
||||||
@ -41,11 +32,10 @@ fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn download_file(
|
pub async fn download_files(
|
||||||
app: tauri::AppHandle,
|
app: tauri::AppHandle,
|
||||||
state: State<'_, AppState>,
|
state: State<'_, AppState>,
|
||||||
url: &str,
|
items: Vec<DownloadItem>,
|
||||||
path: &Path,
|
|
||||||
task_id: &str,
|
task_id: &str,
|
||||||
headers: HashMap<String, String>,
|
headers: HashMap<String, String>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
@ -53,172 +43,34 @@ pub async fn download_file(
|
|||||||
let cancel_token = CancellationToken::new();
|
let cancel_token = CancellationToken::new();
|
||||||
{
|
{
|
||||||
let mut download_manager = state.download_manager.lock().await;
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
if download_manager.cancel_tokens.contains_key(url) {
|
if download_manager.cancel_tokens.contains_key(task_id) {
|
||||||
return Err(format!("URL {} is already being downloaded", url));
|
return Err(format!("task_id {} exists", task_id));
|
||||||
}
|
}
|
||||||
download_manager
|
download_manager
|
||||||
.cancel_tokens
|
.cancel_tokens
|
||||||
.insert(task_id.to_string(), cancel_token.clone());
|
.insert(task_id.to_string(), cancel_token.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
let result =
|
||||||
let total_size = _get_file_size(url, header_map.clone())
|
_download_files_internal(app.clone(), &items, &headers, task_id, cancel_token.clone())
|
||||||
.await
|
.await;
|
||||||
.map_err(err_to_string)?;
|
|
||||||
log::info!("File size: {}", total_size);
|
|
||||||
let mut evt = DownloadEvent {
|
|
||||||
task_id: task_id.to_string(),
|
|
||||||
total_size,
|
|
||||||
downloaded_size: 0,
|
|
||||||
download_type: "Model".to_string(),
|
|
||||||
event_type: DownloadEventType::Started,
|
|
||||||
};
|
|
||||||
app.emit("download", evt.clone()).unwrap();
|
|
||||||
|
|
||||||
// save file under Jan data folder
|
|
||||||
let data_dir = get_jan_data_folder_path(app.clone());
|
|
||||||
let save_path = data_dir.join(path);
|
|
||||||
|
|
||||||
let mut has_error = false;
|
|
||||||
let mut error_msg = String::new();
|
|
||||||
match _download_file_internal(
|
|
||||||
app.clone(),
|
|
||||||
url,
|
|
||||||
&save_path,
|
|
||||||
header_map.clone(),
|
|
||||||
evt,
|
|
||||||
cancel_token.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(evt_) => {
|
|
||||||
evt = evt_; // reassign ownership
|
|
||||||
}
|
|
||||||
Err((evt_, e)) => {
|
|
||||||
evt = evt_; // reassign ownership
|
|
||||||
error_msg = format!("Failed to download file: {}", e);
|
|
||||||
log::error!("{}", error_msg);
|
|
||||||
has_error = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup
|
// cleanup
|
||||||
{
|
{
|
||||||
let mut download_manager = state.download_manager.lock().await;
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
download_manager.cancel_tokens.remove(url);
|
download_manager.cancel_tokens.remove(task_id);
|
||||||
}
|
|
||||||
if has_error {
|
|
||||||
let _ = std::fs::remove_file(&save_path); // don't check error
|
|
||||||
return Err(error_msg);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// emit final event
|
// delete files if cancelled
|
||||||
if evt.event_type == DownloadEventType::Stopped {
|
if cancel_token.is_cancelled() {
|
||||||
let _ = std::fs::remove_file(&save_path); // don't check error
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||||
} else {
|
for item in items {
|
||||||
evt.event_type = DownloadEventType::Success;
|
let save_path = jan_data_folder.join(&item.save_path);
|
||||||
}
|
let _ = std::fs::remove_file(&save_path); // don't check error
|
||||||
app.emit("download", evt.clone()).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tauri::command]
|
|
||||||
pub async fn download_hf_repo(
|
|
||||||
app: tauri::AppHandle,
|
|
||||||
state: State<'_, AppState>,
|
|
||||||
model_id: &str,
|
|
||||||
save_dir: &Path,
|
|
||||||
task_id: &str,
|
|
||||||
branch: Option<&str>,
|
|
||||||
headers: HashMap<String, String>,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
let branch_str = branch.unwrap_or("main");
|
|
||||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
|
||||||
|
|
||||||
log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str);
|
|
||||||
|
|
||||||
// get all files from repo, including subdirs
|
|
||||||
let items = _list_hf_repo_files(model_id, branch, header_map.clone())
|
|
||||||
.await
|
|
||||||
.map_err(err_to_string)?;
|
|
||||||
|
|
||||||
// insert cancel tokens
|
|
||||||
let cancel_token = CancellationToken::new();
|
|
||||||
{
|
|
||||||
let mut download_manager = state.download_manager.lock().await;
|
|
||||||
if download_manager.cancel_tokens.contains_key(model_id) {
|
|
||||||
return Err(format!("model_id {} is already being downloaded", model_id));
|
|
||||||
}
|
|
||||||
download_manager
|
|
||||||
.cancel_tokens
|
|
||||||
.insert(task_id.to_string(), cancel_token.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let total_size = items.iter().map(|f| f.size).sum::<u64>();
|
|
||||||
let mut evt = DownloadEvent {
|
|
||||||
task_id: task_id.to_string(),
|
|
||||||
total_size,
|
|
||||||
downloaded_size: 0,
|
|
||||||
download_type: "Model".to_string(),
|
|
||||||
event_type: DownloadEventType::Started,
|
|
||||||
};
|
|
||||||
app.emit("download", evt.clone()).unwrap();
|
|
||||||
|
|
||||||
let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir);
|
|
||||||
let mut has_error = false;
|
|
||||||
let mut error_msg = String::new();
|
|
||||||
for item in items {
|
|
||||||
let url = format!(
|
|
||||||
"https://huggingface.co/{}/resolve/{}/{}",
|
|
||||||
model_id, branch_str, item.path
|
|
||||||
);
|
|
||||||
let save_path = local_dir.join(&item.path);
|
|
||||||
match _download_file_internal(
|
|
||||||
app.clone(),
|
|
||||||
&url,
|
|
||||||
&save_path,
|
|
||||||
header_map.clone(),
|
|
||||||
evt,
|
|
||||||
cancel_token.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(evt_) => {
|
|
||||||
evt = evt_; // reassign ownership
|
|
||||||
if evt.event_type == DownloadEventType::Stopped {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err((evt_, e)) => {
|
|
||||||
evt = evt_; // reassign ownership
|
|
||||||
error_msg = format!("Failed to download file: {}", e);
|
|
||||||
log::error!("{}", error_msg);
|
|
||||||
has_error = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanup
|
result.map_err(err_to_string)
|
||||||
{
|
|
||||||
let mut download_manager = state.download_manager.lock().await;
|
|
||||||
download_manager.cancel_tokens.remove(model_id);
|
|
||||||
}
|
|
||||||
if has_error {
|
|
||||||
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
|
||||||
return Err(error_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
// emit final event
|
|
||||||
if evt.event_type == DownloadEventType::Stopped {
|
|
||||||
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
|
||||||
} else {
|
|
||||||
evt.event_type = DownloadEventType::Success;
|
|
||||||
}
|
|
||||||
app.emit("download", evt.clone()).unwrap();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
@ -227,31 +79,30 @@ pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) ->
|
|||||||
let mut download_manager = state.download_manager.lock().await;
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
|
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
|
||||||
token.cancel();
|
token.cancel();
|
||||||
log::info!("Cancelled download task_id: {}", task_id);
|
log::info!("Cancelled download task: {}", task_id);
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err(format!("No download task_id: {}", task_id))
|
Err(format!("No download task: {}", task_id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn _convert_headers(
|
fn _convert_headers(
|
||||||
headers: HashMap<String, String>,
|
headers: &HashMap<String, String>,
|
||||||
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
|
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
|
||||||
let mut header_map = HeaderMap::new();
|
let mut header_map = HeaderMap::new();
|
||||||
for (k, v) in headers {
|
for (k, v) in headers {
|
||||||
let key = HeaderName::from_bytes(k.as_bytes())?;
|
let key = HeaderName::from_bytes(k.as_bytes())?;
|
||||||
let value = HeaderValue::from_str(&v)?;
|
let value = HeaderValue::from_str(v)?;
|
||||||
header_map.insert(key, value);
|
header_map.insert(key, value);
|
||||||
}
|
}
|
||||||
Ok(header_map)
|
Ok(header_map)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn _get_file_size(
|
async fn _get_file_size(
|
||||||
|
client: &reqwest::Client,
|
||||||
url: &str,
|
url: &str,
|
||||||
header_map: HeaderMap,
|
|
||||||
) -> Result<u64, Box<dyn std::error::Error>> {
|
) -> Result<u64, Box<dyn std::error::Error>> {
|
||||||
let client = reqwest::Client::new();
|
let resp = client.head(url).send().await?;
|
||||||
let resp = client.head(url).headers(header_map).send().await?;
|
|
||||||
if !resp.status().is_success() {
|
if !resp.status().is_success() {
|
||||||
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
|
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
|
||||||
}
|
}
|
||||||
@ -268,33 +119,16 @@ async fn _get_file_size(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: Caller of this function should pass ownership of `evt` to this function
|
async fn _download_files_internal(
|
||||||
// (no .clone()) and obtain it back. Both Ok and Err will return ownership of
|
|
||||||
// the modified `evt` object back to the caller.
|
|
||||||
async fn _download_file_internal(
|
|
||||||
app: tauri::AppHandle,
|
app: tauri::AppHandle,
|
||||||
url: &str,
|
items: &[DownloadItem],
|
||||||
path: &Path, // this is absolute path
|
headers: &HashMap<String, String>,
|
||||||
header_map: HeaderMap,
|
task_id: &str,
|
||||||
mut evt: DownloadEvent,
|
|
||||||
cancel_token: CancellationToken,
|
cancel_token: CancellationToken,
|
||||||
) -> Result<DownloadEvent, (DownloadEvent, Box<dyn std::error::Error>)> {
|
) -> Result<(), String> {
|
||||||
log::info!("Downloading file: {}", url);
|
log::info!("Start download task: {}", task_id);
|
||||||
|
|
||||||
// normalize and enforce scope
|
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||||
let path = normalize_path(path);
|
|
||||||
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
|
||||||
if !path.starts_with(&jan_data_folder) {
|
|
||||||
return Err((
|
|
||||||
evt.clone(),
|
|
||||||
format!(
|
|
||||||
"Path {} is outside of Jan data folder {}",
|
|
||||||
path.display(),
|
|
||||||
jan_data_folder.display()
|
|
||||||
)
|
|
||||||
.into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
|
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
|
||||||
// compatible with hyper 0.14
|
// compatible with hyper 0.14
|
||||||
@ -302,120 +136,89 @@ async fn _download_file_internal(
|
|||||||
.http2_keep_alive_timeout(Duration::from_secs(15))
|
.http2_keep_alive_timeout(Duration::from_secs(15))
|
||||||
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
|
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
|
||||||
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
|
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
|
||||||
|
.default_headers(header_map.clone())
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| (evt.clone(), e.into()))?;
|
.map_err(err_to_string)?;
|
||||||
|
|
||||||
let resp = client
|
let total_size = {
|
||||||
.get(url)
|
let mut total_size = 0u64;
|
||||||
.headers(header_map)
|
for item in items.iter() {
|
||||||
.send()
|
total_size += _get_file_size(&client, &item.url)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| (evt.clone(), e.into()))?;
|
.map_err(err_to_string)?;
|
||||||
|
}
|
||||||
|
total_size
|
||||||
|
};
|
||||||
|
log::info!("Total download size: {}", total_size);
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
let mut evt = DownloadEvent {
|
||||||
return Err((
|
transferred: 0,
|
||||||
evt,
|
total: total_size,
|
||||||
format!(
|
};
|
||||||
|
let evt_name = format!("download-{}", task_id);
|
||||||
|
|
||||||
|
// save file under Jan data folder
|
||||||
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||||
|
|
||||||
|
for item in items.iter() {
|
||||||
|
let save_path = jan_data_folder.join(&item.save_path);
|
||||||
|
let save_path = normalize_path(&save_path);
|
||||||
|
|
||||||
|
// enforce scope
|
||||||
|
if !save_path.starts_with(&jan_data_folder) {
|
||||||
|
return Err(format!(
|
||||||
|
"Path {} is outside of Jan data folder {}",
|
||||||
|
save_path.display(),
|
||||||
|
jan_data_folder.display()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Started downloading: {}", item.url);
|
||||||
|
let resp = client.get(&item.url).send().await.map_err(err_to_string)?;
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(format!(
|
||||||
"Failed to download: HTTP status {}, {}",
|
"Failed to download: HTTP status {}, {}",
|
||||||
resp.status(),
|
resp.status(),
|
||||||
resp.text().await.unwrap_or_default()
|
resp.text().await.unwrap_or_default()
|
||||||
)
|
));
|
||||||
.into(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create parent directories if they don't exist
|
|
||||||
if let Some(parent) = path.parent() {
|
|
||||||
if !parent.exists() {
|
|
||||||
std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut file = File::create(&path)
|
|
||||||
.await
|
|
||||||
.map_err(|e| (evt.clone(), e.into()))?;
|
|
||||||
|
|
||||||
// write chunk to file
|
|
||||||
let mut stream = resp.bytes_stream();
|
|
||||||
let mut download_delta = 0u64;
|
|
||||||
evt.event_type = DownloadEventType::Updated;
|
|
||||||
|
|
||||||
while let Some(chunk) = stream.next().await {
|
|
||||||
if cancel_token.is_cancelled() {
|
|
||||||
log::info!("Download cancelled: {}", url);
|
|
||||||
evt.event_type = DownloadEventType::Stopped;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?;
|
// Create parent directories if they don't exist
|
||||||
file.write_all(&chunk)
|
if let Some(parent) = save_path.parent() {
|
||||||
.await
|
if !parent.exists() {
|
||||||
.map_err(|e| (evt.clone(), e.into()))?;
|
tokio::fs::create_dir_all(parent).await.map_err(err_to_string)?;
|
||||||
download_delta += chunk.len() as u64;
|
|
||||||
|
|
||||||
// only update every 1MB
|
|
||||||
if download_delta >= 1024 * 1024 {
|
|
||||||
evt.downloaded_size += download_delta;
|
|
||||||
app.emit("download", evt.clone()).unwrap();
|
|
||||||
download_delta = 0u64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup
|
|
||||||
file.flush().await.map_err(|e| (evt.clone(), e.into()))?;
|
|
||||||
if evt.event_type == DownloadEventType::Stopped {
|
|
||||||
let _ = std::fs::remove_file(&path); // don't check error
|
|
||||||
}
|
|
||||||
|
|
||||||
// caller should emit a final event after calling this function
|
|
||||||
evt.downloaded_size += download_delta;
|
|
||||||
|
|
||||||
Ok(evt)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(serde::Deserialize)]
|
|
||||||
struct HfItem {
|
|
||||||
r#type: String,
|
|
||||||
// oid: String, // unused
|
|
||||||
path: String,
|
|
||||||
size: u64,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn _list_hf_repo_files(
|
|
||||||
model_id: &str,
|
|
||||||
branch: Option<&str>,
|
|
||||||
header_map: HeaderMap,
|
|
||||||
) -> Result<Vec<HfItem>, Box<dyn std::error::Error>> {
|
|
||||||
let branch_str = branch.unwrap_or("main");
|
|
||||||
|
|
||||||
let mut files = vec![];
|
|
||||||
|
|
||||||
// DFS
|
|
||||||
let mut stack = vec!["".to_string()];
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
while let Some(subdir) = stack.pop() {
|
|
||||||
let url = format!(
|
|
||||||
"https://huggingface.co/api/models/{}/tree/{}/{}",
|
|
||||||
model_id, branch_str, subdir
|
|
||||||
);
|
|
||||||
let resp = client.get(&url).headers(header_map.clone()).send().await?;
|
|
||||||
|
|
||||||
if !resp.status().is_success() {
|
|
||||||
return Err(format!(
|
|
||||||
"Failed to list files: HTTP status {}, {}",
|
|
||||||
resp.status(),
|
|
||||||
resp.text().await.unwrap_or_default(),
|
|
||||||
)
|
|
||||||
.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
for item in resp.json::<Vec<HfItem>>().await?.into_iter() {
|
|
||||||
if item.r#type == "directory" {
|
|
||||||
stack.push(item.path);
|
|
||||||
} else {
|
|
||||||
files.push(item);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let mut file = File::create(&save_path).await.map_err(err_to_string)?;
|
||||||
|
|
||||||
|
// write chunk to file
|
||||||
|
let mut stream = resp.bytes_stream();
|
||||||
|
let mut download_delta = 0u64;
|
||||||
|
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
log::info!("Download cancelled for task: {}", task_id);
|
||||||
|
app.emit(&evt_name, evt.clone()).unwrap();
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = chunk.map_err(err_to_string)?;
|
||||||
|
file.write_all(&chunk).await.map_err(err_to_string)?;
|
||||||
|
download_delta += chunk.len() as u64;
|
||||||
|
|
||||||
|
// only update every 1MB
|
||||||
|
if download_delta >= 1024 * 1024 {
|
||||||
|
evt.transferred += download_delta;
|
||||||
|
app.emit(&evt_name, evt.clone()).unwrap();
|
||||||
|
download_delta = 0u64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
file.flush().await.map_err(err_to_string)?;
|
||||||
|
evt.transferred += download_delta;
|
||||||
|
log::info!("Finished downloading: {}", item.url);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(files)
|
app.emit(&evt_name, evt.clone()).unwrap();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -70,8 +70,7 @@ pub fn run() {
|
|||||||
core::threads::create_thread_assistant,
|
core::threads::create_thread_assistant,
|
||||||
core::threads::modify_thread_assistant,
|
core::threads::modify_thread_assistant,
|
||||||
// Download
|
// Download
|
||||||
core::utils::download::download_file,
|
core::utils::download::download_files,
|
||||||
core::utils::download::download_hf_repo,
|
|
||||||
core::utils::download::cancel_download_task,
|
core::utils::download::cancel_download_task,
|
||||||
// hardware
|
// hardware
|
||||||
core::hardware::get_system_info,
|
core::hardware::get_system_info,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user