feat: Download manager for llama.cpp extension (#4933)
This commit is contained in:
parent
e9f37e98d1
commit
4bde6645d0
36
extensions/download-extension/package.json
Normal file
36
extensions/download-extension/package.json
Normal file
@ -0,0 +1,36 @@
|
||||
{
|
||||
"name": "@janhq/download-extension",
|
||||
"productName": "Download Manager",
|
||||
"version": "1.0.0",
|
||||
"description": "Handle downloads",
|
||||
"main": "dist/index.js",
|
||||
"author": "Jan <service@jan.ai>",
|
||||
"license": "AGPL-3.0",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"build": "rolldown -c rolldown.config.mjs",
|
||||
"build:publish": "rimraf *.tgz --glob || true && yarn build && npm pack && cpx *.tgz ../../pre-install"
|
||||
},
|
||||
"devDependencies": {
|
||||
"cpx": "^1.5.0",
|
||||
"rimraf": "^3.0.2",
|
||||
"rolldown": "1.0.0-beta.1",
|
||||
"run-script-os": "^1.1.6",
|
||||
"typescript": "5.3.3",
|
||||
"vitest": "^3.0.6"
|
||||
},
|
||||
"files": [
|
||||
"dist/*",
|
||||
"package.json",
|
||||
"README.md"
|
||||
],
|
||||
"dependencies": {
|
||||
"@janhq/core": "../../core/package.tgz",
|
||||
"@tauri-apps/api": "^2.5.0"
|
||||
},
|
||||
"bundleDependencies": [],
|
||||
"installConfig": {
|
||||
"hoistingLimits": "workspaces"
|
||||
},
|
||||
"packageManager": "yarn@4.5.3"
|
||||
}
|
||||
14
extensions/download-extension/rolldown.config.mjs
Normal file
14
extensions/download-extension/rolldown.config.mjs
Normal file
@ -0,0 +1,14 @@
|
||||
import { defineConfig } from 'rolldown'
|
||||
import settingJson from './settings.json' with { type: 'json' }
|
||||
|
||||
export default defineConfig({
|
||||
input: 'src/index.ts',
|
||||
output: {
|
||||
format: 'esm',
|
||||
file: 'dist/index.js',
|
||||
},
|
||||
platform: 'browser',
|
||||
define: {
|
||||
SETTINGS: JSON.stringify(settingJson),
|
||||
},
|
||||
})
|
||||
14
extensions/download-extension/settings.json
Normal file
14
extensions/download-extension/settings.json
Normal file
@ -0,0 +1,14 @@
|
||||
[
|
||||
{
|
||||
"key": "hf-token",
|
||||
"title": "Hugging Face Access Token",
|
||||
"description": "Access tokens programmatically authenticate your identity to the Hugging Face Hub, allowing applications to perform specific actions specified by the scope of permissions granted.",
|
||||
"controllerType": "input",
|
||||
"controllerProps": {
|
||||
"value": "",
|
||||
"placeholder": "hf_**********************************",
|
||||
"type": "password",
|
||||
"inputActions": ["unobscure", "copy"]
|
||||
}
|
||||
}
|
||||
]
|
||||
1
extensions/download-extension/src/@types/global.d.ts
vendored
Normal file
1
extensions/download-extension/src/@types/global.d.ts
vendored
Normal file
@ -0,0 +1 @@
|
||||
declare const SETTINGS: SettingComponentProps[]
|
||||
127
extensions/download-extension/src/index.ts
Normal file
127
extensions/download-extension/src/index.ts
Normal file
@ -0,0 +1,127 @@
|
||||
import { invoke } from '@tauri-apps/api/core';
|
||||
import { listen } from '@tauri-apps/api/event';
|
||||
import { BaseExtension, events } from '@janhq/core';
|
||||
|
||||
export enum Settings {
|
||||
hfToken = 'hf-token',
|
||||
}
|
||||
|
||||
type DownloadEvent = {
|
||||
task_id: string
|
||||
total_size: number
|
||||
downloaded_size: number
|
||||
download_type: string
|
||||
event_type: string
|
||||
}
|
||||
|
||||
export default class DownloadManager extends BaseExtension {
|
||||
hf_token?: string
|
||||
|
||||
async onLoad() {
|
||||
this.registerSettings(SETTINGS)
|
||||
this.hf_token = await this.getSetting<string>(Settings.hfToken, undefined)
|
||||
}
|
||||
|
||||
async onUnload() { }
|
||||
|
||||
async downloadFile(url: string, path: string, taskId: string) {
|
||||
// relay tauri events to Jan events
|
||||
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
||||
let payload = event.payload
|
||||
let eventName = {
|
||||
Updated: 'onFileDownloadUpdate',
|
||||
Error: 'onFileDownloadError',
|
||||
Success: 'onFileDownloadSuccess',
|
||||
Stopped: 'onFileDownloadStopped',
|
||||
Started: 'onFileDownloadStarted',
|
||||
}[payload.event_type]
|
||||
|
||||
// remove this once event system is back in web-app
|
||||
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
||||
|
||||
events.emit(eventName, {
|
||||
modelId: taskId,
|
||||
percent: payload.downloaded_size / payload.total_size,
|
||||
size: {
|
||||
transferred: payload.downloaded_size,
|
||||
total: payload.total_size,
|
||||
},
|
||||
downloadType: payload.download_type,
|
||||
})
|
||||
})
|
||||
|
||||
try {
|
||||
await invoke<void>(
|
||||
"download_file",
|
||||
{ url, path, taskId, headers: this._getHeaders() },
|
||||
)
|
||||
} catch (error) {
|
||||
console.error("Error downloading file:", error)
|
||||
events.emit('onFileDownloadError', {
|
||||
modelId: url,
|
||||
downloadType: 'Model',
|
||||
})
|
||||
throw error
|
||||
} finally {
|
||||
unlisten()
|
||||
}
|
||||
}
|
||||
|
||||
async downloadHfRepo(modelId: string, saveDir: string, taskId: string, branch?: string) {
|
||||
// relay tauri events to Jan events
|
||||
const unlisten = await listen<DownloadEvent>('download', (event) => {
|
||||
let payload = event.payload
|
||||
let eventName = {
|
||||
Updated: 'onFileDownloadUpdate',
|
||||
Error: 'onFileDownloadError',
|
||||
Success: 'onFileDownloadSuccess',
|
||||
Stopped: 'onFileDownloadStopped',
|
||||
Started: 'onFileDownloadStarted',
|
||||
}[payload.event_type]
|
||||
|
||||
// remove this once event system is back in web-app
|
||||
console.log(taskId, payload.event_type, payload.downloaded_size / payload.total_size)
|
||||
|
||||
events.emit(eventName, {
|
||||
modelId: taskId,
|
||||
percent: payload.downloaded_size / payload.total_size,
|
||||
size: {
|
||||
transferred: payload.downloaded_size,
|
||||
total: payload.total_size,
|
||||
},
|
||||
downloadType: payload.download_type,
|
||||
})
|
||||
})
|
||||
|
||||
try {
|
||||
await invoke<void>(
|
||||
"download_hf_repo",
|
||||
{ modelId, saveDir, taskId, branch, headers: this._getHeaders() },
|
||||
)
|
||||
} catch (error) {
|
||||
console.error("Error downloading file:", error)
|
||||
events.emit('onFileDownloadError', {
|
||||
modelId: modelId,
|
||||
downloadType: 'Model',
|
||||
})
|
||||
throw error
|
||||
} finally {
|
||||
unlisten()
|
||||
}
|
||||
}
|
||||
|
||||
async cancelDownload(taskId: string) {
|
||||
try {
|
||||
await invoke<void>("cancel_download_task", { taskId })
|
||||
} catch (error) {
|
||||
console.error("Error cancelling download:", error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
_getHeaders() {
|
||||
return {
|
||||
...(this.hf_token && { Authorization: `Bearer ${this.hf_token}` })
|
||||
}
|
||||
}
|
||||
}
|
||||
15
extensions/download-extension/tsconfig.json
Normal file
15
extensions/download-extension/tsconfig.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "es2016",
|
||||
"module": "esnext",
|
||||
"moduleResolution": "node",
|
||||
"outDir": "./dist",
|
||||
"esModuleInterop": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"strict": false,
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["./src"],
|
||||
"exclude": ["**/*.test.ts", "vite.config.ts"]
|
||||
}
|
||||
8
extensions/download-extension/vite.config.ts
Normal file
8
extensions/download-extension/vite.config.ts
Normal file
@ -0,0 +1,8 @@
|
||||
import { defineConfig } from "vite"
|
||||
export default defineConfig(({ mode }) => ({
|
||||
define: process.env.VITEST ? {} : { global: 'window' },
|
||||
test: {
|
||||
environment: 'jsdom',
|
||||
},
|
||||
}))
|
||||
|
||||
@ -31,7 +31,7 @@ rand = "0.8"
|
||||
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
|
||||
tauri-plugin-store = "2"
|
||||
hyper = { version = "0.14", features = ["server"] }
|
||||
reqwest = { version = "0.11", features = ["json", "blocking"] }
|
||||
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
|
||||
"client",
|
||||
@ -41,6 +41,8 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
|
||||
] }
|
||||
uuid = { version = "1.7", features = ["v4"] }
|
||||
env = "1.0.1"
|
||||
futures-util = "0.3.31"
|
||||
tokio-util = "0.7.14"
|
||||
|
||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||
tauri-plugin-updater = "2"
|
||||
|
||||
@ -5,4 +5,4 @@ pub mod server;
|
||||
pub mod setup;
|
||||
pub mod state;
|
||||
pub mod threads;
|
||||
pub mod utils;
|
||||
pub mod utils;
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::core::utils::download::DownloadManagerState;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use rmcp::{service::RunningService, RoleClient};
|
||||
use tokio::sync::Mutex;
|
||||
@ -8,6 +9,7 @@ use tokio::sync::Mutex;
|
||||
pub struct AppState {
|
||||
pub app_token: Option<String>,
|
||||
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
|
||||
pub download_manager: Arc<Mutex<DownloadManagerState>>,
|
||||
}
|
||||
pub fn generate_app_token() -> String {
|
||||
rand::thread_rng()
|
||||
|
||||
421
src-tauri/src/core/utils/download.rs
Normal file
421
src-tauri/src/core/utils/download.rs
Normal file
@ -0,0 +1,421 @@
|
||||
use crate::core::cmd::get_jan_data_folder_path;
|
||||
use crate::core::state::AppState;
|
||||
use crate::core::utils::normalize_path;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tauri::{Emitter, State};
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DownloadManagerState {
|
||||
pub cancel_tokens: HashMap<String, CancellationToken>,
|
||||
}
|
||||
|
||||
// this is to emulate the current way of downloading files by Cortex + Jan
|
||||
// we can change this later
|
||||
#[derive(serde::Serialize, Clone, Debug, PartialEq)]
|
||||
pub enum DownloadEventType {
|
||||
Started,
|
||||
Updated,
|
||||
Success,
|
||||
// Error, // we don't need to emit an Error event. just return an error directly
|
||||
Stopped,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Clone, Debug)]
|
||||
pub struct DownloadEvent {
|
||||
pub task_id: String,
|
||||
pub total_size: u64,
|
||||
pub downloaded_size: u64,
|
||||
pub download_type: String, // TODO: make this an enum as well
|
||||
pub event_type: DownloadEventType,
|
||||
}
|
||||
|
||||
fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
||||
format!("Error: {}", e)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn download_file(
|
||||
app: tauri::AppHandle,
|
||||
state: State<'_, AppState>,
|
||||
url: &str,
|
||||
path: &Path,
|
||||
task_id: &str,
|
||||
headers: HashMap<String, String>,
|
||||
) -> Result<(), String> {
|
||||
// insert cancel tokens
|
||||
let cancel_token = CancellationToken::new();
|
||||
{
|
||||
let mut download_manager = state.download_manager.lock().await;
|
||||
if download_manager.cancel_tokens.contains_key(url) {
|
||||
return Err(format!("URL {} is already being downloaded", url));
|
||||
}
|
||||
download_manager
|
||||
.cancel_tokens
|
||||
.insert(task_id.to_string(), cancel_token.clone());
|
||||
}
|
||||
|
||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||
let total_size = _get_file_size(url, header_map.clone())
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
log::info!("File size: {}", total_size);
|
||||
let mut evt = DownloadEvent {
|
||||
task_id: task_id.to_string(),
|
||||
total_size,
|
||||
downloaded_size: 0,
|
||||
download_type: "Model".to_string(),
|
||||
event_type: DownloadEventType::Started,
|
||||
};
|
||||
app.emit("download", evt.clone()).unwrap();
|
||||
|
||||
// save file under Jan data folder
|
||||
let data_dir = get_jan_data_folder_path(app.clone());
|
||||
let save_path = data_dir.join(path);
|
||||
|
||||
let mut has_error = false;
|
||||
let mut error_msg = String::new();
|
||||
match _download_file_internal(
|
||||
app.clone(),
|
||||
url,
|
||||
&save_path,
|
||||
header_map.clone(),
|
||||
evt,
|
||||
cancel_token.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(evt_) => {
|
||||
evt = evt_; // reassign ownership
|
||||
}
|
||||
Err((evt_, e)) => {
|
||||
evt = evt_; // reassign ownership
|
||||
error_msg = format!("Failed to download file: {}", e);
|
||||
log::error!("{}", error_msg);
|
||||
has_error = true;
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup
|
||||
{
|
||||
let mut download_manager = state.download_manager.lock().await;
|
||||
download_manager.cancel_tokens.remove(url);
|
||||
}
|
||||
if has_error {
|
||||
let _ = std::fs::remove_file(&save_path); // don't check error
|
||||
return Err(error_msg);
|
||||
}
|
||||
|
||||
// emit final event
|
||||
if evt.event_type == DownloadEventType::Stopped {
|
||||
let _ = std::fs::remove_file(&save_path); // don't check error
|
||||
} else {
|
||||
evt.event_type = DownloadEventType::Success;
|
||||
}
|
||||
app.emit("download", evt.clone()).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn download_hf_repo(
|
||||
app: tauri::AppHandle,
|
||||
state: State<'_, AppState>,
|
||||
model_id: &str,
|
||||
save_dir: &Path,
|
||||
task_id: &str,
|
||||
branch: Option<&str>,
|
||||
headers: HashMap<String, String>,
|
||||
) -> Result<(), String> {
|
||||
let branch_str = branch.unwrap_or("main");
|
||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||
|
||||
log::info!("Downloading HF repo: {}, branch {}", model_id, branch_str);
|
||||
|
||||
// get all files from repo, including subdirs
|
||||
let items = _list_hf_repo_files(model_id, branch, header_map.clone())
|
||||
.await
|
||||
.map_err(err_to_string)?;
|
||||
|
||||
// insert cancel tokens
|
||||
let cancel_token = CancellationToken::new();
|
||||
{
|
||||
let mut download_manager = state.download_manager.lock().await;
|
||||
if download_manager.cancel_tokens.contains_key(model_id) {
|
||||
return Err(format!("model_id {} is already being downloaded", model_id));
|
||||
}
|
||||
download_manager
|
||||
.cancel_tokens
|
||||
.insert(task_id.to_string(), cancel_token.clone());
|
||||
}
|
||||
|
||||
let total_size = items.iter().map(|f| f.size).sum::<u64>();
|
||||
let mut evt = DownloadEvent {
|
||||
task_id: task_id.to_string(),
|
||||
total_size,
|
||||
downloaded_size: 0,
|
||||
download_type: "Model".to_string(),
|
||||
event_type: DownloadEventType::Started,
|
||||
};
|
||||
app.emit("download", evt.clone()).unwrap();
|
||||
|
||||
let local_dir = get_jan_data_folder_path(app.clone()).join(save_dir);
|
||||
let mut has_error = false;
|
||||
let mut error_msg = String::new();
|
||||
for item in items {
|
||||
let url = format!(
|
||||
"https://huggingface.co/{}/resolve/{}/{}",
|
||||
model_id, branch_str, item.path
|
||||
);
|
||||
let save_path = local_dir.join(&item.path);
|
||||
match _download_file_internal(
|
||||
app.clone(),
|
||||
&url,
|
||||
&save_path,
|
||||
header_map.clone(),
|
||||
evt,
|
||||
cancel_token.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(evt_) => {
|
||||
evt = evt_; // reassign ownership
|
||||
if evt.event_type == DownloadEventType::Stopped {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err((evt_, e)) => {
|
||||
evt = evt_; // reassign ownership
|
||||
error_msg = format!("Failed to download file: {}", e);
|
||||
log::error!("{}", error_msg);
|
||||
has_error = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup
|
||||
{
|
||||
let mut download_manager = state.download_manager.lock().await;
|
||||
download_manager.cancel_tokens.remove(model_id);
|
||||
}
|
||||
if has_error {
|
||||
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
||||
return Err(error_msg);
|
||||
}
|
||||
|
||||
// emit final event
|
||||
if evt.event_type == DownloadEventType::Stopped {
|
||||
let _ = std::fs::remove_dir_all(&local_dir); // don't check error
|
||||
} else {
|
||||
evt.event_type = DownloadEventType::Success;
|
||||
}
|
||||
app.emit("download", evt.clone()).unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> Result<(), String> {
|
||||
// NOTE: might want to add User-Agent header
|
||||
let mut download_manager = state.download_manager.lock().await;
|
||||
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
|
||||
token.cancel();
|
||||
log::info!("Cancelled download task_id: {}", task_id);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("No download task_id: {}", task_id))
|
||||
}
|
||||
}
|
||||
|
||||
fn _convert_headers(
|
||||
headers: HashMap<String, String>,
|
||||
) -> Result<HeaderMap, Box<dyn std::error::Error>> {
|
||||
let mut header_map = HeaderMap::new();
|
||||
for (k, v) in headers {
|
||||
let key = HeaderName::from_bytes(k.as_bytes())?;
|
||||
let value = HeaderValue::from_str(&v)?;
|
||||
header_map.insert(key, value);
|
||||
}
|
||||
Ok(header_map)
|
||||
}
|
||||
|
||||
async fn _get_file_size(
|
||||
url: &str,
|
||||
header_map: HeaderMap,
|
||||
) -> Result<u64, Box<dyn std::error::Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client.head(url).headers(header_map).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Failed to get file size: HTTP status {}", resp.status()).into());
|
||||
}
|
||||
// this is buggy, always return 0 for HEAD request
|
||||
// Ok(resp.content_length().unwrap_or(0))
|
||||
|
||||
match resp.headers().get("content-length") {
|
||||
Some(value) => {
|
||||
let value_str = value.to_str()?;
|
||||
let value_u64: u64 = value_str.parse()?;
|
||||
Ok(value_u64)
|
||||
}
|
||||
None => Ok(0),
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: Caller of this function should pass ownership of `evt` to this function
|
||||
// (no .clone()) and obtain it back. Both Ok and Err will return ownership of
|
||||
// the modified `evt` object back to the caller.
|
||||
async fn _download_file_internal(
|
||||
app: tauri::AppHandle,
|
||||
url: &str,
|
||||
path: &Path, // this is absolute path
|
||||
header_map: HeaderMap,
|
||||
mut evt: DownloadEvent,
|
||||
cancel_token: CancellationToken,
|
||||
) -> Result<DownloadEvent, (DownloadEvent, Box<dyn std::error::Error>)> {
|
||||
log::info!("Downloading file: {}", url);
|
||||
|
||||
// normalize and enforce scope
|
||||
let path = normalize_path(path);
|
||||
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||
if !path.starts_with(&jan_data_folder) {
|
||||
return Err((
|
||||
evt.clone(),
|
||||
format!(
|
||||
"Path {} is outside of Jan data folder {}",
|
||||
path.display(),
|
||||
jan_data_folder.display()
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
// .read_timeout() and .connect_timeout() requires reqwest 0.12, which is not
|
||||
// compatible with hyper 0.14
|
||||
let client = reqwest::Client::builder()
|
||||
.http2_keep_alive_timeout(Duration::from_secs(15))
|
||||
// .read_timeout(Duration::from_secs(10)) // timeout between chunks
|
||||
// .connect_timeout(Duration::from_secs(10)) // timeout for first connection
|
||||
.build()
|
||||
.map_err(|e| (evt.clone(), e.into()))?;
|
||||
|
||||
let resp = client
|
||||
.get(url)
|
||||
.headers(header_map)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| (evt.clone(), e.into()))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err((
|
||||
evt,
|
||||
format!(
|
||||
"Failed to download: HTTP status {}, {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default()
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = path.parent() {
|
||||
if !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| (evt.clone(), e.into()))?;
|
||||
}
|
||||
}
|
||||
let mut file = File::create(&path)
|
||||
.await
|
||||
.map_err(|e| (evt.clone(), e.into()))?;
|
||||
|
||||
// write chunk to file
|
||||
let mut stream = resp.bytes_stream();
|
||||
let mut download_delta = 0u64;
|
||||
evt.event_type = DownloadEventType::Updated;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if cancel_token.is_cancelled() {
|
||||
log::info!("Download cancelled: {}", url);
|
||||
evt.event_type = DownloadEventType::Stopped;
|
||||
break;
|
||||
}
|
||||
|
||||
let chunk = chunk.map_err(|e| (evt.clone(), e.into()))?;
|
||||
file.write_all(&chunk)
|
||||
.await
|
||||
.map_err(|e| (evt.clone(), e.into()))?;
|
||||
download_delta += chunk.len() as u64;
|
||||
|
||||
// only update every 1MB
|
||||
if download_delta >= 1024 * 1024 {
|
||||
evt.downloaded_size += download_delta;
|
||||
app.emit("download", evt.clone()).unwrap();
|
||||
download_delta = 0u64;
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup
|
||||
file.flush().await.map_err(|e| (evt.clone(), e.into()))?;
|
||||
if evt.event_type == DownloadEventType::Stopped {
|
||||
let _ = std::fs::remove_file(&path); // don't check error
|
||||
}
|
||||
|
||||
// caller should emit a final event after calling this function
|
||||
evt.downloaded_size += download_delta;
|
||||
|
||||
Ok(evt)
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct HfItem {
|
||||
r#type: String,
|
||||
// oid: String, // unused
|
||||
path: String,
|
||||
size: u64,
|
||||
}
|
||||
|
||||
async fn _list_hf_repo_files(
|
||||
model_id: &str,
|
||||
branch: Option<&str>,
|
||||
header_map: HeaderMap,
|
||||
) -> Result<Vec<HfItem>, Box<dyn std::error::Error>> {
|
||||
let branch_str = branch.unwrap_or("main");
|
||||
|
||||
let mut files = vec![];
|
||||
|
||||
// DFS
|
||||
let mut stack = vec!["".to_string()];
|
||||
let client = reqwest::Client::new();
|
||||
while let Some(subdir) = stack.pop() {
|
||||
let url = format!(
|
||||
"https://huggingface.co/api/models/{}/tree/{}/{}",
|
||||
model_id, branch_str, subdir
|
||||
);
|
||||
let resp = client.get(&url).headers(header_map.clone()).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!(
|
||||
"Failed to list files: HTTP status {}, {}",
|
||||
resp.status(),
|
||||
resp.text().await.unwrap_or_default(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
for item in resp.json::<Vec<HfItem>>().await?.into_iter() {
|
||||
if item.r#type == "directory" {
|
||||
stack.push(item.path);
|
||||
} else {
|
||||
files.push(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
@ -1,5 +1,7 @@
|
||||
pub mod download;
|
||||
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use tauri::Runtime;
|
||||
|
||||
use super::cmd::get_jan_data_folder_path;
|
||||
@ -46,3 +48,31 @@ pub fn ensure_thread_dir_exists<R: Runtime>(
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/rust-lang/cargo/blob/rust-1.67.0/crates/cargo-util/src/paths.rs#L82-L107
|
||||
pub fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut components = path.components().peekable();
|
||||
let mut ret = if let Some(c @ Component::Prefix(..)) = components.peek().cloned() {
|
||||
components.next();
|
||||
PathBuf::from(c.as_os_str())
|
||||
} else {
|
||||
PathBuf::new()
|
||||
};
|
||||
|
||||
for component in components {
|
||||
match component {
|
||||
Component::Prefix(..) => unreachable!(),
|
||||
Component::RootDir => {
|
||||
ret.push(component.as_os_str());
|
||||
}
|
||||
Component::CurDir => {}
|
||||
Component::ParentDir => {
|
||||
ret.pop();
|
||||
}
|
||||
Component::Normal(c) => {
|
||||
ret.push(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ use core::{
|
||||
cmd::get_jan_data_folder_path,
|
||||
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
|
||||
state::{generate_app_token, AppState},
|
||||
utils::download::DownloadManagerState,
|
||||
};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
@ -60,11 +61,16 @@ pub fn run() {
|
||||
core::threads::delete_message,
|
||||
core::threads::get_thread_assistant,
|
||||
core::threads::create_thread_assistant,
|
||||
core::threads::modify_thread_assistant
|
||||
core::threads::modify_thread_assistant,
|
||||
// Download
|
||||
core::utils::download::download_file,
|
||||
core::utils::download::download_hf_repo,
|
||||
core::utils::download::cancel_download_task,
|
||||
])
|
||||
.manage(AppState {
|
||||
app_token: Some(generate_app_token()),
|
||||
mcp_servers: Arc::new(Mutex::new(HashMap::new())),
|
||||
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
|
||||
})
|
||||
.setup(|app| {
|
||||
app.handle().plugin(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user