feat: Model import (download + local import) for llama.cpp extension (#5087)

* add pull and abortPull

* add model import (download only)

* write model.yaml. support local model import

* remove cortex-related command

* add TODO

* remove cortex-related command
This commit is contained in:
Thien Tran 2025-05-23 22:39:23 +08:00 committed by Louis
parent a7a2dcc8d8
commit ded9ae733a
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 134 additions and 65 deletions

View File

@ -91,21 +91,6 @@ export interface listOptions {
}
export type listResult = modelInfo[]
// 2. /pull
export interface pullOptions {
providerId: string
modelId: string // Identifier for the model to pull (e.g., from a known registry)
downloadUrl: string // URL to download the model from
/** optional callback to receive download progress */
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void
}
export interface pullResult {
success: boolean
path?: string // local file path to the pulled model
error?: string
modelInfo?: modelInfo // Info of the pulled model
}
// 3. /load
export interface loadOptions {
modelPath: string
@ -174,27 +159,16 @@ export interface deleteResult {
}
// 7. /import
export interface importOptions {
providerId: string
sourcePath: string // Path to the local model file to import
desiredModelId?: string // Optional: if user wants to name it specifically
export interface ImportOptions {
[key: string]: any
}
export interface importResult {
success: boolean
modelInfo?: modelInfo
error?: string
}
// 8. /abortPull
export interface abortPullOptions {
providerId: string
modelId: string // The modelId whose download is to be aborted
}
export interface abortPullResult {
success: boolean
error?: string
}
/**
* Base AIEngine
* Applicable to all AI Engines
@ -223,11 +197,6 @@ export abstract class AIEngine extends BaseExtension {
*/
abstract list(opts: listOptions): Promise<listResult>
/**
* Pulls/downloads a model
*/
abstract pull(opts: pullOptions): Promise<pullResult>
/**
* Loads a model into memory
*/
@ -251,12 +220,12 @@ export abstract class AIEngine extends BaseExtension {
/**
* Imports a model
*/
abstract import(opts: importOptions): Promise<importResult>
abstract import(modelId: string, opts: ImportOptions): Promise<void>
/**
* Aborts an ongoing model pull
* Aborts an ongoing model import
*/
abstract abortPull(opts: abortPullOptions): Promise<abortPullResult>
abstract abortImport(modelId: string): Promise<void>
/**
* Optional method to get the underlying chat client

View File

@ -21,7 +21,7 @@
},
"dependencies": {
"@janhq/core": "../../core/package.tgz",
"@tauri-apps/api": "^1.4.0",
"@tauri-apps/api": "^2.5.0",
"fetch-retry": "^5.0.6",
"ulidx": "^2.3.0"
},

View File

@ -14,8 +14,6 @@ import {
modelInfo,
listOptions,
listResult,
pullOptions,
pullResult,
loadOptions,
sessionInfo,
unloadOptions,
@ -25,14 +23,17 @@ import {
chatCompletionChunk,
deleteOptions,
deleteResult,
importOptions,
importResult,
abortPullOptions,
abortPullResult,
ImportOptions,
chatCompletionRequest,
events,
} from '@janhq/core'
import { invoke } from '@tauri-apps/api/tauri'
import { invoke } from '@tauri-apps/api/core'
interface DownloadItem {
url: string
save_path: string
}
/**
* Helper to convert GGUF model filename to a more structured ID/name
@ -62,6 +63,7 @@ export default class llamacpp_extension extends AIEngine {
provider: string = 'llamacpp'
readonly providerId: string = 'llamacpp'
private downloadManager
private activeSessions: Map<string, sessionInfo> = new Map()
private modelsBasePath!: string
private activeRequests: Map<string, AbortController> = new Map()
@ -70,6 +72,8 @@ export default class llamacpp_extension extends AIEngine {
super.onLoad() // Calls registerEngine() from AIEngine
this.registerSettings(SETTINGS)
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
// Initialize models base path - assuming this would be retrieved from settings
this.modelsBasePath = await joinPath([
await getJanDataFolderPath(),
@ -82,8 +86,91 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('method not implemented yet')
}
override async pull(opts: pullOptions): Promise<pullResult> {
throw new Error('method not implemented yet')
override async import(modelId: string, opts: ImportOptions): Promise<void> {
// TODO: sanitize modelId
// TODO: check if modelId already exists
const taskId = this.createDownloadTaskId(modelId)
// this is relative to Jan's data folder
const modelDir = `models/${this.provider}/${modelId}`
// we only use these from opts
// opts.modelPath: URL to the model file
// opts.mmprojPath: URL to the mmproj file
let downloadItems: DownloadItem[] = []
let modelPath = opts.modelPath
let mmprojPath = opts.mmprojPath
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
if (opts.modelPath.startsWith("https://")) {
downloadItems.push(modelItem)
modelPath = modelItem.save_path
} else {
// this should be absolute path
if (!(await fs.existsSync(modelPath))) {
throw new Error(`Model file not found: ${modelPath}`)
}
}
if (opts.mmprojPath) {
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
if (opts.mmprojPath.startsWith("https://")) {
downloadItems.push(mmprojItem)
mmprojPath = mmprojItem.save_path
} else {
// this should be absolute path
if (!(await fs.existsSync(mmprojPath))) {
throw new Error(`MMProj file not found: ${mmprojPath}`)
}
}
}
if (downloadItems.length > 0) {
let downloadCompleted = false
try {
// emit download update event on progress
const onProgress = (transferred: number, total: number) => {
events.emit('onFileDownloadUpdate', {
modelId,
percent: transferred / total,
size: { transferred, total },
downloadType: 'Model',
})
downloadCompleted = transferred === total
}
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
} catch (error) {
console.error('Error downloading model:', modelId, opts, error)
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
throw error
}
// once we reach this point, it either means download finishes or it was cancelled.
// if there was an error, it would have been caught above
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
events.emit(eventName, { modelId, downloadType: 'Model' })
}
// TODO: check if files are valid GGUF files
await invoke<void>(
'write_yaml',
{
data: {
model_path: modelPath,
mmproj_path: mmprojPath,
},
savePath: `${modelDir}/model.yml`,
},
)
}
override async abortImport(modelId: string): Promise<void> {
// prepand provider name to avoid name collision
const taskId = this.createDownloadTaskId(modelId)
await this.downloadManager.cancelDownload(taskId)
}
override async load(opts: loadOptions): Promise<sessionInfo> {
@ -234,6 +321,11 @@ export default class llamacpp_extension extends AIEngine {
}
}
private createDownloadTaskId(modelId: string) {
// prepend provider to make taksId unique across providers
return `${this.provider}/${modelId}`
}
private async *handleStreamingResponse(
url: string,
headers: HeadersInit,
@ -342,14 +434,6 @@ export default class llamacpp_extension extends AIEngine {
throw new Error('method not implemented yet')
}
override async import(opts: importOptions): Promise<importResult> {
throw new Error('method not implemented yet')
}
override async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
throw new Error('method not implemented yet')
}
// Optional method for direct client access
override getChatClient(sessionId: string): any {
throw new Error('method not implemented yet')

View File

@ -52,6 +52,7 @@ ash = "0.38.0"
nvml-wrapper = "0.10.0"
tauri-plugin-deep-link = "2"
fix-path-env = { git = "https://github.com/tauri-apps/fix-path-env-rs" }
serde_yaml = "0.9.34"
[target.'cfg(windows)'.dependencies]
libloading = "0.8.7"

View File

@ -283,14 +283,6 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), io::Error> {
Ok(())
}
#[tauri::command]
pub async fn reset_cortex_restart_count(state: State<'_, AppState>) -> Result<(), String> {
let mut count = state.cortex_restart_count.lock().await;
*count = 0;
log::info!("Cortex server restart count reset to 0.");
Ok(())
}
#[tauri::command]
pub fn change_app_data_folder(
app_handle: tauri::AppHandle,

View File

@ -1,4 +1,5 @@
pub mod download;
pub mod extensions;
use std::fs;
use std::path::{Component, Path, PathBuf};
@ -76,4 +77,24 @@ pub fn normalize_path(path: &Path) -> PathBuf {
}
ret
}
pub mod extensions;
#[tauri::command]
pub fn write_yaml(
app: tauri::AppHandle,
data: serde_json::Value,
save_path: &str,
) -> Result<(), String> {
// TODO: have an internal function to check scope
let jan_data_folder = get_jan_data_folder_path(app.clone());
let save_path = normalize_path(&jan_data_folder.join(save_path));
if !save_path.starts_with(&jan_data_folder) {
return Err(format!(
"Error: save path {} is not under jan_data_folder {}",
save_path.to_string_lossy(),
jan_data_folder.to_string_lossy(),
));
}
let mut file = fs::File::create(&save_path).map_err(|e| e.to_string())?;
serde_yaml::to_writer(&mut file, &data).map_err(|e| e.to_string())?;
Ok(())
}

View File

@ -56,7 +56,7 @@ pub fn run() {
core::cmd::get_server_status,
core::cmd::read_logs,
core::cmd::change_app_data_folder,
core::cmd::reset_cortex_restart_count,
core::migration::get_legacy_browser_data,
// MCP commands
core::mcp::get_tools,
core::mcp::call_tool,
@ -79,6 +79,8 @@ pub fn run() {
core::threads::get_thread_assistant,
core::threads::create_thread_assistant,
core::threads::modify_thread_assistant,
// generic utils
core::utils::write_yaml,
// Download
core::utils::download::download_files,
core::utils::download::cancel_download_task,