feat: Download manager for llama.cpp extension (#4933)

This commit is contained in:
Thien Tran 2025-05-16 15:01:42 +08:00 committed by GitHub
parent e9f37e98d1
commit 4bde6645d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 680 additions and 4 deletions

View File

@ -0,0 +1,36 @@
{
"name": "@janhq/download-extension",
"productName": "Download Manager",
"version": "1.0.0",
"description": "Handle downloads",
"main": "dist/index.js",
"author": "Jan <service@jan.ai>",
"license": "AGPL-3.0",
"scripts": {
"test": "vitest run",
"build": "rolldown -c rolldown.config.mjs",
"build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
},
"devDependencies": {
"cpx": "^1.5.0",
"rimraf": "^3.0.2",
"rolldown": "1.0.0-beta.1",
"run-script-os": "^1.1.6",
"typescript": "5.3.3",
"vitest": "^3.0.6"
},
"files": [
"dist/*",
"package.json",
"README.md"
],
"dependencies": {
"@janhq/core": "../../core/package.tgz",
"@tauri-apps/api": "^2.5.0"
},
"bundleDependencies": [],
"installConfig": {
"hoistingLimits": "workspaces"
},
"packageManager": "yarn@4.5.3"
}

View File

@ -0,0 +1,14 @@
import { defineConfig } from 'rolldown'
import settingJson from './settings.json' with { type: 'json' }
export default defineConfig({
input: 'src/index.ts',
output: {
format: 'esm',
file: 'dist/index.js',
},
platform: 'browser',
define: {
SETTINGS: JSON.stringify(settingJson),
},
})

View File

@ -0,0 +1,14 @@
[
{
"key": "hf-token",
"title": "Hugging Face Access Token",
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
"controllerType": "input",
"controllerProps": {
"value": "",
"placeholder": "hf_**********************************",
"type": "password",
"inputActions": ["unobscure", "copy"]
}
}
]

View File

@ -0,0 +1 @@
declare const SETTINGS: SettingComponentProps[]

View File

@ -0,0 +1,127 @@
import { invoke } from '@tauri-apps/api/core';
import { listen } from '@tauri-apps/api/event';
import { BaseExtension, events } from '@janhq/core';
export enum Settings {
hfToken = 'hf-token',
}
type DownloadEvent = {
task_id: string
total_size: number
downloaded_size: number
download_type: string
event_type: string
}
export default class DownloadManager extends BaseExtension {
hf_token?: string
async onLoad() {
this.registerSettings(SETTINGS)
this.hf_token = await this.getSetting<string>(Settings.hfToken, undefined)
}
async onUnload() { }
async downloadFile(url: string, path: string, taskId: string) {
// relay tauri events to Jan events
const unlisten = await listen<DownloadEvent>('download', (event) => {
let payload = event.payload
let eventName = {
Updated: 'onFileDownloadUpdate',
Error: 'onFileDownloadError',
Success: 'onFileDownloadSuccess',
Stopped: 'onFileDownloadStopped',
Started: 'onFileDownloadStarted',
}[payload.event_type]
// remove this once event system is back in web-app
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
events.emit(eventName, {
modelId: taskId,
percent: payload.downloaded_size / payload.total_size,
size: {
transferred: payload.downloaded_size,
total: payload.total_size,
},
downloadType: payload.download_type,
})
})
try {
await invoke<void>(
"download_file",
{ url, path, taskId, headers: this._getHeaders() },
)
} catch (error) {
console.error("Error downloading file:", error)
events.emit('onFileDownloadError', {
modelId: url,
downloadType: 'Model',
})
throw error
} finally {
unlisten()
}
}
async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) {
// relay tauri events to Jan events
const unlisten = await listen<DownloadEvent>('download', (event) => {
let payload = event.payload
let eventName = {
Updated: 'onFileDownloadUpdate',
Error: 'onFileDownloadError',
Success: 'onFileDownloadSuccess',
Stopped: 'onFileDownloadStopped',
Started: 'onFileDownloadStarted',
}[payload.event_type]
// remove this once event system is back in web-app
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
events.emit(eventName, {
modelId: taskId,
percent: payload.downloaded_size / payload.total_size,
size: {
transferred: payload.downloaded_size,
total: payload.total_size,
},
downloadType: payload.download_type,
})
})
try {
await invoke<void>(
"download_hf_repo",
{ modelId, saveDir, taskId, branch, headers: this._getHeaders() },
)
} catch (error) {
console.error("Error downloading file:", error)
events.emit('onFileDownloadError', {
modelId: modelId,
downloadType: 'Model',
})
throw error
} finally {
unlisten()
}
}
async cancelDownload(taskId: string) {
try {
await invoke<void>("cancel_download_task", { taskId })
} catch (error) {
console.error("Error cancelling download:", error)
throw error
}
}
_getHeaders() {
return {
...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` })
}
}
}

View File

@ -0,0 +1,15 @@
{
"compilerOptions": {
"target": "es2016",
"module": "esnext",
"moduleResolution": "node",
"outDir": "./dist",
"esModuleInterop": true,
"forceConsistentCasingInFileNames": true,
"strict": false,
"skipLibCheck": true,
"rootDir": "./src"
},
"include": ["./src"],
"exclude": ["**/*.test.ts", "vite.config.ts"]
}

View File

@ -0,0 +1,8 @@
import { defineConfig } from "vite"
export default defineConfig(({ mode }) => ({
define: process.env.VITEST ? {} : { global: 'window' },
test: {
environment: 'jsdom',
},
}))

View File

@ -31,7 +31,7 @@ rand = "0.8"
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
tauri-plugin-store = "2"
hyper = { version = "0.14", features = ["server"] }
reqwest = { version = "0.11", features = ["json", "blocking"] }
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
tokio = { version = "1", features = ["full"] }
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
"client",
@ -41,6 +41,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
] }
uuid = { version = "1.7", features = ["v4"] }
env = "1.0.1"
futures-util = "0.3.31"
tokio-util = "0.7.14"
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
tauri-plugin-updater = "2"

View File

@ -5,4 +5,4 @@ pub mod server;
pub mod setup;
pub mod state;
pub mod threads;
pub mod utils;
pub mod utils;

View File

@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc};
use crate::core::utils::download::DownloadManagerState;
use rand::{distributions::Alphanumeric, Rng};
use rmcp::{service::RunningService, RoleClient};
use tokio::sync::Mutex;
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
pub struct AppState {
pub app_token: Option<String>,
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
pub download_manager: Arc<Mutex<DownloadManagerState>>,
}
pub fn generate_app_token() -> String {
rand::thread_rng()

View File

@ -0,0 +1,421 @@
use crate::core::cmd::get_jan_data_folder_path;
use crate::core::state::AppState;
use crate::core::utils::normalize_path;
use futures_util::StreamExt;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use std::collections::HashMap;
use std::path::Path;
use std::time::Duration;
use tauri::{Emitter, State};
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
#[derive(Default)]
pub struct DownloadManagerState {
pub cancel_tokens: HashMap<String, CancellationToken>,
}
// this is to emulate the current way of downloading files by Cortex + Jan
// we can change this later
#[derive(serde::Serialize, Clone, Debug, PartialEq)]
pub enum DownloadEventType {
Started,
Updated,
Success,
// Error, // we don't need to emit an Error event. just return an error directly
Stopped,
}
#[derive(serde::Serialize, Clone, Debug)]
pub struct DownloadEvent {
pub task_id: String,
pub total_size: u64,
pub downloaded_size: u64,
pub download_type: String, // TODO: make this an enum as well
pub event_type: DownloadEventType,
}
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
format!("Error: {}", e)
}
#[tauri::command]
pub async fn download_file(
app: tauri::AppHandle,
state: State<'_, AppState>,
url: &str,
path: &Path,
task_id: &str,
headers: HashMap<String, String>,
) -> Result<(), String> {
// insert cancel tokens
let cancel_token = CancellationToken::new();
{
let mut download_manager = state.download_manager.lock().await;
if download_manager.cancel_tokens.contains_key(url) {
return Err(format!("URL {} is already being downloaded", url));
}
download_manager
.cancel_tokens
.insert(task_id.to_string(), cancel_token.clone());
}
let header_map = _convert_headers(headers).map_err(err_to_string)?;
let total_size = _get_file_size(url, header_map.clone())
.await
.map_err(err_to_string)?;
log::info!("File size: {}", total_size);
let mut evt = DownloadEvent {
task_id: task_id.to_string(),
total_size,
downloaded_size: 0,
download_type: "Model".to_string(),
event_type: DownloadEventType::Started,
};
app.emit("download", evt.clone()).unwrap();
// save file under Jan data folder
let data_dir = get_jan_data_folder_path(app.clone());
let save_path = data_dir.join(path);
let mut has_error = false;
let mut error_msg = String::new();
match _download_file_internal(
app.clone(),
url,
&save_path,
header_map.clone(),
evt,
cancel_token.clone(),
)
.await
{
Ok(evt_) => {
evt = evt_; // reassign ownership
}
Err((evt_, e)) => {
evt = evt_; // reassign ownership
error_msg = format!("Failed to download file: {}", e);
log::error!("{}", error_msg);
has_error = true;
}
}
// cleanup
{
let mut download_manager = state.download_manager.lock().await;
download_manager.cancel_tokens.remove(url);
}
if has_error {
let _ = std::fs::remove_file(&save_path); // don't check error
return Err(error_msg);
}
// emit final event
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_file(&save_path); // don't check error
} else {
evt.event_type = DownloadEventType::Success;
}
app.emit("download", evt.clone()).unwrap();
Ok(())
}
#[tauri::command]
pub async fn download_hf_repo(
app: tauri::AppHandle,
state: State<'_, AppState>,
model_id: &str,
save_dir: &Path,
task_id: &str,
branch: Option<&str>,
headers: HashMap<String, String>,
) -> Result<(), String> {
let branch_str = branch.unwrap_or("main");
let header_map = _convert_headers(headers).map_err(err_to_string)?;
log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str);
// get all files from repo, including subdirs
let items = _list_hf_repo_files(model_id, branch, header_map.clone())
.await
.map_err(err_to_string)?;
// insert cancel tokens
let cancel_token = CancellationToken::new();
{
let mut download_manager = state.download_manager.lock().await;
if download_manager.cancel_tokens.contains_key(model_id) {
return Err(format!("model_id {} is already being downloaded", model_id));
}
download_manager
.cancel_tokens
.insert(task_id.to_string(), cancel_token.clone());
}
let total_size = items.iter().map(|f| f.size).sum::<u64>();
let mut evt = DownloadEvent {
task_id: task_id.to_string(),
total_size,
downloaded_size: 0,
download_type: "Model".to_string(),
event_type: DownloadEventType::Started,
};
app.emit("download", evt.clone()).unwrap();
let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir);
let mut has_error = false;
let mut error_msg = String::new();
for item in items {
let url = format!(
"https://huggingface.co/{}/resolve/{}/{}",
model_id, branch_str, item.path
);
let save_path = local_dir.join(&item.path);
match _download_file_internal(
app.clone(),
&url,
&save_path,
header_map.clone(),
evt,
cancel_token.clone(),
)
.await
{
Ok(evt_) => {
evt = evt_; // reassign ownership
if evt.event_type == DownloadEventType::Stopped {
break;
}
}
Err((evt_, e)) => {
evt = evt_; // reassign ownership
error_msg = format!("Failed to download file: {}", e);
log::error!("{}", error_msg);
has_error = true;
break;
}
}
}
// cleanup
{
let mut download_manager = state.download_manager.lock().await;
download_manager.cancel_tokens.remove(model_id);
}
if has_error {
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
return Err(error_msg);
}
// emit final event
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
} else {
evt.event_type = DownloadEventType::Success;
}
app.emit("download", evt.clone()).unwrap();
Ok(())
}
#[tauri::command]
pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> Result<(), String> {
// NOTE: might want to add User-Agent header
let mut download_manager = state.download_manager.lock().await;
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
token.cancel();
log::info!("Cancelled download task_id: {}", task_id);
Ok(())
} else {
Err(format!("No download task_id: {}", task_id))
}
}
fn _convert_headers(
headers: HashMap<String, String>,
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
let mut header_map = HeaderMap::new();
for (k, v) in headers {
let key = HeaderName::from_bytes(k.as_bytes())?;
let value = HeaderValue::from_str(&v)?;
header_map.insert(key, value);
}
Ok(header_map)
}
async fn _get_file_size(
url: &str,
header_map: HeaderMap,
) -> Result<u64, Box<dyn std::error::Error>> {
let client = reqwest::Client::new();
let resp = client.head(url).headers(header_map).send().await?;
if !resp.status().is_success() {
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
}
// this is buggy, always return 0 for HEAD request
// Ok(resp.content_length().unwrap_or(0))
match resp.headers().get("content-length") {
Some(value) => {
let value_str = value.to_str()?;
let value_u64: u64 = value_str.parse()?;
Ok(value_u64)
}
None => Ok(0),
}
}
// NOTE: Caller of this function should pass ownership of `evt` to this function
// (no .clone()) and obtain it back. Both Ok and Err will return ownership of
// the modified `evt` object back to the caller.
async fn _download_file_internal(
app: tauri::AppHandle,
url: &str,
path: &Path, // this is absolute path
header_map: HeaderMap,
mut evt: DownloadEvent,
cancel_token: CancellationToken,
) -> Result<DownloadEvent, (DownloadEvent, Box<dyn std::error::Error>)> {
log::info!("Downloading file: {}", url);
// normalize and enforce scope
let path = normalize_path(path);
let jan_data_folder = get_jan_data_folder_path(app.clone());
if !path.starts_with(&jan_data_folder) {
return Err((
evt.clone(),
format!(
"Path {} is outside of Jan data folder {}",
path.display(),
jan_data_folder.display()
)
.into(),
));
}
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
// compatible with hyper 0.14
let client = reqwest::Client::builder()
.http2_keep_alive_timeout(Duration::from_secs(15))
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
.build()
.map_err(|e| (evt.clone(), e.into()))?;
let resp = client
.get(url)
.headers(header_map)
.send()
.await
.map_err(|e| (evt.clone(), e.into()))?;
if !resp.status().is_success() {
return Err((
evt,
format!(
"Failed to download: HTTP status {}, {}",
resp.status(),
resp.text().await.unwrap_or_default()
)
.into(),
));
}
// Create parent directories if they don't exist
if let Some(parent) = path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?;
}
}
let mut file = File::create(&path)
.await
.map_err(|e| (evt.clone(), e.into()))?;
// write chunk to file
let mut stream = resp.bytes_stream();
let mut download_delta = 0u64;
evt.event_type = DownloadEventType::Updated;
while let Some(chunk) = stream.next().await {
if cancel_token.is_cancelled() {
log::info!("Download cancelled: {}", url);
evt.event_type = DownloadEventType::Stopped;
break;
}
let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?;
file.write_all(&chunk)
.await
.map_err(|e| (evt.clone(), e.into()))?;
download_delta += chunk.len() as u64;
// only update every 1MB
if download_delta >= 1024 * 1024 {
evt.downloaded_size += download_delta;
app.emit("download", evt.clone()).unwrap();
download_delta = 0u64;
}
}
// cleanup
file.flush().await.map_err(|e| (evt.clone(), e.into()))?;
if evt.event_type == DownloadEventType::Stopped {
let _ = std::fs::remove_file(&path); // don't check error
}
// caller should emit a final event after calling this function
evt.downloaded_size += download_delta;
Ok(evt)
}
#[derive(serde::Deserialize)]
struct HfItem {
r#type: String,
// oid: String, // unused
path: String,
size: u64,
}
async fn _list_hf_repo_files(
model_id: &str,
branch: Option<&str>,
header_map: HeaderMap,
) -> Result<Vec<HfItem>, Box<dyn std::error::Error>> {
let branch_str = branch.unwrap_or("main");
let mut files = vec![];
// DFS
let mut stack = vec!["".to_string()];
let client = reqwest::Client::new();
while let Some(subdir) = stack.pop() {
let url = format!(
"https://huggingface.co/api/models/{}/tree/{}/{}",
model_id, branch_str, subdir
);
let resp = client.get(&url).headers(header_map.clone()).send().await?;
if !resp.status().is_success() {
return Err(format!(
"Failed to list files: HTTP status {}, {}",
resp.status(),
resp.text().await.unwrap_or_default(),
)
.into());
}
for item in resp.json::<Vec<HfItem>>().await?.into_iter() {
if item.r#type == "directory" {
stack.push(item.path);
} else {
files.push(item);
}
}
}
Ok(files)
}

View File

@ -1,5 +1,7 @@
pub mod download;
use std::fs;
use std::path::PathBuf;
use std::path::{Component, Path, PathBuf};
use tauri::Runtime;
use super::cmd::get_jan_data_folder_path;
@ -46,3 +48,31 @@ pub fn ensure_thread_dir_exists<R: Runtime>(
}
Ok(())
}
// https://github.com/rust-lang/cargo/blob/rust-1.67.0/crates/cargo-util/src/paths.rs#L82-L107
pub fn normalize_path(path: &Path) -> PathBuf {
let mut components = path.components().peekable();
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
components.next();
PathBuf::from(c.as_os_str())
} else {
PathBuf::new()
};
for component in components {
match component {
Component::Prefix(..) => unreachable!(),
Component::RootDir => {
ret.push(component.as_os_str());
}
Component::CurDir => {}
Component::ParentDir => {
ret.pop();
}
Component::Normal(c) => {
ret.push(c);
}
}
}
ret
}

View File

@ -3,6 +3,7 @@ use core::{
cmd::get_jan_data_folder_path,
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
state::{generate_app_token, AppState},
utils::download::DownloadManagerState,
};
use std::{collections::HashMap, sync::Arc};
@ -60,11 +61,16 @@ pub fn run() {
core::threads::delete_message,
core::threads::get_thread_assistant,
core::threads::create_thread_assistant,
core::threads::modify_thread_assistant
core::threads::modify_thread_assistant,
// Download
core::utils::download::download_file,
core::utils::download::download_hf_repo,
core::utils::download::cancel_download_task,
])
.manage(AppState {
app_token: Some(generate_app_token()),
mcp_servers: Arc::new(Mutex::new(HashMap::new())),
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
})
.setup(|app| {
app.handle().plugin(