diff --git a/extensions/download-extension/package.json b/extensions/download-extension/package.json new file mode 100644 index 000000000..750934594 --- /dev/null +++ b/extensions/download-extension/package.json @@ -0,0 +1,36 @@ +{ + "name": "@janhq/download-extension", + "productName": "Download Manager", + "version": "1.0.0", + "description": "Handle downloads", + "main": "dist/index.js", + "author": "Jan ", + "license": "AGPL-3.0", + "scripts": { + "test": "vitest run", + "build": "rolldown -c rolldown.config.mjs", + "build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install" + }, + "devDependencies": { + "cpx": "^1.5.0", + "rimraf": "^3.0.2", + "rolldown": "1.0.0-beta.1", + "run-script-os": "^1.1.6", + "typescript": "5.3.3", + "vitest": "^3.0.6" + }, + "files": [ + "dist/*", + "package.json", + "README.md" + ], + "dependencies": { + "@janhq/core": "../../core/package.tgz", + "@tauri-apps/api": "^2.5.0" + }, + "bundleDependencies": [], + "installConfig": { + "hoistingLimits": "workspaces" + }, + "packageManager": "yarn@4.5.3" +} diff --git a/extensions/download-extension/rolldown.config.mjs b/extensions/download-extension/rolldown.config.mjs new file mode 100644 index 000000000..e9b190546 --- /dev/null +++ b/extensions/download-extension/rolldown.config.mjs @@ -0,0 +1,14 @@ +import { defineConfig } from 'rolldown' +import settingJson from './settings.json' with { type: 'json' } + +export default defineConfig({ + input: 'src/index.ts', + output: { + format: 'esm', + file: 'dist/index.js', + }, + platform: 'browser', + define: { + SETTINGS: JSON.stringify(settingJson), + }, +}) diff --git a/extensions/download-extension/settings.json b/extensions/download-extension/settings.json new file mode 100644 index 000000000..f2f50762f --- /dev/null +++ b/extensions/download-extension/settings.json @@ -0,0 +1,14 @@ +[ + { + "key": "hf-token", + "title": "Hugging Face Access Token", + "description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.", + "controllerType": "input", + "controllerProps": { + "value": "", + "placeholder": "hf_**********************************", + "type": "password", + "inputActions": ["unobscure", "copy"] + } + } +] diff --git a/extensions/download-extension/src/@types/global.d.ts b/extensions/download-extension/src/@types/global.d.ts new file mode 100644 index 000000000..4ff21449c --- /dev/null +++ b/extensions/download-extension/src/@types/global.d.ts @@ -0,0 +1 @@ +declare const SETTINGS: SettingComponentProps[] diff --git a/extensions/download-extension/src/index.ts b/extensions/download-extension/src/index.ts new file mode 100644 index 000000000..639fd677f --- /dev/null +++ b/extensions/download-extension/src/index.ts @@ -0,0 +1,127 @@ +import { invoke } from '@tauri-apps/api/core'; +import { listen } from '@tauri-apps/api/event'; +import { BaseExtension, events } from '@janhq/core'; + +export enum Settings { + hfToken = 'hf-token', +} + +type DownloadEvent = { + task_id: string + total_size: number + downloaded_size: number + download_type: string + event_type: string +} + +export default class DownloadManager extends BaseExtension { + hf_token?: string + + async onLoad() { + this.registerSettings(SETTINGS) + this.hf_token = await this.getSetting(Settings.hfToken, undefined) + } + + async onUnload() { } + + async downloadFile(url: string, path: string, taskId: string) { + // relay tauri events to Jan events + const unlisten = await listen('download', (event) => { + let payload = event.payload + let eventName = { + Updated: 'onFileDownloadUpdate', + Error: 'onFileDownloadError', + Success: 'onFileDownloadSuccess', + Stopped: 'onFileDownloadStopped', + Started: 'onFileDownloadStarted', + }[payload.event_type] + + // remove this once event system is back in web-app + console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size) + + events.emit(eventName, { + modelId: taskId, + percent: payload.downloaded_size / payload.total_size, + size: { + transferred: payload.downloaded_size, + total: payload.total_size, + }, + downloadType: payload.download_type, + }) + }) + + try { + await invoke( + "download_file", + { url, path, taskId, headers: this._getHeaders() }, + ) + } catch (error) { + console.error("Error downloading file:", error) + events.emit('onFileDownloadError', { + modelId: url, + downloadType: 'Model', + }) + throw error + } finally { + unlisten() + } + } + + async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) { + // relay tauri events to Jan events + const unlisten = await listen('download', (event) => { + let payload = event.payload + let eventName = { + Updated: 'onFileDownloadUpdate', + Error: 'onFileDownloadError', + Success: 'onFileDownloadSuccess', + Stopped: 'onFileDownloadStopped', + Started: 'onFileDownloadStarted', + }[payload.event_type] + + // remove this once event system is back in web-app + console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size) + + events.emit(eventName, { + modelId: taskId, + percent: payload.downloaded_size / payload.total_size, + size: { + transferred: payload.downloaded_size, + total: payload.total_size, + }, + downloadType: payload.download_type, + }) + }) + + try { + await invoke( + "download_hf_repo", + { modelId, saveDir, taskId, branch, headers: this._getHeaders() }, + ) + } catch (error) { + console.error("Error downloading file:", error) + events.emit('onFileDownloadError', { + modelId: modelId, + downloadType: 'Model', + }) + throw error + } finally { + unlisten() + } + } + + async cancelDownload(taskId: string) { + try { + await invoke("cancel_download_task", { taskId }) + } catch (error) { + console.error("Error cancelling download:", error) + throw error + } + } + + _getHeaders() { + return { + ...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` }) + } + } +} diff --git a/extensions/download-extension/tsconfig.json b/extensions/download-extension/tsconfig.json new file mode 100644 index 000000000..1d3c112d4 --- /dev/null +++ b/extensions/download-extension/tsconfig.json @@ -0,0 +1,15 @@ +{ + "compilerOptions": { + "target": "es2016", + "module": "esnext", + "moduleResolution": "node", + "outDir": "./dist", + "esModuleInterop": true, + "forceConsistentCasingInFileNames": true, + "strict": false, + "skipLibCheck": true, + "rootDir": "./src" + }, + "include": ["./src"], + "exclude": ["**/*.test.ts", "vite.config.ts"] +} diff --git a/extensions/download-extension/vite.config.ts b/extensions/download-extension/vite.config.ts new file mode 100644 index 000000000..a8ad5615f --- /dev/null +++ b/extensions/download-extension/vite.config.ts @@ -0,0 +1,8 @@ +import { defineConfig } from "vite" +export default defineConfig(({ mode }) => ({ + define: process.env.VITEST ? {} : { global: 'window' }, + test: { + environment: 'jsdom', + }, +})) + diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 01a7f3030..8f5e19b2d 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -31,7 +31,7 @@ rand = "0.8" tauri-plugin-http = { version = "2", features = ["unsafe-headers"] } tauri-plugin-store = "2" hyper = { version = "0.14", features = ["server"] } -reqwest = { version = "0.11", features = ["json", "blocking"] } +reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } tokio = { version = "1", features = ["full"] } rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [ "client", @@ -41,6 +41,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai ] } uuid = { version = "1.7", features = ["v4"] } env = "1.0.1" +futures-util = "0.3.31" +tokio-util = "0.7.14" [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-updater = "2" diff --git a/src-tauri/src/core/mod.rs b/src-tauri/src/core/mod.rs index 8d4edde3c..36f84f627 100644 --- a/src-tauri/src/core/mod.rs +++ b/src-tauri/src/core/mod.rs @@ -5,4 +5,4 @@ pub mod server; pub mod setup; pub mod state; pub mod threads; -pub mod utils; \ No newline at end of file +pub mod utils; diff --git a/src-tauri/src/core/state.rs b/src-tauri/src/core/state.rs index 925030085..09a724c0c 100644 --- a/src-tauri/src/core/state.rs +++ b/src-tauri/src/core/state.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, sync::Arc}; +use crate::core::utils::download::DownloadManagerState; use rand::{distributions::Alphanumeric, Rng}; use rmcp::{service::RunningService, RoleClient}; use tokio::sync::Mutex; @@ -8,6 +9,7 @@ use tokio::sync::Mutex; pub struct AppState { pub app_token: Option, pub mcp_servers: Arc>>>, + pub download_manager: Arc>, } pub fn generate_app_token() -> String { rand::thread_rng() diff --git a/src-tauri/src/core/utils/download.rs b/src-tauri/src/core/utils/download.rs new file mode 100644 index 000000000..4ec4d057b --- /dev/null +++ b/src-tauri/src/core/utils/download.rs @@ -0,0 +1,421 @@ +use crate::core::cmd::get_jan_data_folder_path; +use crate::core::state::AppState; +use crate::core::utils::normalize_path; +use futures_util::StreamExt; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use std::collections::HashMap; +use std::path::Path; +use std::time::Duration; +use tauri::{Emitter, State}; +use tokio::fs::File; +use tokio::io::AsyncWriteExt; +use tokio_util::sync::CancellationToken; + +#[derive(Default)] +pub struct DownloadManagerState { + pub cancel_tokens: HashMap, +} + +// this is to emulate the current way of downloading files by Cortex + Jan +// we can change this later +#[derive(serde::Serialize, Clone, Debug, PartialEq)] +pub enum DownloadEventType { + Started, + Updated, + Success, + // Error, // we don't need to emit an Error event. just return an error directly + Stopped, +} + +#[derive(serde::Serialize, Clone, Debug)] +pub struct DownloadEvent { + pub task_id: String, + pub total_size: u64, + pub downloaded_size: u64, + pub download_type: String, // TODO: make this an enum as well + pub event_type: DownloadEventType, +} + +fn err_to_string(e: E) -> String { + format!("Error: {}", e) +} + +#[tauri::command] +pub async fn download_file( + app: tauri::AppHandle, + state: State<'_, AppState>, + url: &str, + path: &Path, + task_id: &str, + headers: HashMap, +) -> Result<(), String> { + // insert cancel tokens + let cancel_token = CancellationToken::new(); + { + let mut download_manager = state.download_manager.lock().await; + if download_manager.cancel_tokens.contains_key(url) { + return Err(format!("URL {} is already being downloaded", url)); + } + download_manager + .cancel_tokens + .insert(task_id.to_string(), cancel_token.clone()); + } + + let header_map = _convert_headers(headers).map_err(err_to_string)?; + let total_size = _get_file_size(url, header_map.clone()) + .await + .map_err(err_to_string)?; + log::info!("File size: {}", total_size); + let mut evt = DownloadEvent { + task_id: task_id.to_string(), + total_size, + downloaded_size: 0, + download_type: "Model".to_string(), + event_type: DownloadEventType::Started, + }; + app.emit("download", evt.clone()).unwrap(); + + // save file under Jan data folder + let data_dir = get_jan_data_folder_path(app.clone()); + let save_path = data_dir.join(path); + + let mut has_error = false; + let mut error_msg = String::new(); + match _download_file_internal( + app.clone(), + url, + &save_path, + header_map.clone(), + evt, + cancel_token.clone(), + ) + .await + { + Ok(evt_) => { + evt = evt_; // reassign ownership + } + Err((evt_, e)) => { + evt = evt_; // reassign ownership + error_msg = format!("Failed to download file: {}", e); + log::error!("{}", error_msg); + has_error = true; + } + } + + // cleanup + { + let mut download_manager = state.download_manager.lock().await; + download_manager.cancel_tokens.remove(url); + } + if has_error { + let _ = std::fs::remove_file(&save_path); // don't check error + return Err(error_msg); + } + + // emit final event + if evt.event_type == DownloadEventType::Stopped { + let _ = std::fs::remove_file(&save_path); // don't check error + } else { + evt.event_type = DownloadEventType::Success; + } + app.emit("download", evt.clone()).unwrap(); + + Ok(()) +} + +#[tauri::command] +pub async fn download_hf_repo( + app: tauri::AppHandle, + state: State<'_, AppState>, + model_id: &str, + save_dir: &Path, + task_id: &str, + branch: Option<&str>, + headers: HashMap, +) -> Result<(), String> { + let branch_str = branch.unwrap_or("main"); + let header_map = _convert_headers(headers).map_err(err_to_string)?; + + log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str); + + // get all files from repo, including subdirs + let items = _list_hf_repo_files(model_id, branch, header_map.clone()) + .await + .map_err(err_to_string)?; + + // insert cancel tokens + let cancel_token = CancellationToken::new(); + { + let mut download_manager = state.download_manager.lock().await; + if download_manager.cancel_tokens.contains_key(model_id) { + return Err(format!("model_id {} is already being downloaded", model_id)); + } + download_manager + .cancel_tokens + .insert(task_id.to_string(), cancel_token.clone()); + } + + let total_size = items.iter().map(|f| f.size).sum::(); + let mut evt = DownloadEvent { + task_id: task_id.to_string(), + total_size, + downloaded_size: 0, + download_type: "Model".to_string(), + event_type: DownloadEventType::Started, + }; + app.emit("download", evt.clone()).unwrap(); + + let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir); + let mut has_error = false; + let mut error_msg = String::new(); + for item in items { + let url = format!( + "https://huggingface.co/{}/resolve/{}/{}", + model_id, branch_str, item.path + ); + let save_path = local_dir.join(&item.path); + match _download_file_internal( + app.clone(), + &url, + &save_path, + header_map.clone(), + evt, + cancel_token.clone(), + ) + .await + { + Ok(evt_) => { + evt = evt_; // reassign ownership + if evt.event_type == DownloadEventType::Stopped { + break; + } + } + Err((evt_, e)) => { + evt = evt_; // reassign ownership + error_msg = format!("Failed to download file: {}", e); + log::error!("{}", error_msg); + has_error = true; + break; + } + } + } + + // cleanup + { + let mut download_manager = state.download_manager.lock().await; + download_manager.cancel_tokens.remove(model_id); + } + if has_error { + let _ = std::fs::remove_dir_all(&local_dir); // don't check error + return Err(error_msg); + } + + // emit final event + if evt.event_type == DownloadEventType::Stopped { + let _ = std::fs::remove_dir_all(&local_dir); // don't check error + } else { + evt.event_type = DownloadEventType::Success; + } + app.emit("download", evt.clone()).unwrap(); + + Ok(()) +} + +#[tauri::command] +pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> Result<(), String> { + // NOTE: might want to add User-Agent header + let mut download_manager = state.download_manager.lock().await; + if let Some(token) = download_manager.cancel_tokens.remove(task_id) { + token.cancel(); + log::info!("Cancelled download task_id: {}", task_id); + Ok(()) + } else { + Err(format!("No download task_id: {}", task_id)) + } +} + +fn _convert_headers( + headers: HashMap, +) -> Result> { + let mut header_map = HeaderMap::new(); + for (k, v) in headers { + let key = HeaderName::from_bytes(k.as_bytes())?; + let value = HeaderValue::from_str(&v)?; + header_map.insert(key, value); + } + Ok(header_map) +} + +async fn _get_file_size( + url: &str, + header_map: HeaderMap, +) -> Result> { + let client = reqwest::Client::new(); + let resp = client.head(url).headers(header_map).send().await?; + if !resp.status().is_success() { + return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into()); + } + // this is buggy, always return 0 for HEAD request + // Ok(resp.content_length().unwrap_or(0)) + + match resp.headers().get("content-length") { + Some(value) => { + let value_str = value.to_str()?; + let value_u64: u64 = value_str.parse()?; + Ok(value_u64) + } + None => Ok(0), + } +} + +// NOTE: Caller of this function should pass ownership of `evt` to this function +// (no .clone()) and obtain it back. Both Ok and Err will return ownership of +// the modified `evt` object back to the caller. +async fn _download_file_internal( + app: tauri::AppHandle, + url: &str, + path: &Path, // this is absolute path + header_map: HeaderMap, + mut evt: DownloadEvent, + cancel_token: CancellationToken, +) -> Result)> { + log::info!("Downloading file: {}", url); + + // normalize and enforce scope + let path = normalize_path(path); + let jan_data_folder = get_jan_data_folder_path(app.clone()); + if !path.starts_with(&jan_data_folder) { + return Err(( + evt.clone(), + format!( + "Path {} is outside of Jan data folder {}", + path.display(), + jan_data_folder.display() + ) + .into(), + )); + } + + // .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not + // compatible with hyper 0.14 + let client = reqwest::Client::builder() + .http2_keep_alive_timeout(Duration::from_secs(15)) + // .read_timeout(Duration::from_secs(10)) // timeout between chunks + // .connect_timeout(Duration::from_secs(10)) // timeout for first connection + .build() + .map_err(|e| (evt.clone(), e.into()))?; + + let resp = client + .get(url) + .headers(header_map) + .send() + .await + .map_err(|e| (evt.clone(), e.into()))?; + + if !resp.status().is_success() { + return Err(( + evt, + format!( + "Failed to download: HTTP status {}, {}", + resp.status(), + resp.text().await.unwrap_or_default() + ) + .into(), + )); + } + + // Create parent directories if they don't exist + if let Some(parent) = path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?; + } + } + let mut file = File::create(&path) + .await + .map_err(|e| (evt.clone(), e.into()))?; + + // write chunk to file + let mut stream = resp.bytes_stream(); + let mut download_delta = 0u64; + evt.event_type = DownloadEventType::Updated; + + while let Some(chunk) = stream.next().await { + if cancel_token.is_cancelled() { + log::info!("Download cancelled: {}", url); + evt.event_type = DownloadEventType::Stopped; + break; + } + + let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?; + file.write_all(&chunk) + .await + .map_err(|e| (evt.clone(), e.into()))?; + download_delta += chunk.len() as u64; + + // only update every 1MB + if download_delta >= 1024 * 1024 { + evt.downloaded_size += download_delta; + app.emit("download", evt.clone()).unwrap(); + download_delta = 0u64; + } + } + + // cleanup + file.flush().await.map_err(|e| (evt.clone(), e.into()))?; + if evt.event_type == DownloadEventType::Stopped { + let _ = std::fs::remove_file(&path); // don't check error + } + + // caller should emit a final event after calling this function + evt.downloaded_size += download_delta; + + Ok(evt) +} + +#[derive(serde::Deserialize)] +struct HfItem { + r#type: String, + // oid: String, // unused + path: String, + size: u64, +} + +async fn _list_hf_repo_files( + model_id: &str, + branch: Option<&str>, + header_map: HeaderMap, +) -> Result, Box> { + let branch_str = branch.unwrap_or("main"); + + let mut files = vec![]; + + // DFS + let mut stack = vec!["".to_string()]; + let client = reqwest::Client::new(); + while let Some(subdir) = stack.pop() { + let url = format!( + "https://huggingface.co/api/models/{}/tree/{}/{}", + model_id, branch_str, subdir + ); + let resp = client.get(&url).headers(header_map.clone()).send().await?; + + if !resp.status().is_success() { + return Err(format!( + "Failed to list files: HTTP status {}, {}", + resp.status(), + resp.text().await.unwrap_or_default(), + ) + .into()); + } + + for item in resp.json::>().await?.into_iter() { + if item.r#type == "directory" { + stack.push(item.path); + } else { + files.push(item); + } + } + } + + Ok(files) +} diff --git a/src-tauri/src/core/utils/mod.rs b/src-tauri/src/core/utils/mod.rs index 7f80e6f3a..04bfd12b0 100644 --- a/src-tauri/src/core/utils/mod.rs +++ b/src-tauri/src/core/utils/mod.rs @@ -1,5 +1,7 @@ +pub mod download; + use std::fs; -use std::path::PathBuf; +use std::path::{Component, Path, PathBuf}; use tauri::Runtime; use super::cmd::get_jan_data_folder_path; @@ -46,3 +48,31 @@ pub fn ensure_thread_dir_exists( } Ok(()) } + +// https://github.com/rust-lang/cargo/blob/rust-1.67.0/crates/cargo-util/src/paths.rs#L82-L107 +pub fn normalize_path(path: &Path) -> PathBuf { + let mut components = path.components().peekable(); + let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() { + components.next(); + PathBuf::from(c.as_os_str()) + } else { + PathBuf::new() + }; + + for component in components { + match component { + Component::Prefix(..) => unreachable!(), + Component::RootDir => { + ret.push(component.as_os_str()); + } + Component::CurDir => {} + Component::ParentDir => { + ret.pop(); + } + Component::Normal(c) => { + ret.push(c); + } + } + } + ret +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 0709a6a2a..5c04261b1 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -3,6 +3,7 @@ use core::{ cmd::get_jan_data_folder_path, setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar}, state::{generate_app_token, AppState}, + utils::download::DownloadManagerState, }; use std::{collections::HashMap, sync::Arc}; @@ -60,11 +61,16 @@ pub fn run() { core::threads::delete_message, core::threads::get_thread_assistant, core::threads::create_thread_assistant, - core::threads::modify_thread_assistant + core::threads::modify_thread_assistant, + // Download + core::utils::download::download_file, + core::utils::download::download_hf_repo, + core::utils::download::cancel_download_task, ]) .manage(AppState { app_token: Some(generate_app_token()), mcp_servers: Arc::new(Mutex::new(HashMap::new())), + download_manager: Arc::new(Mutex::new(DownloadManagerState::default())), }) .setup(|app| { app.handle().plugin(