* feat: gguf file size + hash validation * fix tests fe * update cargo tests * handle asyn download for both models and mmproj * move progress tracker to models * handle file download cancelled * add cancellation mid hash run
This commit is contained in:
parent
41b4cc3bb3
commit
32a2ca95b6
@ -194,6 +194,10 @@ export interface chatOptions {
|
|||||||
export interface ImportOptions {
|
export interface ImportOptions {
|
||||||
modelPath: string
|
modelPath: string
|
||||||
mmprojPath?: string
|
mmprojPath?: string
|
||||||
|
modelSha256?: string
|
||||||
|
modelSize?: number
|
||||||
|
mmprojSha256?: string
|
||||||
|
mmprojSize?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface importResult {
|
export interface importResult {
|
||||||
|
|||||||
@ -73,6 +73,9 @@ export enum DownloadEvent {
|
|||||||
onFileDownloadSuccess = 'onFileDownloadSuccess',
|
onFileDownloadSuccess = 'onFileDownloadSuccess',
|
||||||
onFileDownloadStopped = 'onFileDownloadStopped',
|
onFileDownloadStopped = 'onFileDownloadStopped',
|
||||||
onFileDownloadStarted = 'onFileDownloadStarted',
|
onFileDownloadStarted = 'onFileDownloadStarted',
|
||||||
|
onModelValidationStarted = 'onModelValidationStarted',
|
||||||
|
onModelValidationFailed = 'onModelValidationFailed',
|
||||||
|
onFileDownloadAndVerificationSuccess = 'onFileDownloadAndVerificationSuccess',
|
||||||
}
|
}
|
||||||
export enum ExtensionRoute {
|
export enum ExtensionRoute {
|
||||||
baseExtensions = 'baseExtensions',
|
baseExtensions = 'baseExtensions',
|
||||||
|
|||||||
@ -10,6 +10,8 @@ interface DownloadItem {
|
|||||||
url: string
|
url: string
|
||||||
save_path: string
|
save_path: string
|
||||||
proxy?: Record<string, string | string[] | boolean>
|
proxy?: Record<string, string | string[] | boolean>
|
||||||
|
sha256?: string
|
||||||
|
size?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
type DownloadEvent = {
|
type DownloadEvent = {
|
||||||
|
|||||||
@ -20,9 +20,11 @@ import {
|
|||||||
chatCompletionRequest,
|
chatCompletionRequest,
|
||||||
events,
|
events,
|
||||||
AppEvent,
|
AppEvent,
|
||||||
|
DownloadEvent,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||||||
|
import { listen } from '@tauri-apps/api/event'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
listSupportedBackends,
|
listSupportedBackends,
|
||||||
@ -71,6 +73,8 @@ interface DownloadItem {
|
|||||||
url: string
|
url: string
|
||||||
save_path: string
|
save_path: string
|
||||||
proxy?: Record<string, string | string[] | boolean>
|
proxy?: Record<string, string | string[] | boolean>
|
||||||
|
sha256?: string
|
||||||
|
size?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
interface ModelConfig {
|
interface ModelConfig {
|
||||||
@ -79,6 +83,9 @@ interface ModelConfig {
|
|||||||
name: string // user-friendly
|
name: string // user-friendly
|
||||||
// some model info that we cache upon import
|
// some model info that we cache upon import
|
||||||
size_bytes: number
|
size_bytes: number
|
||||||
|
sha256?: string
|
||||||
|
mmproj_sha256?: string
|
||||||
|
mmproj_size_bytes?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
interface EmbeddingResponse {
|
interface EmbeddingResponse {
|
||||||
@ -154,6 +161,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||||||
private isConfiguringBackends: boolean = false
|
private isConfiguringBackends: boolean = false
|
||||||
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
||||||
|
private unlistenValidationStarted?: () => void
|
||||||
|
|
||||||
override async onLoad(): Promise<void> {
|
override async onLoad(): Promise<void> {
|
||||||
super.onLoad() // Calls registerEngine() from AIEngine
|
super.onLoad() // Calls registerEngine() from AIEngine
|
||||||
@ -181,6 +189,19 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
await getJanDataFolderPath(),
|
await getJanDataFolderPath(),
|
||||||
this.providerId,
|
this.providerId,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
// Set up validation event listeners to bridge Tauri events to frontend
|
||||||
|
this.unlistenValidationStarted = await listen<{
|
||||||
|
modelId: string
|
||||||
|
downloadType: string
|
||||||
|
}>('onModelValidationStarted', (event) => {
|
||||||
|
console.debug(
|
||||||
|
'LlamaCPP: bridging onModelValidationStarted event',
|
||||||
|
event.payload
|
||||||
|
)
|
||||||
|
events.emit(DownloadEvent.onModelValidationStarted, event.payload)
|
||||||
|
})
|
||||||
|
|
||||||
this.configureBackends()
|
this.configureBackends()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -774,6 +795,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
override async onUnload(): Promise<void> {
|
override async onUnload(): Promise<void> {
|
||||||
// Terminate all active sessions
|
// Terminate all active sessions
|
||||||
|
|
||||||
|
// Clean up validation event listeners
|
||||||
|
if (this.unlistenValidationStarted) {
|
||||||
|
this.unlistenValidationStarted()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
onSettingUpdate<T>(key: string, value: T): void {
|
onSettingUpdate<T>(key: string, value: T): void {
|
||||||
@ -1006,6 +1032,9 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
url: path,
|
url: path,
|
||||||
save_path: localPath,
|
save_path: localPath,
|
||||||
proxy: getProxyConfig(),
|
proxy: getProxyConfig(),
|
||||||
|
sha256:
|
||||||
|
saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256,
|
||||||
|
size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize,
|
||||||
})
|
})
|
||||||
return localPath
|
return localPath
|
||||||
}
|
}
|
||||||
@ -1023,8 +1052,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
if (downloadItems.length > 0) {
|
if (downloadItems.length > 0) {
|
||||||
let downloadCompleted = false
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// emit download update event on progress
|
// emit download update event on progress
|
||||||
const onProgress = (transferred: number, total: number) => {
|
const onProgress = (transferred: number, total: number) => {
|
||||||
@ -1034,7 +1061,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
size: { transferred, total },
|
size: { transferred, total },
|
||||||
downloadType: 'Model',
|
downloadType: 'Model',
|
||||||
})
|
})
|
||||||
downloadCompleted = transferred === total
|
|
||||||
}
|
}
|
||||||
const downloadManager = window.core.extensionManager.getByName(
|
const downloadManager = window.core.extensionManager.getByName(
|
||||||
'@janhq/download-extension'
|
'@janhq/download-extension'
|
||||||
@ -1045,13 +1071,67 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
onProgress
|
onProgress
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventName = downloadCompleted
|
// If we reach here, download completed successfully (including validation)
|
||||||
? 'onFileDownloadSuccess'
|
// The downloadFiles function only returns successfully if all files downloaded AND validated
|
||||||
: 'onFileDownloadStopped'
|
events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, {
|
||||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
modelId,
|
||||||
|
downloadType: 'Model'
|
||||||
|
})
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Error downloading model:', modelId, opts, error)
|
logger.error('Error downloading model:', modelId, opts, error)
|
||||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
const errorMessage =
|
||||||
|
error instanceof Error ? error.message : String(error)
|
||||||
|
|
||||||
|
// Check if this is a cancellation
|
||||||
|
const isCancellationError = errorMessage.includes('Download cancelled') ||
|
||||||
|
errorMessage.includes('Validation cancelled') ||
|
||||||
|
errorMessage.includes('Hash computation cancelled') ||
|
||||||
|
errorMessage.includes('cancelled') ||
|
||||||
|
errorMessage.includes('aborted')
|
||||||
|
|
||||||
|
// Check if this is a validation failure
|
||||||
|
const isValidationError =
|
||||||
|
errorMessage.includes('Hash verification failed') ||
|
||||||
|
errorMessage.includes('Size verification failed') ||
|
||||||
|
errorMessage.includes('Failed to verify file')
|
||||||
|
|
||||||
|
if (isCancellationError) {
|
||||||
|
logger.info('Download cancelled for model:', modelId)
|
||||||
|
// Emit download stopped event instead of error
|
||||||
|
events.emit(DownloadEvent.onFileDownloadStopped, {
|
||||||
|
modelId,
|
||||||
|
downloadType: 'Model',
|
||||||
|
})
|
||||||
|
} else if (isValidationError) {
|
||||||
|
logger.error(
|
||||||
|
'Validation failed for model:',
|
||||||
|
modelId,
|
||||||
|
'Error:',
|
||||||
|
errorMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
// Cancel any other download tasks for this model
|
||||||
|
try {
|
||||||
|
this.abortImport(modelId)
|
||||||
|
} catch (cancelError) {
|
||||||
|
logger.warn('Failed to cancel download task:', cancelError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit validation failure event
|
||||||
|
events.emit(DownloadEvent.onModelValidationFailed, {
|
||||||
|
modelId,
|
||||||
|
downloadType: 'Model',
|
||||||
|
error: errorMessage,
|
||||||
|
reason: 'validation_failed',
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Regular download error
|
||||||
|
events.emit(DownloadEvent.onFileDownloadError, {
|
||||||
|
modelId,
|
||||||
|
downloadType: 'Model',
|
||||||
|
error: errorMessage,
|
||||||
|
})
|
||||||
|
}
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1078,7 +1158,9 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('GGUF validation failed:', error)
|
logger.error('GGUF validation failed:', error)
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Invalid GGUF file(s): ${error.message || 'File format validation failed'}`
|
`Invalid GGUF file(s): ${
|
||||||
|
error.message || 'File format validation failed'
|
||||||
|
}`
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1097,6 +1179,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
mmproj_path: mmprojPath,
|
mmproj_path: mmprojPath,
|
||||||
name: modelId,
|
name: modelId,
|
||||||
size_bytes,
|
size_bytes,
|
||||||
|
model_sha256: opts.modelSha256,
|
||||||
|
model_size_bytes: opts.modelSize,
|
||||||
|
mmproj_sha256: opts.mmprojSha256,
|
||||||
|
mmproj_size_bytes: opts.mmprojSize,
|
||||||
} as ModelConfig
|
} as ModelConfig
|
||||||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||||||
await invoke<void>('write_yaml', {
|
await invoke<void>('write_yaml', {
|
||||||
@ -1108,16 +1194,50 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
modelPath,
|
modelPath,
|
||||||
mmprojPath,
|
mmprojPath,
|
||||||
size_bytes,
|
size_bytes,
|
||||||
|
model_sha256: opts.modelSha256,
|
||||||
|
model_size_bytes: opts.modelSize,
|
||||||
|
mmproj_sha256: opts.mmprojSha256,
|
||||||
|
mmproj_size_bytes: opts.mmprojSize,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deletes the entire model folder for a given modelId
|
||||||
|
* @param modelId The model ID to delete
|
||||||
|
*/
|
||||||
|
private async deleteModelFolder(modelId: string): Promise<void> {
|
||||||
|
try {
|
||||||
|
const modelDir = await joinPath([
|
||||||
|
await this.getProviderPath(),
|
||||||
|
'models',
|
||||||
|
modelId,
|
||||||
|
])
|
||||||
|
|
||||||
|
if (await fs.existsSync(modelDir)) {
|
||||||
|
logger.info(`Cleaning up model directory: ${modelDir}`)
|
||||||
|
await fs.rm(modelDir)
|
||||||
|
}
|
||||||
|
} catch (deleteError) {
|
||||||
|
logger.warn('Failed to delete model directory:', deleteError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override async abortImport(modelId: string): Promise<void> {
|
override async abortImport(modelId: string): Promise<void> {
|
||||||
// prepand provider name to avoid name collision
|
// Cancel any active download task
|
||||||
|
// prepend provider name to avoid name collision
|
||||||
const taskId = this.createDownloadTaskId(modelId)
|
const taskId = this.createDownloadTaskId(modelId)
|
||||||
const downloadManager = window.core.extensionManager.getByName(
|
const downloadManager = window.core.extensionManager.getByName(
|
||||||
'@janhq/download-extension'
|
'@janhq/download-extension'
|
||||||
)
|
)
|
||||||
await downloadManager.cancelDownload(taskId)
|
|
||||||
|
try {
|
||||||
|
await downloadManager.cancelDownload(taskId)
|
||||||
|
} catch (cancelError) {
|
||||||
|
logger.warn('Failed to cancel download task:', cancelError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the entire model folder if it exists (for validation failures)
|
||||||
|
await this.deleteModelFolder(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
11
src-tauri/Cargo.lock
generated
11
src-tauri/Cargo.lock
generated
@ -2323,6 +2323,7 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"sha2",
|
"sha2",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
"url",
|
"url",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -4019,8 +4020,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rmcp"
|
name = "rmcp"
|
||||||
version = "0.5.0"
|
version = "0.6.0"
|
||||||
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bb21cd3555f1059f27e4813827338dec44429a08ecd0011acc41d9907b160c00"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.22.1",
|
"base64 0.22.1",
|
||||||
"chrono",
|
"chrono",
|
||||||
@ -4045,8 +4047,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rmcp-macros"
|
name = "rmcp-macros"
|
||||||
version = "0.5.0"
|
version = "0.6.0"
|
||||||
source = "git+https://github.com/modelcontextprotocol/rust-sdk?rev=209dbac50f51737ad953c3a2c8e28f3619b6c277#209dbac50f51737ad953c3a2c8e28f3619b6c277"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ab5d16ae1ff3ce2c5fd86c37047b2869b75bec795d53a4b1d8257b15415a2354"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling 0.21.2",
|
"darling 0.21.2",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
use super::models::{DownloadEvent, DownloadItem, ProxyConfig};
|
use super::models::{DownloadEvent, DownloadItem, ProxyConfig, ProgressTracker};
|
||||||
use crate::core::app::commands::get_jan_data_folder_path;
|
use crate::core::app::commands::get_jan_data_folder_path;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use jan_utils::normalize_path;
|
use jan_utils::normalize_path;
|
||||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tauri::Emitter;
|
use tauri::Emitter;
|
||||||
use tokio::fs::File;
|
use tokio::fs::File;
|
||||||
@ -11,10 +12,131 @@ use tokio::io::AsyncWriteExt;
|
|||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
// ===== UTILITY FUNCTIONS =====
|
||||||
|
|
||||||
pub fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
pub fn err_to_string<E: std::fmt::Display>(e: E) -> String {
|
||||||
format!("Error: {}", e)
|
format!("Error: {}", e)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ===== VALIDATION FUNCTIONS =====
|
||||||
|
|
||||||
|
/// Validates a downloaded file against expected hash and size
|
||||||
|
async fn validate_downloaded_file(
|
||||||
|
item: &DownloadItem,
|
||||||
|
save_path: &Path,
|
||||||
|
app: &tauri::AppHandle,
|
||||||
|
cancel_token: &CancellationToken,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// Skip validation if no verification data is provided
|
||||||
|
if item.sha256.is_none() && item.size.is_none() {
|
||||||
|
log::debug!(
|
||||||
|
"No validation data provided for {}, skipping validation",
|
||||||
|
item.url
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model ID from save path for validation events
|
||||||
|
// Path structure: llamacpp/models/{modelId}/model.gguf or llamacpp/models/{modelId}/mmproj.gguf
|
||||||
|
let model_id = save_path
|
||||||
|
.parent() // get parent directory (modelId folder)
|
||||||
|
.and_then(|p| p.file_name())
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
|
||||||
|
// Emit validation started event
|
||||||
|
app.emit(
|
||||||
|
"onModelValidationStarted",
|
||||||
|
serde_json::json!({
|
||||||
|
"modelId": model_id,
|
||||||
|
"downloadType": "Model",
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
log::info!("Starting validation for model: {}", model_id);
|
||||||
|
|
||||||
|
// Validate size if provided (fast check first)
|
||||||
|
if let Some(expected_size) = &item.size {
|
||||||
|
log::info!("Starting size verification for {}", item.url);
|
||||||
|
|
||||||
|
match tokio::fs::metadata(save_path).await {
|
||||||
|
Ok(metadata) => {
|
||||||
|
let actual_size = metadata.len();
|
||||||
|
|
||||||
|
if actual_size != *expected_size {
|
||||||
|
log::error!(
|
||||||
|
"Size verification failed for {}. Expected: {} bytes, Actual: {} bytes",
|
||||||
|
item.url,
|
||||||
|
expected_size,
|
||||||
|
actual_size
|
||||||
|
);
|
||||||
|
return Err(format!(
|
||||||
|
"Size verification failed. Expected {} bytes but got {} bytes.",
|
||||||
|
expected_size, actual_size
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Size verification successful for {} ({} bytes)",
|
||||||
|
item.url,
|
||||||
|
actual_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!(
|
||||||
|
"Failed to get file metadata for {}: {}",
|
||||||
|
save_path.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
return Err(format!("Failed to verify file size: {}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for cancellation before expensive hash computation
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
log::info!("Validation cancelled for {}", item.url);
|
||||||
|
return Err("Validation cancelled".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate hash if provided (expensive check second)
|
||||||
|
if let Some(expected_sha256) = &item.sha256 {
|
||||||
|
log::info!("Starting Hash verification for {}", item.url);
|
||||||
|
|
||||||
|
match jan_utils::crypto::compute_file_sha256_with_cancellation(save_path, cancel_token).await {
|
||||||
|
Ok(computed_sha256) => {
|
||||||
|
if computed_sha256 != *expected_sha256 {
|
||||||
|
log::error!(
|
||||||
|
"Hash verification failed for {}. Expected: {}, Computed: {}",
|
||||||
|
item.url,
|
||||||
|
expected_sha256,
|
||||||
|
computed_sha256
|
||||||
|
);
|
||||||
|
|
||||||
|
return Err(format!(
|
||||||
|
"Hash verification failed. The downloaded file is corrupted or has been tampered with."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("Hash verification successful for {}", item.url);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
log::error!(
|
||||||
|
"Failed to compute SHA256 for {}: {}",
|
||||||
|
save_path.display(),
|
||||||
|
e
|
||||||
|
);
|
||||||
|
return Err(format!("Failed to verify file integrity: {}", e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log::info!("All validations passed for {}", item.url);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> {
|
pub fn validate_proxy_config(config: &ProxyConfig) -> Result<(), String> {
|
||||||
// Validate proxy URL format
|
// Validate proxy URL format
|
||||||
if let Err(e) = Url::parse(&config.url) {
|
if let Err(e) = Url::parse(&config.url) {
|
||||||
@ -172,6 +294,9 @@ pub async fn _get_file_size(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ===== MAIN DOWNLOAD FUNCTIONS =====
|
||||||
|
|
||||||
|
/// Downloads multiple files in parallel with individual progress tracking
|
||||||
pub async fn _download_files_internal(
|
pub async fn _download_files_internal(
|
||||||
app: tauri::AppHandle,
|
app: tauri::AppHandle,
|
||||||
items: &[DownloadItem],
|
items: &[DownloadItem],
|
||||||
@ -184,28 +309,31 @@ pub async fn _download_files_internal(
|
|||||||
|
|
||||||
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
let header_map = _convert_headers(headers).map_err(err_to_string)?;
|
||||||
|
|
||||||
let total_size = {
|
// Calculate sizes for each file
|
||||||
let mut total_size = 0u64;
|
let mut file_sizes = HashMap::new();
|
||||||
for item in items.iter() {
|
for item in items.iter() {
|
||||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||||
total_size += _get_file_size(&client, &item.url)
|
let size = _get_file_size(&client, &item.url)
|
||||||
.await
|
.await
|
||||||
.map_err(err_to_string)?;
|
.map_err(err_to_string)?;
|
||||||
}
|
file_sizes.insert(item.url.clone(), size);
|
||||||
total_size
|
}
|
||||||
};
|
|
||||||
|
let total_size: u64 = file_sizes.values().sum();
|
||||||
log::info!("Total download size: {}", total_size);
|
log::info!("Total download size: {}", total_size);
|
||||||
|
|
||||||
let mut evt = DownloadEvent {
|
|
||||||
transferred: 0,
|
|
||||||
total: total_size,
|
|
||||||
};
|
|
||||||
let evt_name = format!("download-{}", task_id);
|
let evt_name = format!("download-{}", task_id);
|
||||||
|
|
||||||
|
// Create progress tracker
|
||||||
|
let progress_tracker = ProgressTracker::new(items, file_sizes.clone());
|
||||||
|
|
||||||
// save file under Jan data folder
|
// save file under Jan data folder
|
||||||
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||||
|
|
||||||
for item in items.iter() {
|
// Collect download tasks for parallel execution
|
||||||
|
let mut download_tasks = Vec::new();
|
||||||
|
|
||||||
|
for (index, item) in items.iter().enumerate() {
|
||||||
let save_path = jan_data_folder.join(&item.save_path);
|
let save_path = jan_data_folder.join(&item.save_path);
|
||||||
let save_path = normalize_path(&save_path);
|
let save_path = normalize_path(&save_path);
|
||||||
|
|
||||||
@ -217,120 +345,251 @@ pub async fn _download_files_internal(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create parent directories if they don't exist
|
// Spawn download task for each file
|
||||||
if let Some(parent) = save_path.parent() {
|
let item_clone = item.clone();
|
||||||
if !parent.exists() {
|
let app_clone = app.clone();
|
||||||
tokio::fs::create_dir_all(parent)
|
let header_map_clone = header_map.clone();
|
||||||
.await
|
let cancel_token_clone = cancel_token.clone();
|
||||||
.map_err(err_to_string)?;
|
let evt_name_clone = evt_name.clone();
|
||||||
}
|
let progress_tracker_clone = progress_tracker.clone();
|
||||||
}
|
let file_id = format!("{}-{}", task_id, index);
|
||||||
|
let file_size = file_sizes.get(&item.url).copied().unwrap_or(0);
|
||||||
|
|
||||||
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
|
let task = tokio::spawn(async move {
|
||||||
let append_extension = |ext: &str| {
|
download_single_file(
|
||||||
if current_extension.is_empty() {
|
app_clone,
|
||||||
ext.to_string()
|
&item_clone,
|
||||||
} else {
|
&header_map_clone,
|
||||||
format!("{}.{}", current_extension, ext)
|
&save_path,
|
||||||
}
|
resume,
|
||||||
};
|
cancel_token_clone,
|
||||||
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
|
evt_name_clone,
|
||||||
let url_save_path = save_path.with_extension(append_extension("url"));
|
progress_tracker_clone,
|
||||||
|
file_id,
|
||||||
let mut should_resume = resume
|
file_size,
|
||||||
&& tmp_save_path.exists()
|
)
|
||||||
&& tokio::fs::read_to_string(&url_save_path)
|
|
||||||
.await
|
|
||||||
.map(|url| url == item.url) // check if we resume the same URL
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
tokio::fs::write(&url_save_path, item.url.clone())
|
|
||||||
.await
|
.await
|
||||||
.map_err(err_to_string)?;
|
});
|
||||||
|
|
||||||
log::info!("Started downloading: {}", item.url);
|
download_tasks.push(task);
|
||||||
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
|
||||||
let mut download_delta = 0u64;
|
|
||||||
let resp = if should_resume {
|
|
||||||
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
|
|
||||||
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
|
|
||||||
Ok(resp) => {
|
|
||||||
log::info!(
|
|
||||||
"Resume download: {}, already downloaded {} bytes",
|
|
||||||
item.url,
|
|
||||||
downloaded_size
|
|
||||||
);
|
|
||||||
download_delta += downloaded_size;
|
|
||||||
resp
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// fallback to normal download
|
|
||||||
log::warn!("Failed to resume download: {}", e);
|
|
||||||
should_resume = false;
|
|
||||||
_get_maybe_resume(&client, &item.url, 0).await?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_get_maybe_resume(&client, &item.url, 0).await?
|
|
||||||
};
|
|
||||||
let mut stream = resp.bytes_stream();
|
|
||||||
|
|
||||||
let file = if should_resume {
|
|
||||||
// resume download, append to existing file
|
|
||||||
tokio::fs::OpenOptions::new()
|
|
||||||
.write(true)
|
|
||||||
.append(true)
|
|
||||||
.open(&tmp_save_path)
|
|
||||||
.await
|
|
||||||
.map_err(err_to_string)?
|
|
||||||
} else {
|
|
||||||
// start new download, create a new file
|
|
||||||
File::create(&tmp_save_path).await.map_err(err_to_string)?
|
|
||||||
};
|
|
||||||
let mut writer = tokio::io::BufWriter::new(file);
|
|
||||||
|
|
||||||
// write chunk to file
|
|
||||||
while let Some(chunk) = stream.next().await {
|
|
||||||
if cancel_token.is_cancelled() {
|
|
||||||
if !should_resume {
|
|
||||||
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
}
|
|
||||||
log::info!("Download cancelled for task: {}", task_id);
|
|
||||||
app.emit(&evt_name, evt.clone()).unwrap();
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let chunk = chunk.map_err(err_to_string)?;
|
|
||||||
writer.write_all(&chunk).await.map_err(err_to_string)?;
|
|
||||||
download_delta += chunk.len() as u64;
|
|
||||||
|
|
||||||
// only update every 10 MB
|
|
||||||
if download_delta >= 10 * 1024 * 1024 {
|
|
||||||
evt.transferred += download_delta;
|
|
||||||
app.emit(&evt_name, evt.clone()).unwrap();
|
|
||||||
download_delta = 0u64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
writer.flush().await.map_err(err_to_string)?;
|
|
||||||
evt.transferred += download_delta;
|
|
||||||
|
|
||||||
// rename tmp file to final file
|
|
||||||
tokio::fs::rename(&tmp_save_path, &save_path)
|
|
||||||
.await
|
|
||||||
.map_err(err_to_string)?;
|
|
||||||
tokio::fs::remove_file(&url_save_path)
|
|
||||||
.await
|
|
||||||
.map_err(err_to_string)?;
|
|
||||||
log::info!("Finished downloading: {}", item.url);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
app.emit(&evt_name, evt.clone()).unwrap();
|
// Wait for all downloads to complete
|
||||||
|
let mut validation_tasks = Vec::new();
|
||||||
|
for (task, item) in download_tasks.into_iter().zip(items.iter()) {
|
||||||
|
let result = task.await.map_err(|e| format!("Task join error: {}", e))?;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(downloaded_path) => {
|
||||||
|
// Spawn validation task in parallel
|
||||||
|
let item_clone = item.clone();
|
||||||
|
let app_clone = app.clone();
|
||||||
|
let path_clone = downloaded_path.clone();
|
||||||
|
let cancel_token_clone = cancel_token.clone();
|
||||||
|
let validation_task = tokio::spawn(async move {
|
||||||
|
validate_downloaded_file(&item_clone, &path_clone, &app_clone, &cancel_token_clone).await
|
||||||
|
});
|
||||||
|
validation_tasks.push((validation_task, downloaded_path, item.clone()));
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all validations to complete
|
||||||
|
for (validation_task, save_path, _item) in validation_tasks {
|
||||||
|
let validation_result = validation_task
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Validation task join error: {}", e))?;
|
||||||
|
|
||||||
|
if let Err(validation_error) = validation_result {
|
||||||
|
// Clean up the file if validation fails
|
||||||
|
let _ = tokio::fs::remove_file(&save_path).await;
|
||||||
|
|
||||||
|
// Try to clean up the parent directory if it's empty
|
||||||
|
if let Some(parent) = save_path.parent() {
|
||||||
|
let _ = tokio::fs::remove_dir(parent).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(validation_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit final progress
|
||||||
|
let (transferred, total) = progress_tracker.get_total_progress().await;
|
||||||
|
let final_evt = DownloadEvent { transferred, total };
|
||||||
|
app.emit(&evt_name, final_evt).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Downloads a single file without blocking other downloads
|
||||||
|
async fn download_single_file(
|
||||||
|
app: tauri::AppHandle,
|
||||||
|
item: &DownloadItem,
|
||||||
|
header_map: &HeaderMap,
|
||||||
|
save_path: &std::path::Path,
|
||||||
|
resume: bool,
|
||||||
|
cancel_token: CancellationToken,
|
||||||
|
evt_name: String,
|
||||||
|
progress_tracker: ProgressTracker,
|
||||||
|
file_id: String,
|
||||||
|
_file_size: u64,
|
||||||
|
) -> Result<std::path::PathBuf, String> {
|
||||||
|
// Create parent directories if they don't exist
|
||||||
|
if let Some(parent) = save_path.parent() {
|
||||||
|
if !parent.exists() {
|
||||||
|
tokio::fs::create_dir_all(parent)
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let current_extension = save_path.extension().unwrap_or_default().to_string_lossy();
|
||||||
|
let append_extension = |ext: &str| {
|
||||||
|
if current_extension.is_empty() {
|
||||||
|
ext.to_string()
|
||||||
|
} else {
|
||||||
|
format!("{}.{}", current_extension, ext)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let tmp_save_path = save_path.with_extension(append_extension("tmp"));
|
||||||
|
let url_save_path = save_path.with_extension(append_extension("url"));
|
||||||
|
|
||||||
|
let mut should_resume = resume
|
||||||
|
&& tmp_save_path.exists()
|
||||||
|
&& tokio::fs::read_to_string(&url_save_path)
|
||||||
|
.await
|
||||||
|
.map(|url| url == item.url) // check if we resume the same URL
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
tokio::fs::write(&url_save_path, item.url.clone())
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
|
||||||
|
log::info!("Started downloading: {}", item.url);
|
||||||
|
let client = _get_client_for_item(item, &header_map).map_err(err_to_string)?;
|
||||||
|
let mut download_delta = 0u64;
|
||||||
|
let mut initial_progress = 0u64;
|
||||||
|
|
||||||
|
let resp = if should_resume {
|
||||||
|
let downloaded_size = tmp_save_path.metadata().map_err(err_to_string)?.len();
|
||||||
|
match _get_maybe_resume(&client, &item.url, downloaded_size).await {
|
||||||
|
Ok(resp) => {
|
||||||
|
log::info!(
|
||||||
|
"Resume download: {}, already downloaded {} bytes",
|
||||||
|
item.url,
|
||||||
|
downloaded_size
|
||||||
|
);
|
||||||
|
initial_progress = downloaded_size;
|
||||||
|
|
||||||
|
// Initialize progress for resumed download
|
||||||
|
progress_tracker
|
||||||
|
.update_progress(&file_id, downloaded_size)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Emit initial combined progress
|
||||||
|
let (combined_transferred, combined_total) =
|
||||||
|
progress_tracker.get_total_progress().await;
|
||||||
|
let evt = DownloadEvent {
|
||||||
|
transferred: combined_transferred,
|
||||||
|
total: combined_total,
|
||||||
|
};
|
||||||
|
app.emit(&evt_name, evt).unwrap();
|
||||||
|
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// fallback to normal download
|
||||||
|
log::warn!("Failed to resume download: {}", e);
|
||||||
|
should_resume = false;
|
||||||
|
_get_maybe_resume(&client, &item.url, 0).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_get_maybe_resume(&client, &item.url, 0).await?
|
||||||
|
};
|
||||||
|
let mut stream = resp.bytes_stream();
|
||||||
|
|
||||||
|
let file = if should_resume {
|
||||||
|
// resume download, append to existing file
|
||||||
|
tokio::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.append(true)
|
||||||
|
.open(&tmp_save_path)
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?
|
||||||
|
} else {
|
||||||
|
// start new download, create a new file
|
||||||
|
File::create(&tmp_save_path).await.map_err(err_to_string)?
|
||||||
|
};
|
||||||
|
let mut writer = tokio::io::BufWriter::new(file);
|
||||||
|
let mut total_transferred = initial_progress;
|
||||||
|
|
||||||
|
// write chunk to file
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
if !should_resume {
|
||||||
|
tokio::fs::remove_dir_all(&save_path.parent().unwrap())
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
log::info!("Download cancelled: {}", item.url);
|
||||||
|
return Err("Download cancelled".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = chunk.map_err(err_to_string)?;
|
||||||
|
writer.write_all(&chunk).await.map_err(err_to_string)?;
|
||||||
|
download_delta += chunk.len() as u64;
|
||||||
|
total_transferred += chunk.len() as u64;
|
||||||
|
|
||||||
|
// Update progress every 10 MB
|
||||||
|
if download_delta >= 10 * 1024 * 1024 {
|
||||||
|
// Update individual file progress
|
||||||
|
progress_tracker
|
||||||
|
.update_progress(&file_id, total_transferred)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Emit combined progress event
|
||||||
|
let (combined_transferred, combined_total) =
|
||||||
|
progress_tracker.get_total_progress().await;
|
||||||
|
let evt = DownloadEvent {
|
||||||
|
transferred: combined_transferred,
|
||||||
|
total: combined_total,
|
||||||
|
};
|
||||||
|
app.emit(&evt_name, evt).unwrap();
|
||||||
|
|
||||||
|
download_delta = 0u64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.flush().await.map_err(err_to_string)?;
|
||||||
|
|
||||||
|
// Final progress update for this file
|
||||||
|
progress_tracker
|
||||||
|
.update_progress(&file_id, total_transferred)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Emit final combined progress
|
||||||
|
let (combined_transferred, combined_total) = progress_tracker.get_total_progress().await;
|
||||||
|
let evt = DownloadEvent {
|
||||||
|
transferred: combined_transferred,
|
||||||
|
total: combined_total,
|
||||||
|
};
|
||||||
|
app.emit(&evt_name, evt).unwrap();
|
||||||
|
|
||||||
|
// rename tmp file to final file
|
||||||
|
tokio::fs::rename(&tmp_save_path, &save_path)
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
tokio::fs::remove_file(&url_save_path)
|
||||||
|
.await
|
||||||
|
.map_err(err_to_string)?;
|
||||||
|
|
||||||
|
log::info!("Finished downloading: {}", item.url);
|
||||||
|
Ok(save_path.to_path_buf())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== HTTP CLIENT HELPER FUNCTIONS =====
|
||||||
|
|
||||||
pub async fn _get_maybe_resume(
|
pub async fn _get_maybe_resume(
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
url: &str,
|
url: &str,
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tokio_util::sync::CancellationToken;
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
@ -20,6 +22,8 @@ pub struct DownloadItem {
|
|||||||
pub url: String,
|
pub url: String,
|
||||||
pub save_path: String,
|
pub save_path: String,
|
||||||
pub proxy: Option<ProxyConfig>,
|
pub proxy: Option<ProxyConfig>,
|
||||||
|
pub sha256: Option<String>,
|
||||||
|
pub size: Option<u64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(serde::Serialize, Clone, Debug)]
|
#[derive(serde::Serialize, Clone, Debug)]
|
||||||
@ -27,3 +31,31 @@ pub struct DownloadEvent {
|
|||||||
pub transferred: u64,
|
pub transferred: u64,
|
||||||
pub total: u64,
|
pub total: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Structure to track progress for each file in parallel downloads
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ProgressTracker {
|
||||||
|
file_progress: Arc<Mutex<HashMap<String, u64>>>,
|
||||||
|
total_size: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProgressTracker {
|
||||||
|
pub fn new(_items: &[DownloadItem], sizes: HashMap<String, u64>) -> Self {
|
||||||
|
let total_size = sizes.values().sum();
|
||||||
|
ProgressTracker {
|
||||||
|
file_progress: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
total_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_progress(&self, file_id: &str, transferred: u64) {
|
||||||
|
let mut progress = self.file_progress.lock().await;
|
||||||
|
progress.insert(file_id.to_string(), transferred);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_total_progress(&self) -> (u64, u64) {
|
||||||
|
let progress = self.file_progress.lock().await;
|
||||||
|
let total_transferred: u64 = progress.values().sum();
|
||||||
|
(total_transferred, self.total_size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -194,6 +194,8 @@ fn test_download_item_with_ssl_proxy() {
|
|||||||
url: "https://example.com/file.zip".to_string(),
|
url: "https://example.com/file.zip".to_string(),
|
||||||
save_path: "downloads/file.zip".to_string(),
|
save_path: "downloads/file.zip".to_string(),
|
||||||
proxy: Some(proxy_config),
|
proxy: Some(proxy_config),
|
||||||
|
sha256: None,
|
||||||
|
size: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert!(download_item.proxy.is_some());
|
assert!(download_item.proxy.is_some());
|
||||||
@ -211,6 +213,8 @@ fn test_client_creation_with_ssl_settings() {
|
|||||||
url: "https://example.com/file.zip".to_string(),
|
url: "https://example.com/file.zip".to_string(),
|
||||||
save_path: "downloads/file.zip".to_string(),
|
save_path: "downloads/file.zip".to_string(),
|
||||||
proxy: Some(proxy_config),
|
proxy: Some(proxy_config),
|
||||||
|
sha256: None,
|
||||||
|
size: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let header_map = HeaderMap::new();
|
let header_map = HeaderMap::new();
|
||||||
@ -256,6 +260,8 @@ fn test_download_item_creation() {
|
|||||||
url: "https://example.com/file.tar.gz".to_string(),
|
url: "https://example.com/file.tar.gz".to_string(),
|
||||||
save_path: "models/test.tar.gz".to_string(),
|
save_path: "models/test.tar.gz".to_string(),
|
||||||
proxy: None,
|
proxy: None,
|
||||||
|
sha256: None,
|
||||||
|
size: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(item.url, "https://example.com/file.tar.gz");
|
assert_eq!(item.url, "https://example.com/file.tar.gz");
|
||||||
|
|||||||
@ -13,6 +13,7 @@ serde = { version = "1.0", features = ["derive"] }
|
|||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
sha2 = "0.10"
|
sha2 = "0.10"
|
||||||
tokio = { version = "1", features = ["process"] }
|
tokio = { version = "1", features = ["process"] }
|
||||||
|
tokio-util = "0.7.14"
|
||||||
url = "2.5"
|
url = "2.5"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
use rand::{distributions::Alphanumeric, Rng};
|
use rand::{distributions::Alphanumeric, Rng};
|
||||||
use sha2::Sha256;
|
use sha2::{Digest, Sha256};
|
||||||
|
use std::path::Path;
|
||||||
|
use tokio::fs::File;
|
||||||
|
use tokio::io::AsyncReadExt;
|
||||||
|
use tokio_util::sync::CancellationToken;
|
||||||
|
|
||||||
type HmacSha256 = Hmac<Sha256>;
|
type HmacSha256 = Hmac<Sha256>;
|
||||||
|
|
||||||
@ -24,3 +28,59 @@ pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String,
|
|||||||
let hash = general_purpose::STANDARD.encode(code_bytes);
|
let hash = general_purpose::STANDARD.encode(code_bytes);
|
||||||
Ok(hash)
|
Ok(hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Compute SHA256 hash of a file with cancellation support by chunking the file
|
||||||
|
pub async fn compute_file_sha256_with_cancellation(
|
||||||
|
file_path: &Path,
|
||||||
|
cancel_token: &CancellationToken,
|
||||||
|
) -> Result<String, String> {
|
||||||
|
// Check for cancellation before starting
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
return Err("Hash computation cancelled".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut file = File::open(file_path)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to open file for hashing: {}", e))?;
|
||||||
|
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
let mut buffer = vec![0u8; 64 * 1024]; // 64KB chunks
|
||||||
|
let mut total_read = 0u64;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Check for cancellation every chunk (every 64KB)
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
return Err("Hash computation cancelled".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes_read = file
|
||||||
|
.read(&mut buffer)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to read file for hashing: {}", e))?;
|
||||||
|
|
||||||
|
if bytes_read == 0 {
|
||||||
|
break; // EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
hasher.update(&buffer[..bytes_read]);
|
||||||
|
total_read += bytes_read as u64;
|
||||||
|
|
||||||
|
// Log progress for very large files (every 100MB)
|
||||||
|
if total_read % (100 * 1024 * 1024) == 0 {
|
||||||
|
#[cfg(feature = "logging")]
|
||||||
|
log::debug!("Hash progress: {} MB processed", total_read / (1024 * 1024));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final cancellation check
|
||||||
|
if cancel_token.is_cancelled() {
|
||||||
|
return Err("Hash computation cancelled".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let hash_bytes = hasher.finalize();
|
||||||
|
let hash_hex = format!("{:x}", hash_bytes);
|
||||||
|
|
||||||
|
#[cfg(feature = "logging")]
|
||||||
|
log::debug!("Hash computation completed for {} bytes", total_read);
|
||||||
|
Ok(hash_hex)
|
||||||
|
}
|
||||||
|
|||||||
@ -168,9 +168,46 @@ export function DownloadManagement() {
|
|||||||
[removeDownload, removeLocalDownloadingModel, t]
|
[removeDownload, removeLocalDownloadingModel, t]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const onModelValidationStarted = useCallback(
|
||||||
|
(event: { modelId: string; downloadType: string }) => {
|
||||||
|
console.debug('onModelValidationStarted', event)
|
||||||
|
|
||||||
|
// Show validation in progress toast
|
||||||
|
toast.info(t('common:toast.modelValidationStarted.title'), {
|
||||||
|
id: `model-validation-started-${event.modelId}`,
|
||||||
|
description: t('common:toast.modelValidationStarted.description', {
|
||||||
|
modelId: event.modelId,
|
||||||
|
}),
|
||||||
|
duration: 10000,
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[t]
|
||||||
|
)
|
||||||
|
|
||||||
|
const onModelValidationFailed = useCallback(
|
||||||
|
(event: { modelId: string; error: string; reason: string }) => {
|
||||||
|
console.debug('onModelValidationFailed', event)
|
||||||
|
|
||||||
|
// Dismiss the validation started toast
|
||||||
|
toast.dismiss(`model-validation-started-${event.modelId}`)
|
||||||
|
|
||||||
|
removeDownload(event.modelId)
|
||||||
|
removeLocalDownloadingModel(event.modelId)
|
||||||
|
|
||||||
|
// Show specific toast for validation failure
|
||||||
|
toast.error(t('common:toast.modelValidationFailed.title'), {
|
||||||
|
description: t('common:toast.modelValidationFailed.description', {
|
||||||
|
modelId: event.modelId,
|
||||||
|
}),
|
||||||
|
duration: 30000, // Requires manual dismissal for security-critical message
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[removeDownload, removeLocalDownloadingModel, t]
|
||||||
|
)
|
||||||
|
|
||||||
const onFileDownloadStopped = useCallback(
|
const onFileDownloadStopped = useCallback(
|
||||||
(state: DownloadState) => {
|
(state: DownloadState) => {
|
||||||
console.debug('onFileDownloadError', state)
|
console.debug('onFileDownloadStopped', state)
|
||||||
removeDownload(state.modelId)
|
removeDownload(state.modelId)
|
||||||
removeLocalDownloadingModel(state.modelId)
|
removeLocalDownloadingModel(state.modelId)
|
||||||
},
|
},
|
||||||
@ -180,6 +217,10 @@ export function DownloadManagement() {
|
|||||||
const onFileDownloadSuccess = useCallback(
|
const onFileDownloadSuccess = useCallback(
|
||||||
async (state: DownloadState) => {
|
async (state: DownloadState) => {
|
||||||
console.debug('onFileDownloadSuccess', state)
|
console.debug('onFileDownloadSuccess', state)
|
||||||
|
|
||||||
|
// Dismiss any validation started toast when download completes successfully
|
||||||
|
toast.dismiss(`model-validation-started-${state.modelId}`)
|
||||||
|
|
||||||
removeDownload(state.modelId)
|
removeDownload(state.modelId)
|
||||||
removeLocalDownloadingModel(state.modelId)
|
removeLocalDownloadingModel(state.modelId)
|
||||||
toast.success(t('common:toast.downloadComplete.title'), {
|
toast.success(t('common:toast.downloadComplete.title'), {
|
||||||
@ -192,12 +233,34 @@ export function DownloadManagement() {
|
|||||||
[removeDownload, removeLocalDownloadingModel, t]
|
[removeDownload, removeLocalDownloadingModel, t]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const onFileDownloadAndVerificationSuccess = useCallback(
|
||||||
|
async (state: DownloadState) => {
|
||||||
|
console.debug('onFileDownloadAndVerificationSuccess', state)
|
||||||
|
|
||||||
|
// Dismiss any validation started toast when download and verification complete successfully
|
||||||
|
toast.dismiss(`model-validation-started-${state.modelId}`)
|
||||||
|
|
||||||
|
removeDownload(state.modelId)
|
||||||
|
removeLocalDownloadingModel(state.modelId)
|
||||||
|
toast.success(t('common:toast.downloadAndVerificationComplete.title'), {
|
||||||
|
id: 'download-complete',
|
||||||
|
description: t('common:toast.downloadAndVerificationComplete.description', {
|
||||||
|
item: state.modelId,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
},
|
||||||
|
[removeDownload, removeLocalDownloadingModel, t]
|
||||||
|
)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
console.debug('DownloadListener: registering event listeners...')
|
console.debug('DownloadListener: registering event listeners...')
|
||||||
events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
|
events.on(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate)
|
||||||
events.on(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
events.on(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
||||||
events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
events.on(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||||
events.on(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
events.on(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
||||||
|
events.on(DownloadEvent.onModelValidationStarted, onModelValidationStarted)
|
||||||
|
events.on(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
|
||||||
|
events.on(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
|
||||||
|
|
||||||
// Register app update event listeners
|
// Register app update event listeners
|
||||||
events.on(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
events.on(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
||||||
@ -210,6 +273,12 @@ export function DownloadManagement() {
|
|||||||
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
events.off(DownloadEvent.onFileDownloadError, onFileDownloadError)
|
||||||
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess)
|
||||||
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped)
|
||||||
|
events.off(
|
||||||
|
DownloadEvent.onModelValidationStarted,
|
||||||
|
onModelValidationStarted
|
||||||
|
)
|
||||||
|
events.off(DownloadEvent.onModelValidationFailed, onModelValidationFailed)
|
||||||
|
events.off(DownloadEvent.onFileDownloadAndVerificationSuccess, onFileDownloadAndVerificationSuccess)
|
||||||
|
|
||||||
// Unregister app update event listeners
|
// Unregister app update event listeners
|
||||||
events.off(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
events.off(AppEvent.onAppUpdateDownloadUpdate, onAppUpdateDownloadUpdate)
|
||||||
@ -224,6 +293,9 @@ export function DownloadManagement() {
|
|||||||
onFileDownloadError,
|
onFileDownloadError,
|
||||||
onFileDownloadSuccess,
|
onFileDownloadSuccess,
|
||||||
onFileDownloadStopped,
|
onFileDownloadStopped,
|
||||||
|
onModelValidationStarted,
|
||||||
|
onModelValidationFailed,
|
||||||
|
onFileDownloadAndVerificationSuccess,
|
||||||
onAppUpdateDownloadUpdate,
|
onAppUpdateDownloadUpdate,
|
||||||
onAppUpdateDownloadSuccess,
|
onAppUpdateDownloadSuccess,
|
||||||
onAppUpdateDownloadError,
|
onAppUpdateDownloadError,
|
||||||
|
|||||||
@ -256,6 +256,22 @@
|
|||||||
"downloadCancelled": {
|
"downloadCancelled": {
|
||||||
"title": "Download abgebrochen",
|
"title": "Download abgebrochen",
|
||||||
"description": "Der Download-Prozess wurde abgebrochen"
|
"description": "Der Download-Prozess wurde abgebrochen"
|
||||||
|
},
|
||||||
|
"downloadFailed": {
|
||||||
|
"title": "Download fehlgeschlagen",
|
||||||
|
"description": "{{item}} Download fehlgeschlagen"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "Modell wird validiert",
|
||||||
|
"description": "Modell \"{{modelId}}\" erfolgreich heruntergeladen. Integrität wird überprüft..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "Modellvalidierung fehlgeschlagen",
|
||||||
|
"description": "Das heruntergeladene Modell \"{{modelId}}\" ist bei der Integritätsprüfung fehlgeschlagen und wurde entfernt. Die Datei könnte beschädigt oder manipuliert worden sein."
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "Download abgeschlossen",
|
||||||
|
"description": "Modell \"{{item}}\" erfolgreich heruntergeladen und verifiziert"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "Zeige Varianten",
|
"showVariants": "Zeige Varianten",
|
||||||
"useModel": "Nutze dieses Modell",
|
"useModel": "Nutze dieses Modell",
|
||||||
"downloadModel": "Modell herunterladen",
|
"downloadModel": "Modell herunterladen",
|
||||||
|
"tools": "Werkzeuge",
|
||||||
"searchPlaceholder": "Suche nach Modellen auf Hugging Face...",
|
"searchPlaceholder": "Suche nach Modellen auf Hugging Face...",
|
||||||
"editTheme": "Bearbeite Erscheinungsbild",
|
"editTheme": "Bearbeite Erscheinungsbild",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
|
|||||||
@ -261,6 +261,18 @@
|
|||||||
"downloadFailed": {
|
"downloadFailed": {
|
||||||
"title": "Download Failed",
|
"title": "Download Failed",
|
||||||
"description": "{{item}} download failed"
|
"description": "{{item}} download failed"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "Validating Model",
|
||||||
|
"description": "Downloaded model \"{{modelId}}\" successfully. Verifying integrity..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "Model Validation Failed",
|
||||||
|
"description": "The downloaded model \"{{modelId}}\" failed integrity verification and was removed. The file may be corrupted or tampered with."
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "Download Complete",
|
||||||
|
"description": "Model \"{{item}}\" downloaded and verified successfully"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "Show variants",
|
"showVariants": "Show variants",
|
||||||
"useModel": "Use this model",
|
"useModel": "Use this model",
|
||||||
"downloadModel": "Download model",
|
"downloadModel": "Download model",
|
||||||
|
"tools": "Tools",
|
||||||
"searchPlaceholder": "Search for models on Hugging Face...",
|
"searchPlaceholder": "Search for models on Hugging Face...",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
"recommendedModelTitle": "Recommended Model",
|
"recommendedModelTitle": "Recommended Model",
|
||||||
|
|||||||
@ -249,6 +249,22 @@
|
|||||||
"downloadCancelled": {
|
"downloadCancelled": {
|
||||||
"title": "Unduhan Dibatalkan",
|
"title": "Unduhan Dibatalkan",
|
||||||
"description": "Proses unduhan telah dibatalkan"
|
"description": "Proses unduhan telah dibatalkan"
|
||||||
|
},
|
||||||
|
"downloadFailed": {
|
||||||
|
"title": "Unduhan Gagal",
|
||||||
|
"description": "Unduhan {{item}} gagal"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "Memvalidasi Model",
|
||||||
|
"description": "Model \"{{modelId}}\" berhasil diunduh. Memverifikasi integritas..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "Validasi Model Gagal",
|
||||||
|
"description": "Model yang diunduh \"{{modelId}}\" gagal verifikasi integritas dan telah dihapus. File mungkin rusak atau telah dimanipulasi."
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "Unduhan Selesai",
|
||||||
|
"description": "Model \"{{item}}\" berhasil diunduh dan diverifikasi"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "Tampilkan Varian",
|
"showVariants": "Tampilkan Varian",
|
||||||
"useModel": "Gunakan model ini",
|
"useModel": "Gunakan model ini",
|
||||||
"downloadModel": "Unduh model",
|
"downloadModel": "Unduh model",
|
||||||
|
"tools": "Alat",
|
||||||
"searchPlaceholder": "Cari model di Hugging Face...",
|
"searchPlaceholder": "Cari model di Hugging Face...",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
"recommendedModelTitle": "Model yang Direkomendasikan",
|
"recommendedModelTitle": "Model yang Direkomendasikan",
|
||||||
|
|||||||
@ -249,6 +249,22 @@
|
|||||||
"downloadCancelled": {
|
"downloadCancelled": {
|
||||||
"title": "Đã hủy tải xuống",
|
"title": "Đã hủy tải xuống",
|
||||||
"description": "Quá trình tải xuống đã bị hủy"
|
"description": "Quá trình tải xuống đã bị hủy"
|
||||||
|
},
|
||||||
|
"downloadFailed": {
|
||||||
|
"title": "Tải xuống thất bại",
|
||||||
|
"description": "Tải xuống {{item}} thất bại"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "Đang xác thực mô hình",
|
||||||
|
"description": "Đã tải xuống mô hình \"{{modelId}}\" thành công. Đang xác minh tính toàn vẹn..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "Xác thực mô hình thất bại",
|
||||||
|
"description": "Mô hình đã tải xuống \"{{modelId}}\" không vượt qua kiểm tra tính toàn vẹn và đã bị xóa. Tệp có thể bị hỏng hoặc bị giả mạo."
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "Tải xuống hoàn tất",
|
||||||
|
"description": "Mô hình \"{{item}}\" đã được tải xuống và xác minh thành công"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "Hiển thị biến thể",
|
"showVariants": "Hiển thị biến thể",
|
||||||
"useModel": "Sử dụng mô hình này",
|
"useModel": "Sử dụng mô hình này",
|
||||||
"downloadModel": "Tải xuống mô hình",
|
"downloadModel": "Tải xuống mô hình",
|
||||||
|
"tools": "Công cụ",
|
||||||
"searchPlaceholder": "Tìm kiếm các mô hình trên Hugging Face...",
|
"searchPlaceholder": "Tìm kiếm các mô hình trên Hugging Face...",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
"recommendedModelTitle": "Mô hình được đề xuất",
|
"recommendedModelTitle": "Mô hình được đề xuất",
|
||||||
|
|||||||
@ -249,6 +249,22 @@
|
|||||||
"downloadCancelled": {
|
"downloadCancelled": {
|
||||||
"title": "下载已取消",
|
"title": "下载已取消",
|
||||||
"description": "下载过程已取消"
|
"description": "下载过程已取消"
|
||||||
|
},
|
||||||
|
"downloadFailed": {
|
||||||
|
"title": "下载失败",
|
||||||
|
"description": "{{item}} 下载失败"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "正在验证模型",
|
||||||
|
"description": "模型 \"{{modelId}}\" 下载成功。正在验证完整性..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "模型验证失败",
|
||||||
|
"description": "已下载的模型 \"{{modelId}}\" 未通过完整性验证并已被删除。文件可能损坏或被篡改。"
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "下载完成",
|
||||||
|
"description": "模型 \"{{item}}\" 下载并验证成功"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "显示变体",
|
"showVariants": "显示变体",
|
||||||
"useModel": "使用此模型",
|
"useModel": "使用此模型",
|
||||||
"downloadModel": "下载模型",
|
"downloadModel": "下载模型",
|
||||||
|
"tools": "工具",
|
||||||
"searchPlaceholder": "在 Hugging Face 上搜索模型...",
|
"searchPlaceholder": "在 Hugging Face 上搜索模型...",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
"recommendedModelTitle": "推荐模型",
|
"recommendedModelTitle": "推荐模型",
|
||||||
|
|||||||
@ -249,6 +249,22 @@
|
|||||||
"downloadCancelled": {
|
"downloadCancelled": {
|
||||||
"title": "下載已取消",
|
"title": "下載已取消",
|
||||||
"description": "下載過程已取消"
|
"description": "下載過程已取消"
|
||||||
|
},
|
||||||
|
"downloadFailed": {
|
||||||
|
"title": "下載失敗",
|
||||||
|
"description": "{{item}} 下載失敗"
|
||||||
|
},
|
||||||
|
"modelValidationStarted": {
|
||||||
|
"title": "正在驗證模型",
|
||||||
|
"description": "模型 \"{{modelId}}\" 下載成功。正在驗證完整性..."
|
||||||
|
},
|
||||||
|
"modelValidationFailed": {
|
||||||
|
"title": "模型驗證失敗",
|
||||||
|
"description": "已下載的模型 \"{{modelId}}\" 未通過完整性驗證並已被刪除。檔案可能損壞或被篡改。"
|
||||||
|
},
|
||||||
|
"downloadAndVerificationComplete": {
|
||||||
|
"title": "下載完成",
|
||||||
|
"description": "模型 \"{{item}}\" 下載並驗證成功"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
"showVariants": "顯示變體",
|
"showVariants": "顯示變體",
|
||||||
"useModel": "使用此模型",
|
"useModel": "使用此模型",
|
||||||
"downloadModel": "下載模型",
|
"downloadModel": "下載模型",
|
||||||
|
"tools": "工具",
|
||||||
"searchPlaceholder": "在 Hugging Face 上搜尋模型...",
|
"searchPlaceholder": "在 Hugging Face 上搜尋模型...",
|
||||||
"joyride": {
|
"joyride": {
|
||||||
"recommendedModelTitle": "推薦模型",
|
"recommendedModelTitle": "推薦模型",
|
||||||
|
|||||||
@ -22,7 +22,7 @@ import {
|
|||||||
CatalogModel,
|
CatalogModel,
|
||||||
convertHfRepoToCatalogModel,
|
convertHfRepoToCatalogModel,
|
||||||
fetchHuggingFaceRepo,
|
fetchHuggingFaceRepo,
|
||||||
pullModel,
|
pullModelWithMetadata,
|
||||||
} from '@/services/models'
|
} from '@/services/models'
|
||||||
import { Progress } from '@/components/ui/progress'
|
import { Progress } from '@/components/ui/progress'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
@ -408,9 +408,11 @@ function HubModelDetail() {
|
|||||||
addLocalDownloadingModel(
|
addLocalDownloadingModel(
|
||||||
variant.model_id
|
variant.model_id
|
||||||
)
|
)
|
||||||
pullModel(
|
pullModelWithMetadata(
|
||||||
variant.model_id,
|
variant.model_id,
|
||||||
variant.path
|
variant.path,
|
||||||
|
modelData.mmproj_models?.[0]?.path,
|
||||||
|
huggingfaceToken
|
||||||
)
|
)
|
||||||
}}
|
}}
|
||||||
className={cn(isDownloading && 'hidden')}
|
className={cn(isDownloading && 'hidden')}
|
||||||
|
|||||||
@ -41,7 +41,7 @@ import {
|
|||||||
} from '@/components/ui/dropdown-menu'
|
} from '@/components/ui/dropdown-menu'
|
||||||
import {
|
import {
|
||||||
CatalogModel,
|
CatalogModel,
|
||||||
pullModel,
|
pullModelWithMetadata,
|
||||||
fetchHuggingFaceRepo,
|
fetchHuggingFaceRepo,
|
||||||
convertHfRepoToCatalogModel,
|
convertHfRepoToCatalogModel,
|
||||||
} from '@/services/models'
|
} from '@/services/models'
|
||||||
@ -313,7 +313,12 @@ function Hub() {
|
|||||||
// Immediately set local downloading state
|
// Immediately set local downloading state
|
||||||
addLocalDownloadingModel(modelId)
|
addLocalDownloadingModel(modelId)
|
||||||
const mmprojPath = model.mmproj_models?.[0]?.path
|
const mmprojPath = model.mmproj_models?.[0]?.path
|
||||||
pullModel(modelId, modelUrl, mmprojPath)
|
pullModelWithMetadata(
|
||||||
|
modelId,
|
||||||
|
modelUrl,
|
||||||
|
mmprojPath,
|
||||||
|
huggingfaceToken
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -812,12 +817,13 @@ function Hub() {
|
|||||||
addLocalDownloadingModel(
|
addLocalDownloadingModel(
|
||||||
variant.model_id
|
variant.model_id
|
||||||
)
|
)
|
||||||
pullModel(
|
pullModelWithMetadata(
|
||||||
variant.model_id,
|
variant.model_id,
|
||||||
variant.path,
|
variant.path,
|
||||||
filteredModels[
|
filteredModels[
|
||||||
virtualItem.index
|
virtualItem.index
|
||||||
].mmproj_models?.[0]?.path
|
].mmproj_models?.[0]?.path,
|
||||||
|
huggingfaceToken
|
||||||
)
|
)
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
|
|||||||
@ -325,7 +325,7 @@ describe('models service', () => {
|
|||||||
|
|
||||||
expect(result).toEqual(mockRepoData)
|
expect(result).toEqual(mockRepoData)
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
headers: {},
|
headers: {},
|
||||||
}
|
}
|
||||||
@ -344,7 +344,7 @@ describe('models service', () => {
|
|||||||
'https://huggingface.co/microsoft/DialoGPT-medium'
|
'https://huggingface.co/microsoft/DialoGPT-medium'
|
||||||
)
|
)
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
headers: {},
|
headers: {},
|
||||||
}
|
}
|
||||||
@ -353,7 +353,7 @@ describe('models service', () => {
|
|||||||
// Test with domain prefix
|
// Test with domain prefix
|
||||||
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
headers: {},
|
headers: {},
|
||||||
}
|
}
|
||||||
@ -362,7 +362,7 @@ describe('models service', () => {
|
|||||||
// Test with trailing slash
|
// Test with trailing slash
|
||||||
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
|
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true',
|
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
headers: {},
|
headers: {},
|
||||||
}
|
}
|
||||||
@ -391,7 +391,7 @@ describe('models service', () => {
|
|||||||
|
|
||||||
expect(result).toBeNull()
|
expect(result).toBeNull()
|
||||||
expect(fetch).toHaveBeenCalledWith(
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
'https://huggingface.co/api/models/nonexistent/model?blobs=true',
|
'https://huggingface.co/api/models/nonexistent/model?blobs=true&files_metadata=true',
|
||||||
{
|
{
|
||||||
headers: {},
|
headers: {},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -62,6 +62,11 @@ export interface HuggingFaceRepo {
|
|||||||
rfilename: string
|
rfilename: string
|
||||||
size?: number
|
size?: number
|
||||||
blobId?: string
|
blobId?: string
|
||||||
|
lfs?: {
|
||||||
|
sha256: string
|
||||||
|
size: number
|
||||||
|
pointerSize: number
|
||||||
|
}
|
||||||
}>
|
}>
|
||||||
readme?: string
|
readme?: string
|
||||||
}
|
}
|
||||||
@ -126,7 +131,7 @@ export const fetchHuggingFaceRepo = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true`,
|
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true&files_metadata=true`,
|
||||||
{
|
{
|
||||||
headers: hfToken
|
headers: hfToken
|
||||||
? {
|
? {
|
||||||
@ -237,14 +242,103 @@ export const updateModel = async (
|
|||||||
export const pullModel = async (
|
export const pullModel = async (
|
||||||
id: string,
|
id: string,
|
||||||
modelPath: string,
|
modelPath: string,
|
||||||
mmprojPath?: string
|
modelSha256?: string,
|
||||||
|
modelSize?: number,
|
||||||
|
mmprojPath?: string,
|
||||||
|
mmprojSha256?: string,
|
||||||
|
mmprojSize?: number
|
||||||
) => {
|
) => {
|
||||||
return getEngine()?.import(id, {
|
return getEngine()?.import(id, {
|
||||||
modelPath,
|
modelPath,
|
||||||
mmprojPath,
|
mmprojPath,
|
||||||
|
modelSha256,
|
||||||
|
modelSize,
|
||||||
|
mmprojSha256,
|
||||||
|
mmprojSize,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Pull a model with real-time metadata fetching from HuggingFace.
|
||||||
|
* Extracts hash and size information from the model URL for both main model and mmproj files.
|
||||||
|
* @param id The model ID
|
||||||
|
* @param modelPath The model file URL (HuggingFace download URL)
|
||||||
|
* @param mmprojPath Optional mmproj file URL
|
||||||
|
* @param hfToken Optional HuggingFace token for authentication
|
||||||
|
* @returns A promise that resolves when the model download task is created.
|
||||||
|
*/
|
||||||
|
export const pullModelWithMetadata = async (
|
||||||
|
id: string,
|
||||||
|
modelPath: string,
|
||||||
|
mmprojPath?: string,
|
||||||
|
hfToken?: string
|
||||||
|
) => {
|
||||||
|
let modelSha256: string | undefined
|
||||||
|
let modelSize: number | undefined
|
||||||
|
let mmprojSha256: string | undefined
|
||||||
|
let mmprojSize: number | undefined
|
||||||
|
|
||||||
|
// Extract repo ID from model URL
|
||||||
|
// URL format: https://huggingface.co/{repo}/resolve/main/{filename}
|
||||||
|
const modelUrlMatch = modelPath.match(
|
||||||
|
/https:\/\/huggingface\.co\/([^/]+\/[^/]+)\/resolve\/main\/(.+)/
|
||||||
|
)
|
||||||
|
|
||||||
|
if (modelUrlMatch) {
|
||||||
|
const [, repoId, modelFilename] = modelUrlMatch
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Fetch real-time metadata from HuggingFace
|
||||||
|
const repoInfo = await fetchHuggingFaceRepo(repoId, hfToken)
|
||||||
|
|
||||||
|
if (repoInfo?.siblings) {
|
||||||
|
// Find the specific model file
|
||||||
|
const modelFile = repoInfo.siblings.find(
|
||||||
|
(file) => file.rfilename === modelFilename
|
||||||
|
)
|
||||||
|
if (modelFile?.lfs) {
|
||||||
|
modelSha256 = modelFile.lfs.sha256
|
||||||
|
modelSize = modelFile.lfs.size
|
||||||
|
}
|
||||||
|
|
||||||
|
// If mmproj path provided, extract its metadata too
|
||||||
|
if (mmprojPath) {
|
||||||
|
const mmprojUrlMatch = mmprojPath.match(
|
||||||
|
/https:\/\/huggingface\.co\/[^/]+\/[^/]+\/resolve\/main\/(.+)/
|
||||||
|
)
|
||||||
|
if (mmprojUrlMatch) {
|
||||||
|
const [, mmprojFilename] = mmprojUrlMatch
|
||||||
|
const mmprojFile = repoInfo.siblings.find(
|
||||||
|
(file) => file.rfilename === mmprojFilename
|
||||||
|
)
|
||||||
|
if (mmprojFile?.lfs) {
|
||||||
|
mmprojSha256 = mmprojFile.lfs.sha256
|
||||||
|
mmprojSize = mmprojFile.lfs.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn(
|
||||||
|
'Failed to fetch HuggingFace metadata, proceeding without hash verification:',
|
||||||
|
error
|
||||||
|
)
|
||||||
|
// Continue with download even if metadata fetch fails
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the original pullModel with the fetched metadata
|
||||||
|
return pullModel(
|
||||||
|
id,
|
||||||
|
modelPath,
|
||||||
|
modelSha256,
|
||||||
|
modelSize,
|
||||||
|
mmprojPath,
|
||||||
|
mmprojSha256,
|
||||||
|
mmprojSize
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Aborts a model download.
|
* Aborts a model download.
|
||||||
* @param id
|
* @param id
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user