diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 75f74f16c..f46b00a13 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -91,21 +91,6 @@ export interface listOptions { } export type listResult = modelInfo[] -// 2. /pull -export interface pullOptions { - providerId: string - modelId: string // Identifier for the model to pull (e.g., from a known registry) - downloadUrl: string // URL to download the model from - /** optional callback to receive download progress */ - onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void -} -export interface pullResult { - success: boolean - path?: string // local file path to the pulled model - error?: string - modelInfo?: modelInfo // Info of the pulled model -} - // 3. /load export interface loadOptions { modelPath: string @@ -174,27 +159,16 @@ export interface deleteResult { } // 7. /import -export interface importOptions { - providerId: string - sourcePath: string // Path to the local model file to import - desiredModelId?: string // Optional: if user wants to name it specifically +export interface ImportOptions { + [key: string]: any } + export interface importResult { success: boolean modelInfo?: modelInfo error?: string } -// 8. /abortPull -export interface abortPullOptions { - providerId: string - modelId: string // The modelId whose download is to be aborted -} -export interface abortPullResult { - success: boolean - error?: string -} - /** * Base AIEngine * Applicable to all AI Engines @@ -223,11 +197,6 @@ export abstract class AIEngine extends BaseExtension { */ abstract list(opts: listOptions): Promise - /** - * Pulls/downloads a model - */ - abstract pull(opts: pullOptions): Promise - /** * Loads a model into memory */ @@ -251,12 +220,12 @@ export abstract class AIEngine extends BaseExtension { /** * Imports a model */ - abstract import(opts: importOptions): Promise + abstract import(modelId: string, opts: ImportOptions): Promise /** - * Aborts an ongoing model pull + * Aborts an ongoing model import */ - abstract abortPull(opts: abortPullOptions): Promise + abstract abortImport(modelId: string): Promise /** * Optional method to get the underlying chat client diff --git a/extensions/llamacpp-extension/package.json b/extensions/llamacpp-extension/package.json index 746ad9d06..10be232d9 100644 --- a/extensions/llamacpp-extension/package.json +++ b/extensions/llamacpp-extension/package.json @@ -21,7 +21,7 @@ }, "dependencies": { "@janhq/core": "../../core/package.tgz", - "@tauri-apps/api": "^1.4.0", + "@tauri-apps/api": "^2.5.0", "fetch-retry": "^5.0.6", "ulidx": "^2.3.0" }, diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 3edd1b6e5..77257f17e 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -14,8 +14,6 @@ import { modelInfo, listOptions, listResult, - pullOptions, - pullResult, loadOptions, sessionInfo, unloadOptions, @@ -25,14 +23,17 @@ import { chatCompletionChunk, deleteOptions, deleteResult, - importOptions, - importResult, - abortPullOptions, - abortPullResult, + ImportOptions, chatCompletionRequest, + events, } from '@janhq/core' -import { invoke } from '@tauri-apps/api/tauri' +import { invoke } from '@tauri-apps/api/core' + +interface DownloadItem { + url: string + save_path: string +} /** * Helper to convert GGUF model filename to a more structured ID/name @@ -62,6 +63,7 @@ export default class llamacpp_extension extends AIEngine { provider: string = 'llamacpp' readonly providerId: string = 'llamacpp' + private downloadManager private activeSessions: Map = new Map() private modelsBasePath!: string private activeRequests: Map = new Map() @@ -70,6 +72,8 @@ export default class llamacpp_extension extends AIEngine { super.onLoad() // Calls registerEngine() from AIEngine this.registerSettings(SETTINGS) + this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') + // Initialize models base path - assuming this would be retrieved from settings this.modelsBasePath = await joinPath([ await getJanDataFolderPath(), @@ -82,8 +86,91 @@ export default class llamacpp_extension extends AIEngine { throw new Error('method not implemented yet') } - override async pull(opts: pullOptions): Promise { - throw new Error('method not implemented yet') + override async import(modelId: string, opts: ImportOptions): Promise { + // TODO: sanitize modelId + // TODO: check if modelId already exists + const taskId = this.createDownloadTaskId(modelId) + + // this is relative to Jan's data folder + const modelDir = `models/${this.provider}/${modelId}` + + // we only use these from opts + // opts.modelPath: URL to the model file + // opts.mmprojPath: URL to the mmproj file + + let downloadItems: DownloadItem[] = [] + let modelPath = opts.modelPath + let mmprojPath = opts.mmprojPath + + const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` } + if (opts.modelPath.startsWith("https://")) { + downloadItems.push(modelItem) + modelPath = modelItem.save_path + } else { + // this should be absolute path + if (!(await fs.existsSync(modelPath))) { + throw new Error(`Model file not found: ${modelPath}`) + } + } + + if (opts.mmprojPath) { + const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` } + if (opts.mmprojPath.startsWith("https://")) { + downloadItems.push(mmprojItem) + mmprojPath = mmprojItem.save_path + } else { + // this should be absolute path + if (!(await fs.existsSync(mmprojPath))) { + throw new Error(`MMProj file not found: ${mmprojPath}`) + } + } + } + + if (downloadItems.length > 0) { + let downloadCompleted = false + + try { + // emit download update event on progress + const onProgress = (transferred: number, total: number) => { + events.emit('onFileDownloadUpdate', { + modelId, + percent: transferred / total, + size: { transferred, total }, + downloadType: 'Model', + }) + downloadCompleted = transferred === total + } + await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress) + } catch (error) { + console.error('Error downloading model:', modelId, opts, error) + events.emit('onFileDownloadError', { modelId, downloadType: 'Model' }) + throw error + } + + // once we reach this point, it either means download finishes or it was cancelled. + // if there was an error, it would have been caught above + const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped' + events.emit(eventName, { modelId, downloadType: 'Model' }) + } + + // TODO: check if files are valid GGUF files + + await invoke( + 'write_yaml', + { + data: { + model_path: modelPath, + mmproj_path: mmprojPath, + }, + savePath: `${modelDir}/model.yml`, + }, + ) + } + + override async abortImport(modelId: string): Promise { + // prepand provider name to avoid name collision + const taskId = this.createDownloadTaskId(modelId) + await this.downloadManager.cancelDownload(taskId) } override async load(opts: loadOptions): Promise { @@ -234,6 +321,11 @@ export default class llamacpp_extension extends AIEngine { } } + private createDownloadTaskId(modelId: string) { + // prepend provider to make taksId unique across providers + return `${this.provider}/${modelId}` + } + private async *handleStreamingResponse( url: string, headers: HeadersInit, @@ -342,14 +434,6 @@ export default class llamacpp_extension extends AIEngine { throw new Error('method not implemented yet') } - override async import(opts: importOptions): Promise { - throw new Error('method not implemented yet') - } - - override async abortPull(opts: abortPullOptions): Promise { - throw new Error('method not implemented yet') - } - // Optional method for direct client access override getChatClient(sessionId: string): any { throw new Error('method not implemented yet') diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 04843c1a9..df939aa55 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -52,6 +52,7 @@ ash = "0.38.0" nvml-wrapper = "0.10.0" tauri-plugin-deep-link = "2" fix-path-env = { git = "https://github.com/tauri-apps/fix-path-env-rs" } +serde_yaml = "0.9.34" [target.'cfg(windows)'.dependencies] libloading = "0.8.7" diff --git a/src-tauri/src/core/cmd.rs b/src-tauri/src/core/cmd.rs index 4b4463d12..8027eb5d5 100644 --- a/src-tauri/src/core/cmd.rs +++ b/src-tauri/src/core/cmd.rs @@ -283,14 +283,6 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), io::Error> { Ok(()) } -#[tauri::command] -pub async fn reset_cortex_restart_count(state: State<'_, AppState>) -> Result<(), String> { - let mut count = state.cortex_restart_count.lock().await; - *count = 0; - log::info!("Cortex server restart count reset to 0."); - Ok(()) -} - #[tauri::command] pub fn change_app_data_folder( app_handle: tauri::AppHandle, diff --git a/src-tauri/src/core/utils/mod.rs b/src-tauri/src/core/utils/mod.rs index faf8aeda2..2880b0e1d 100644 --- a/src-tauri/src/core/utils/mod.rs +++ b/src-tauri/src/core/utils/mod.rs @@ -1,4 +1,5 @@ pub mod download; +pub mod extensions; use std::fs; use std::path::{Component, Path, PathBuf}; @@ -76,4 +77,24 @@ pub fn normalize_path(path: &Path) -> PathBuf { } ret } -pub mod extensions; + +#[tauri::command] +pub fn write_yaml( + app: tauri::AppHandle, + data: serde_json::Value, + save_path: &str, +) -> Result<(), String> { + // TODO: have an internal function to check scope + let jan_data_folder = get_jan_data_folder_path(app.clone()); + let save_path = normalize_path(&jan_data_folder.join(save_path)); + if !save_path.starts_with(&jan_data_folder) { + return Err(format!( + "Error: save path {} is not under jan_data_folder {}", + save_path.to_string_lossy(), + jan_data_folder.to_string_lossy(), + )); + } + let mut file = fs::File::create(&save_path).map_err(|e| e.to_string())?; + serde_yaml::to_writer(&mut file, &data).map_err(|e| e.to_string())?; + Ok(()) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 23636a883..a3791552b 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -56,7 +56,7 @@ pub fn run() { core::cmd::get_server_status, core::cmd::read_logs, core::cmd::change_app_data_folder, - core::cmd::reset_cortex_restart_count, + core::migration::get_legacy_browser_data, // MCP commands core::mcp::get_tools, core::mcp::call_tool, @@ -79,6 +79,8 @@ pub fn run() { core::threads::get_thread_assistant, core::threads::create_thread_assistant, core::threads::modify_thread_assistant, + // generic utils + core::utils::write_yaml, // Download core::utils::download::download_files, core::utils::download::cancel_download_task,