feat: Download manager for llama.cpp extension (#4933)
This commit is contained in:
parent
e9f37e98d1
commit
4bde6645d0
36
extensions/download-extension/package.json
Normal file
36
extensions/download-extension/package.json
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
{
|
||||||
|
"name": "@janhq/download-extension",
|
||||||
|
"productName": "Download Manager",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "Handle downloads",
|
||||||
|
"main": "dist/index.js",
|
||||||
|
"author": "Jan <service@jan.ai>",
|
||||||
|
"license": "AGPL-3.0",
|
||||||
|
"scripts": {
|
||||||
|
"test": "vitest run",
|
||||||
|
"build": "rolldown -c rolldown.config.mjs",
|
||||||
|
"build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"cpx": "^1.5.0",
|
||||||
|
"rimraf": "^3.0.2",
|
||||||
|
"rolldown": "1.0.0-beta.1",
|
||||||
|
"run-script-os": "^1.1.6",
|
||||||
|
"typescript": "5.3.3",
|
||||||
|
"vitest": "^3.0.6"
|
||||||
|
},
|
||||||
|
"files": [
|
||||||
|
"dist/*",
|
||||||
|
"package.json",
|
||||||
|
"README.md"
|
||||||
|
],
|
||||||
|
"dependencies": {
|
||||||
|
"@janhq/core": "../../core/package.tgz",
|
||||||
|
"@tauri-apps/api": "^2.5.0"
|
||||||
|
},
|
||||||
|
"bundleDependencies": [],
|
||||||
|
"installConfig": {
|
||||||
|
"hoistingLimits": "workspaces"
|
||||||
|
},
|
||||||
|
"packageManager": "yarn@4.5.3"
|
||||||
|
}
|
||||||
14
extensions/download-extension/rolldown.config.mjs
Normal file
14
extensions/download-extension/rolldown.config.mjs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import { defineConfig } from 'rolldown'
|
||||||
|
import settingJson from './settings.json' with { type: 'json' }
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
input: 'src/index.ts',
|
||||||
|
output: {
|
||||||
|
format: 'esm',
|
||||||
|
file: 'dist/index.js',
|
||||||
|
},
|
||||||
|
platform: 'browser',
|
||||||
|
define: {
|
||||||
|
SETTINGS: JSON.stringify(settingJson),
|
||||||
|
},
|
||||||
|
})
|
||||||
14
extensions/download-extension/settings.json
Normal file
14
extensions/download-extension/settings.json
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"key": "hf-token",
|
||||||
|
"title": "Hugging Face Access Token",
|
||||||
|
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
|
||||||
|
"controllerType": "input",
|
||||||
|
"controllerProps": {
|
||||||
|
"value": "",
|
||||||
|
"placeholder": "hf_**********************************",
|
||||||
|
"type": "password",
|
||||||
|
"inputActions": ["unobscure", "copy"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
1
extensions/download-extension/src/@types/global.d.ts
vendored
Normal file
1
extensions/download-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
declare const SETTINGS: SettingComponentProps[]
|
||||||
127
extensions/download-extension/src/index.ts
Normal file
127
extensions/download-extension/src/index.ts
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import { invoke } from '@tauri-apps/api/core';
|
||||||
|
import { listen } from '@tauri-apps/api/event';
|
||||||
|
import { BaseExtension, events } from '@janhq/core';
|
||||||
|
|
||||||
|
export enum Settings {
|
||||||
|
hfToken = 'hf-token',
|
||||||
|
}
|
||||||
|
|
||||||
|
type DownloadEvent = {
|
||||||
|
task_id: string
|
||||||
|
total_size: number
|
||||||
|
downloaded_size: number
|
||||||
|
download_type: string
|
||||||
|
event_type: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export default class DownloadManager extends BaseExtension {
|
||||||
|
hf_token?: string
|
||||||
|
|
||||||
|
async onLoad() {
|
||||||
|
this.registerSettings(SETTINGS)
|
||||||
|
this.hf_token = await this.getSetting<string>(Settings.hfToken, undefined)
|
||||||
|
}
|
||||||
|
|
||||||
|
async onUnload() { }
|
||||||
|
|
||||||
|
async downloadFile(url: string, path: string, taskId: string) {
|
||||||
|
// relay tauri events to Jan events
|
||||||
|
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
||||||
|
let payload = event.payload
|
||||||
|
let eventName = {
|
||||||
|
Updated: 'onFileDownloadUpdate',
|
||||||
|
Error: 'onFileDownloadError',
|
||||||
|
Success: 'onFileDownloadSuccess',
|
||||||
|
Stopped: 'onFileDownloadStopped',
|
||||||
|
Started: 'onFileDownloadStarted',
|
||||||
|
}[payload.event_type]
|
||||||
|
|
||||||
|
// remove this once event system is back in web-app
|
||||||
|
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
||||||
|
|
||||||
|
events.emit(eventName, {
|
||||||
|
modelId: taskId,
|
||||||
|
percent: payload.downloaded_size / payload.total_size,
|
||||||
|
size: {
|
||||||
|
transferred: payload.downloaded_size,
|
||||||
|
total: payload.total_size,
|
||||||
|
},
|
||||||
|
downloadType: payload.download_type,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
await invoke<void>(
|
||||||
|
"download_file",
|
||||||
|
{ url, path, taskId, headers: this._getHeaders() },
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error downloading file:", error)
|
||||||
|
events.emit('onFileDownloadError', {
|
||||||
|
modelId: url,
|
||||||
|
downloadType: 'Model',
|
||||||
|
})
|
||||||
|
throw error
|
||||||
|
} finally {
|
||||||
|
unlisten()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) {
|
||||||
|
// relay tauri events to Jan events
|
||||||
|
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
||||||
|
let payload = event.payload
|
||||||
|
let eventName = {
|
||||||
|
Updated: 'onFileDownloadUpdate',
|
||||||
|
Error: 'onFileDownloadError',
|
||||||
|
Success: 'onFileDownloadSuccess',
|
||||||
|
Stopped: 'onFileDownloadStopped',
|
||||||
|
Started: 'onFileDownloadStarted',
|
||||||
|
}[payload.event_type]
|
||||||
|
|
||||||
|
// remove this once event system is back in web-app
|
||||||
|
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
||||||
|
|
||||||
|
events.emit(eventName, {
|
||||||
|
modelId: taskId,
|
||||||
|
percent: payload.downloaded_size / payload.total_size,
|
||||||
|
size: {
|
||||||
|
transferred: payload.downloaded_size,
|
||||||
|
total: payload.total_size,
|
||||||
|
},
|
||||||
|
downloadType: payload.download_type,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
await invoke<void>(
|
||||||
|
"download_hf_repo",
|
||||||
|
{ modelId, saveDir, taskId, branch, headers: this._getHeaders() },
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error downloading file:", error)
|
||||||
|
events.emit('onFileDownloadError', {
|
||||||
|
modelId: modelId,
|
||||||
|
downloadType: 'Model',
|
||||||
|
})
|
||||||
|
throw error
|
||||||
|
} finally {
|
||||||
|
unlisten()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async cancelDownload(taskId: string) {
|
||||||
|
try {
|
||||||
|
await invoke<void>("cancel_download_task", { taskId })
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error cancelling download:", error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_getHeaders() {
|
||||||
|
return {
|
||||||
|
...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
15
extensions/download-extension/tsconfig.json
Normal file
15
extensions/download-extension/tsconfig.json
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "es2016",
|
||||||
|
"module": "esnext",
|
||||||
|
"moduleResolution": "node",
|
||||||
|
"outDir": "./dist",
|
||||||
|
"esModuleInterop": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"strict": false,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"rootDir": "./src"
|
||||||
|
},
|
||||||
|
"include": ["./src"],
|
||||||
|
"exclude": ["**/*.test.ts", "vite.config.ts"]
|
||||||
|
}
|
||||||
8
extensions/download-extension/vite.config.ts
Normal file
8
extensions/download-extension/vite.config.ts
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
import { defineConfig } from "vite"
|
||||||
|
export default defineConfig(({ mode }) => ({
|
||||||
|
define: process.env.VITEST ? {} : { global: 'window' },
|
||||||
|
test: {
|
||||||
|
environment: 'jsdom',
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ rand = "0.8"
|
|||||||
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
|
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
|
||||||
tauri-plugin-store = "2"
|
tauri-plugin-store = "2"
|
||||||
hyper = { version = "0.14", features = ["server"] }
|
hyper = { version = "0.14", features = ["server"] }
|
||||||
reqwest = { version = "0.11", features = ["json", "blocking"] }
|
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||||
"client",
|
"client",
|
||||||
@ -41,6 +41,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
|
|||||||
] }
|
] }
|
||||||
uuid = { version = "1.7", features = ["v4"] }
|
uuid = { version = "1.7", features = ["v4"] }
|
||||||
env = "1.0.1"
|
env = "1.0.1"
|
||||||
|
futures-util = "0.3.31"
|
||||||
|
tokio-util = "0.7.14"
|
||||||
|
|
||||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||||
tauri-plugin-updater = "2"
|
tauri-plugin-updater = "2"
|
||||||
|
|||||||
@ -5,4 +5,4 @@ pub mod server;
|
|||||||
pub mod setup;
|
pub mod setup;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
pub mod threads;
|
pub mod threads;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use crate::core::utils::download::DownloadManagerState;
|
||||||
use rand::{distributions::Alphanumeric, Rng};
|
use rand::{distributions::Alphanumeric, Rng};
|
||||||
use rmcp::{service::RunningService, RoleClient};
|
use rmcp::{service::RunningService, RoleClient};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
|
|||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub app_token: Option<String>,
|
pub app_token: Option<String>,
|
||||||
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||||
|
pub download_manager: Arc<Mutex<DownloadManagerState>>,
|
||||||
}
|
}
|
||||||
pub fn generate_app_token() -> String {
|
pub fn generate_app_token() -> String {
|
||||||
rand::thread_rng()
|
rand::thread_rng()
|
||||||
|
|||||||
421
src-tauri/src/core/utils/download.rs
Normal file
421
src-tauri/src/core/utils/download.rs
Normal file
@ -0,0 +1,421 @@
|
|||||||
|
use crate::core::cmd::get_jan_data_folder_path;
|
||||||
|
use crate::core::state::AppState;
|
||||||
|
use crate::core::utils::normalize_path;
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tauri::{Emitter, State};
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct DownloadManagerState {
|
||||||
|
pub cancel_tokens: HashMap<String, CancellationToken>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is to emulate the current way of downloading files by Cortex + Jan
|
||||||
|
// we can change this later
|
||||||
|
#[derive(serde::Serialize, Clone, Debug, PartialEq)]
|
||||||
|
pub enum DownloadEventType {
|
||||||
|
Started,
|
||||||
|
Updated,
|
||||||
|
Success,
|
||||||
|
// Error, // we don't need to emit an Error event. just return an error directly
|
||||||
|
Stopped,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, Clone, Debug)]
|
||||||
|
pub struct DownloadEvent {
|
||||||
|
pub task_id: String,
|
||||||
|
pub total_size: u64,
|
||||||
|
pub downloaded_size: u64,
|
||||||
|
pub download_type: String, // TODO: make this an enum as well
|
||||||
|
pub event_type: DownloadEventType,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
||||||
|
format!("Error: {}", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn download_file(
|
||||||
|
app: tauri::AppHandle,
|
||||||
|
state: State<'_, AppState>,
|
||||||
|
url: &str,
|
||||||
|
path: &Path,
|
||||||
|
task_id: &str,
|
||||||
|
headers: HashMap<String, String>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// insert cancel tokens
|
||||||
|
let cancel_token = CancellationToken::new();
|
||||||
|
{
|
||||||
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
|
if download_manager.cancel_tokens.contains_key(url) {
|
||||||
|
return Err(format!("URL {} is already being downloaded", url));
|
||||||
|
}
|
||||||
|
download_manager
|
||||||
|
.cancel_tokens
|
||||||
|
.insert(task_id.to_string(), cancel_token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||||
|
let total_size = _get_file_size(url, header_map.clone())
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
log::info!("File size: {}", total_size);
|
||||||
|
let mut evt = DownloadEvent {
|
||||||
|
task_id: task_id.to_string(),
|
||||||
|
total_size,
|
||||||
|
downloaded_size: 0,
|
||||||
|
download_type: "Model".to_string(),
|
||||||
|
event_type: DownloadEventType::Started,
|
||||||
|
};
|
||||||
|
app.emit("download", evt.clone()).unwrap();
|
||||||
|
|
||||||
|
// save file under Jan data folder
|
||||||
|
let data_dir = get_jan_data_folder_path(app.clone());
|
||||||
|
let save_path = data_dir.join(path);
|
||||||
|
|
||||||
|
let mut has_error = false;
|
||||||
|
let mut error_msg = String::new();
|
||||||
|
match _download_file_internal(
|
||||||
|
app.clone(),
|
||||||
|
url,
|
||||||
|
&save_path,
|
||||||
|
header_map.clone(),
|
||||||
|
evt,
|
||||||
|
cancel_token.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(evt_) => {
|
||||||
|
evt = evt_; // reassign ownership
|
||||||
|
}
|
||||||
|
Err((evt_, e)) => {
|
||||||
|
evt = evt_; // reassign ownership
|
||||||
|
error_msg = format!("Failed to download file: {}", e);
|
||||||
|
log::error!("{}", error_msg);
|
||||||
|
has_error = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup
|
||||||
|
{
|
||||||
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
|
download_manager.cancel_tokens.remove(url);
|
||||||
|
}
|
||||||
|
if has_error {
|
||||||
|
let _ = std::fs::remove_file(&save_path); // don't check error
|
||||||
|
return Err(error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// emit final event
|
||||||
|
if evt.event_type == DownloadEventType::Stopped {
|
||||||
|
let _ = std::fs::remove_file(&save_path); // don't check error
|
||||||
|
} else {
|
||||||
|
evt.event_type = DownloadEventType::Success;
|
||||||
|
}
|
||||||
|
app.emit("download", evt.clone()).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn download_hf_repo(
|
||||||
|
app: tauri::AppHandle,
|
||||||
|
state: State<'_, AppState>,
|
||||||
|
model_id: &str,
|
||||||
|
save_dir: &Path,
|
||||||
|
task_id: &str,
|
||||||
|
branch: Option<&str>,
|
||||||
|
headers: HashMap<String, String>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let branch_str = branch.unwrap_or("main");
|
||||||
|
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||||
|
|
||||||
|
log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str);
|
||||||
|
|
||||||
|
// get all files from repo, including subdirs
|
||||||
|
let items = _list_hf_repo_files(model_id, branch, header_map.clone())
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
|
||||||
|
// insert cancel tokens
|
||||||
|
let cancel_token = CancellationToken::new();
|
||||||
|
{
|
||||||
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
|
if download_manager.cancel_tokens.contains_key(model_id) {
|
||||||
|
return Err(format!("model_id {} is already being downloaded", model_id));
|
||||||
|
}
|
||||||
|
download_manager
|
||||||
|
.cancel_tokens
|
||||||
|
.insert(task_id.to_string(), cancel_token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let total_size = items.iter().map(|f| f.size).sum::<u64>();
|
||||||
|
let mut evt = DownloadEvent {
|
||||||
|
task_id: task_id.to_string(),
|
||||||
|
total_size,
|
||||||
|
downloaded_size: 0,
|
||||||
|
download_type: "Model".to_string(),
|
||||||
|
event_type: DownloadEventType::Started,
|
||||||
|
};
|
||||||
|
app.emit("download", evt.clone()).unwrap();
|
||||||
|
|
||||||
|
let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir);
|
||||||
|
let mut has_error = false;
|
||||||
|
let mut error_msg = String::new();
|
||||||
|
for item in items {
|
||||||
|
let url = format!(
|
||||||
|
"https://huggingface.co/{}/resolve/{}/{}",
|
||||||
|
model_id, branch_str, item.path
|
||||||
|
);
|
||||||
|
let save_path = local_dir.join(&item.path);
|
||||||
|
match _download_file_internal(
|
||||||
|
app.clone(),
|
||||||
|
&url,
|
||||||
|
&save_path,
|
||||||
|
header_map.clone(),
|
||||||
|
evt,
|
||||||
|
cancel_token.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(evt_) => {
|
||||||
|
evt = evt_; // reassign ownership
|
||||||
|
if evt.event_type == DownloadEventType::Stopped {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err((evt_, e)) => {
|
||||||
|
evt = evt_; // reassign ownership
|
||||||
|
error_msg = format!("Failed to download file: {}", e);
|
||||||
|
log::error!("{}", error_msg);
|
||||||
|
has_error = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup
|
||||||
|
{
|
||||||
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
|
download_manager.cancel_tokens.remove(model_id);
|
||||||
|
}
|
||||||
|
if has_error {
|
||||||
|
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
||||||
|
return Err(error_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// emit final event
|
||||||
|
if evt.event_type == DownloadEventType::Stopped {
|
||||||
|
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
||||||
|
} else {
|
||||||
|
evt.event_type = DownloadEventType::Success;
|
||||||
|
}
|
||||||
|
app.emit("download", evt.clone()).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> Result<(), String> {
|
||||||
|
// NOTE: might want to add User-Agent header
|
||||||
|
let mut download_manager = state.download_manager.lock().await;
|
||||||
|
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
|
||||||
|
token.cancel();
|
||||||
|
log::info!("Cancelled download task_id: {}", task_id);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("No download task_id: {}", task_id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _convert_headers(
|
||||||
|
headers: HashMap<String, String>,
|
||||||
|
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
|
||||||
|
let mut header_map = HeaderMap::new();
|
||||||
|
for (k, v) in headers {
|
||||||
|
let key = HeaderName::from_bytes(k.as_bytes())?;
|
||||||
|
let value = HeaderValue::from_str(&v)?;
|
||||||
|
header_map.insert(key, value);
|
||||||
|
}
|
||||||
|
Ok(header_map)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn _get_file_size(
|
||||||
|
url: &str,
|
||||||
|
header_map: HeaderMap,
|
||||||
|
) -> Result<u64, Box<dyn std::error::Error>> {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let resp = client.head(url).headers(header_map).send().await?;
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
|
||||||
|
}
|
||||||
|
// this is buggy, always return 0 for HEAD request
|
||||||
|
// Ok(resp.content_length().unwrap_or(0))
|
||||||
|
|
||||||
|
match resp.headers().get("content-length") {
|
||||||
|
Some(value) => {
|
||||||
|
let value_str = value.to_str()?;
|
||||||
|
let value_u64: u64 = value_str.parse()?;
|
||||||
|
Ok(value_u64)
|
||||||
|
}
|
||||||
|
None => Ok(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: Caller of this function should pass ownership of `evt` to this function
|
||||||
|
// (no .clone()) and obtain it back. Both Ok and Err will return ownership of
|
||||||
|
// the modified `evt` object back to the caller.
|
||||||
|
async fn _download_file_internal(
|
||||||
|
app: tauri::AppHandle,
|
||||||
|
url: &str,
|
||||||
|
path: &Path, // this is absolute path
|
||||||
|
header_map: HeaderMap,
|
||||||
|
mut evt: DownloadEvent,
|
||||||
|
cancel_token: CancellationToken,
|
||||||
|
) -> Result<DownloadEvent, (DownloadEvent, Box<dyn std::error::Error>)> {
|
||||||
|
log::info!("Downloading file: {}", url);
|
||||||
|
|
||||||
|
// normalize and enforce scope
|
||||||
|
let path = normalize_path(path);
|
||||||
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||||
|
if !path.starts_with(&jan_data_folder) {
|
||||||
|
return Err((
|
||||||
|
evt.clone(),
|
||||||
|
format!(
|
||||||
|
"Path {} is outside of Jan data folder {}",
|
||||||
|
path.display(),
|
||||||
|
jan_data_folder.display()
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
|
||||||
|
// compatible with hyper 0.14
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.http2_keep_alive_timeout(Duration::from_secs(15))
|
||||||
|
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
|
||||||
|
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
|
||||||
|
.build()
|
||||||
|
.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.get(url)
|
||||||
|
.headers(header_map)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err((
|
||||||
|
evt,
|
||||||
|
format!(
|
||||||
|
"Failed to download: HTTP status {}, {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.text().await.unwrap_or_default()
|
||||||
|
)
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create parent directories if they don't exist
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
if !parent.exists() {
|
||||||
|
std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut file = File::create(&path)
|
||||||
|
.await
|
||||||
|
.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
|
||||||
|
// write chunk to file
|
||||||
|
let mut stream = resp.bytes_stream();
|
||||||
|
let mut download_delta = 0u64;
|
||||||
|
evt.event_type = DownloadEventType::Updated;
|
||||||
|
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
log::info!("Download cancelled: {}", url);
|
||||||
|
evt.event_type = DownloadEventType::Stopped;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
file.write_all(&chunk)
|
||||||
|
.await
|
||||||
|
.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
download_delta += chunk.len() as u64;
|
||||||
|
|
||||||
|
// only update every 1MB
|
||||||
|
if download_delta >= 1024 * 1024 {
|
||||||
|
evt.downloaded_size += download_delta;
|
||||||
|
app.emit("download", evt.clone()).unwrap();
|
||||||
|
download_delta = 0u64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup
|
||||||
|
file.flush().await.map_err(|e| (evt.clone(), e.into()))?;
|
||||||
|
if evt.event_type == DownloadEventType::Stopped {
|
||||||
|
let _ = std::fs::remove_file(&path); // don't check error
|
||||||
|
}
|
||||||
|
|
||||||
|
// caller should emit a final event after calling this function
|
||||||
|
evt.downloaded_size += download_delta;
|
||||||
|
|
||||||
|
Ok(evt)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
struct HfItem {
|
||||||
|
r#type: String,
|
||||||
|
// oid: String, // unused
|
||||||
|
path: String,
|
||||||
|
size: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn _list_hf_repo_files(
|
||||||
|
model_id: &str,
|
||||||
|
branch: Option<&str>,
|
||||||
|
header_map: HeaderMap,
|
||||||
|
) -> Result<Vec<HfItem>, Box<dyn std::error::Error>> {
|
||||||
|
let branch_str = branch.unwrap_or("main");
|
||||||
|
|
||||||
|
let mut files = vec![];
|
||||||
|
|
||||||
|
// DFS
|
||||||
|
let mut stack = vec!["".to_string()];
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
while let Some(subdir) = stack.pop() {
|
||||||
|
let url = format!(
|
||||||
|
"https://huggingface.co/api/models/{}/tree/{}/{}",
|
||||||
|
model_id, branch_str, subdir
|
||||||
|
);
|
||||||
|
let resp = client.get(&url).headers(header_map.clone()).send().await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
return Err(format!(
|
||||||
|
"Failed to list files: HTTP status {}, {}",
|
||||||
|
resp.status(),
|
||||||
|
resp.text().await.unwrap_or_default(),
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
for item in resp.json::<Vec<HfItem>>().await?.into_iter() {
|
||||||
|
if item.r#type == "directory" {
|
||||||
|
stack.push(item.path);
|
||||||
|
} else {
|
||||||
|
files.push(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(files)
|
||||||
|
}
|
||||||
@ -1,5 +1,7 @@
|
|||||||
|
pub mod download;
|
||||||
|
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::{Component, Path, PathBuf};
|
||||||
use tauri::Runtime;
|
use tauri::Runtime;
|
||||||
|
|
||||||
use super::cmd::get_jan_data_folder_path;
|
use super::cmd::get_jan_data_folder_path;
|
||||||
@ -46,3 +48,31 @@ pub fn ensure_thread_dir_exists<R: Runtime>(
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/rust-lang/cargo/blob/rust-1.67.0/crates/cargo-util/src/paths.rs#L82-L107
|
||||||
|
pub fn normalize_path(path: &Path) -> PathBuf {
|
||||||
|
let mut components = path.components().peekable();
|
||||||
|
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
|
||||||
|
components.next();
|
||||||
|
PathBuf::from(c.as_os_str())
|
||||||
|
} else {
|
||||||
|
PathBuf::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
for component in components {
|
||||||
|
match component {
|
||||||
|
Component::Prefix(..) => unreachable!(),
|
||||||
|
Component::RootDir => {
|
||||||
|
ret.push(component.as_os_str());
|
||||||
|
}
|
||||||
|
Component::CurDir => {}
|
||||||
|
Component::ParentDir => {
|
||||||
|
ret.pop();
|
||||||
|
}
|
||||||
|
Component::Normal(c) => {
|
||||||
|
ret.push(c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
|||||||
@ -3,6 +3,7 @@ use core::{
|
|||||||
cmd::get_jan_data_folder_path,
|
cmd::get_jan_data_folder_path,
|
||||||
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
|
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
|
||||||
state::{generate_app_token, AppState},
|
state::{generate_app_token, AppState},
|
||||||
|
utils::download::DownloadManagerState,
|
||||||
};
|
};
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
@ -60,11 +61,16 @@ pub fn run() {
|
|||||||
core::threads::delete_message,
|
core::threads::delete_message,
|
||||||
core::threads::get_thread_assistant,
|
core::threads::get_thread_assistant,
|
||||||
core::threads::create_thread_assistant,
|
core::threads::create_thread_assistant,
|
||||||
core::threads::modify_thread_assistant
|
core::threads::modify_thread_assistant,
|
||||||
|
// Download
|
||||||
|
core::utils::download::download_file,
|
||||||
|
core::utils::download::download_hf_repo,
|
||||||
|
core::utils::download::cancel_download_task,
|
||||||
])
|
])
|
||||||
.manage(AppState {
|
.manage(AppState {
|
||||||
app_token: Some(generate_app_token()),
|
app_token: Some(generate_app_token()),
|
||||||
mcp_servers: Arc::new(Mutex::new(HashMap::new())),
|
mcp_servers: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
|
||||||
})
|
})
|
||||||
.setup(|app| {
|
.setup(|app| {
|
||||||
app.handle().plugin(
|
app.handle().plugin(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user