feat: Model import (download + local import) for llama.cpp extension (#5087)
* add pull and abortPull * add model import (download only) * write model.yaml. support local model import * remove cortex-related command * add TODO * remove cortex-related command
This commit is contained in:
parent
a7a2dcc8d8
commit
ded9ae733a
@ -91,21 +91,6 @@ export interface listOptions {
|
||||
}
|
||||
export type listResult = modelInfo[]
|
||||
|
||||
// 2. /pull
|
||||
export interface pullOptions {
|
||||
providerId: string
|
||||
modelId: string // Identifier for the model to pull (e.g., from a known registry)
|
||||
downloadUrl: string // URL to download the model from
|
||||
/** optional callback to receive download progress */
|
||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void
|
||||
}
|
||||
export interface pullResult {
|
||||
success: boolean
|
||||
path?: string // local file path to the pulled model
|
||||
error?: string
|
||||
modelInfo?: modelInfo // Info of the pulled model
|
||||
}
|
||||
|
||||
// 3. /load
|
||||
export interface loadOptions {
|
||||
modelPath: string
|
||||
@ -174,27 +159,16 @@ export interface deleteResult {
|
||||
}
|
||||
|
||||
// 7. /import
|
||||
export interface importOptions {
|
||||
providerId: string
|
||||
sourcePath: string // Path to the local model file to import
|
||||
desiredModelId?: string // Optional: if user wants to name it specifically
|
||||
export interface ImportOptions {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
export interface importResult {
|
||||
success: boolean
|
||||
modelInfo?: modelInfo
|
||||
error?: string
|
||||
}
|
||||
|
||||
// 8. /abortPull
|
||||
export interface abortPullOptions {
|
||||
providerId: string
|
||||
modelId: string // The modelId whose download is to be aborted
|
||||
}
|
||||
export interface abortPullResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Base AIEngine
|
||||
* Applicable to all AI Engines
|
||||
@ -223,11 +197,6 @@ export abstract class AIEngine extends BaseExtension {
|
||||
*/
|
||||
abstract list(opts: listOptions): Promise<listResult>
|
||||
|
||||
/**
|
||||
* Pulls/downloads a model
|
||||
*/
|
||||
abstract pull(opts: pullOptions): Promise<pullResult>
|
||||
|
||||
/**
|
||||
* Loads a model into memory
|
||||
*/
|
||||
@ -251,12 +220,12 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Imports a model
|
||||
*/
|
||||
abstract import(opts: importOptions): Promise<importResult>
|
||||
abstract import(modelId: string, opts: ImportOptions): Promise<void>
|
||||
|
||||
/**
|
||||
* Aborts an ongoing model pull
|
||||
* Aborts an ongoing model import
|
||||
*/
|
||||
abstract abortPull(opts: abortPullOptions): Promise<abortPullResult>
|
||||
abstract abortImport(modelId: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Optional method to get the underlying chat client
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@janhq/core": "../../core/package.tgz",
|
||||
"@tauri-apps/api": "^1.4.0",
|
||||
"@tauri-apps/api": "^2.5.0",
|
||||
"fetch-retry": "^5.0.6",
|
||||
"ulidx": "^2.3.0"
|
||||
},
|
||||
|
||||
@ -14,8 +14,6 @@ import {
|
||||
modelInfo,
|
||||
listOptions,
|
||||
listResult,
|
||||
pullOptions,
|
||||
pullResult,
|
||||
loadOptions,
|
||||
sessionInfo,
|
||||
unloadOptions,
|
||||
@ -25,14 +23,17 @@ import {
|
||||
chatCompletionChunk,
|
||||
deleteOptions,
|
||||
deleteResult,
|
||||
importOptions,
|
||||
importResult,
|
||||
abortPullOptions,
|
||||
abortPullResult,
|
||||
ImportOptions,
|
||||
chatCompletionRequest,
|
||||
events,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { invoke } from '@tauri-apps/api/tauri'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
interface DownloadItem {
|
||||
url: string
|
||||
save_path: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to convert GGUF model filename to a more structured ID/name
|
||||
@ -62,6 +63,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
provider: string = 'llamacpp'
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
private downloadManager
|
||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||
private modelsBasePath!: string
|
||||
private activeRequests: Map<string, AbortController> = new Map()
|
||||
@ -70,6 +72,8 @@ export default class llamacpp_extension extends AIEngine {
|
||||
super.onLoad() // Calls registerEngine() from AIEngine
|
||||
this.registerSettings(SETTINGS)
|
||||
|
||||
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
||||
|
||||
// Initialize models base path - assuming this would be retrieved from settings
|
||||
this.modelsBasePath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
@ -82,8 +86,91 @@ export default class llamacpp_extension extends AIEngine {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
override async pull(opts: pullOptions): Promise<pullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||||
// TODO: sanitize modelId
|
||||
// TODO: check if modelId already exists
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
|
||||
// this is relative to Jan's data folder
|
||||
const modelDir = `models/${this.provider}/${modelId}`
|
||||
|
||||
// we only use these from opts
|
||||
// opts.modelPath: URL to the model file
|
||||
// opts.mmprojPath: URL to the mmproj file
|
||||
|
||||
let downloadItems: DownloadItem[] = []
|
||||
let modelPath = opts.modelPath
|
||||
let mmprojPath = opts.mmprojPath
|
||||
|
||||
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
|
||||
if (opts.modelPath.startsWith("https://")) {
|
||||
downloadItems.push(modelItem)
|
||||
modelPath = modelItem.save_path
|
||||
} else {
|
||||
// this should be absolute path
|
||||
if (!(await fs.existsSync(modelPath))) {
|
||||
throw new Error(`Model file not found: ${modelPath}`)
|
||||
}
|
||||
}
|
||||
|
||||
if (opts.mmprojPath) {
|
||||
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
|
||||
if (opts.mmprojPath.startsWith("https://")) {
|
||||
downloadItems.push(mmprojItem)
|
||||
mmprojPath = mmprojItem.save_path
|
||||
} else {
|
||||
// this should be absolute path
|
||||
if (!(await fs.existsSync(mmprojPath))) {
|
||||
throw new Error(`MMProj file not found: ${mmprojPath}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (downloadItems.length > 0) {
|
||||
let downloadCompleted = false
|
||||
|
||||
try {
|
||||
// emit download update event on progress
|
||||
const onProgress = (transferred: number, total: number) => {
|
||||
events.emit('onFileDownloadUpdate', {
|
||||
modelId,
|
||||
percent: transferred / total,
|
||||
size: { transferred, total },
|
||||
downloadType: 'Model',
|
||||
})
|
||||
downloadCompleted = transferred === total
|
||||
}
|
||||
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
|
||||
} catch (error) {
|
||||
console.error('Error downloading model:', modelId, opts, error)
|
||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||
throw error
|
||||
}
|
||||
|
||||
// once we reach this point, it either means download finishes or it was cancelled.
|
||||
// if there was an error, it would have been caught above
|
||||
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
|
||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||
}
|
||||
|
||||
// TODO: check if files are valid GGUF files
|
||||
|
||||
await invoke<void>(
|
||||
'write_yaml',
|
||||
{
|
||||
data: {
|
||||
model_path: modelPath,
|
||||
mmproj_path: mmprojPath,
|
||||
},
|
||||
savePath: `${modelDir}/model.yml`,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
// prepand provider name to avoid name collision
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
await this.downloadManager.cancelDownload(taskId)
|
||||
}
|
||||
|
||||
override async load(opts: loadOptions): Promise<sessionInfo> {
|
||||
@ -234,6 +321,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private createDownloadTaskId(modelId: string) {
|
||||
// prepend provider to make taksId unique across providers
|
||||
return `${this.provider}/${modelId}`
|
||||
}
|
||||
|
||||
private async *handleStreamingResponse(
|
||||
url: string,
|
||||
headers: HeadersInit,
|
||||
@ -342,14 +434,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
override async import(opts: importOptions): Promise<importResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
override async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
// Optional method for direct client access
|
||||
override getChatClient(sessionId: string): any {
|
||||
throw new Error('method not implemented yet')
|
||||
|
||||
@ -52,6 +52,7 @@ ash = "0.38.0"
|
||||
nvml-wrapper = "0.10.0"
|
||||
tauri-plugin-deep-link = "2"
|
||||
fix-path-env = { git = "https://github.com/tauri-apps/fix-path-env-rs" }
|
||||
serde_yaml = "0.9.34"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
libloading = "0.8.7"
|
||||
|
||||
@ -283,14 +283,6 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), io::Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn reset_cortex_restart_count(state: State<'_, AppState>) -> Result<(), String> {
|
||||
let mut count = state.cortex_restart_count.lock().await;
|
||||
*count = 0;
|
||||
log::info!("Cortex server restart count reset to 0.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub fn change_app_data_folder(
|
||||
app_handle: tauri::AppHandle,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
pub mod download;
|
||||
pub mod extensions;
|
||||
|
||||
use std::fs;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
@ -76,4 +77,24 @@ pub fn normalize_path(path: &Path) -> PathBuf {
|
||||
}
|
||||
ret
|
||||
}
|
||||
pub mod extensions;
|
||||
|
||||
#[tauri::command]
|
||||
pub fn write_yaml(
|
||||
app: tauri::AppHandle,
|
||||
data: serde_json::Value,
|
||||
save_path: &str,
|
||||
) -> Result<(), String> {
|
||||
// TODO: have an internal function to check scope
|
||||
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||
let save_path = normalize_path(&jan_data_folder.join(save_path));
|
||||
if !save_path.starts_with(&jan_data_folder) {
|
||||
return Err(format!(
|
||||
"Error: save path {} is not under jan_data_folder {}",
|
||||
save_path.to_string_lossy(),
|
||||
jan_data_folder.to_string_lossy(),
|
||||
));
|
||||
}
|
||||
let mut file = fs::File::create(&save_path).map_err(|e| e.to_string())?;
|
||||
serde_yaml::to_writer(&mut file, &data).map_err(|e| e.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -56,7 +56,7 @@ pub fn run() {
|
||||
core::cmd::get_server_status,
|
||||
core::cmd::read_logs,
|
||||
core::cmd::change_app_data_folder,
|
||||
core::cmd::reset_cortex_restart_count,
|
||||
core::migration::get_legacy_browser_data,
|
||||
// MCP commands
|
||||
core::mcp::get_tools,
|
||||
core::mcp::call_tool,
|
||||
@ -79,6 +79,8 @@ pub fn run() {
|
||||
core::threads::get_thread_assistant,
|
||||
core::threads::create_thread_assistant,
|
||||
core::threads::modify_thread_assistant,
|
||||
// generic utils
|
||||
core::utils::write_yaml,
|
||||
// Download
|
||||
core::utils::download::download_files,
|
||||
core::utils::download::cancel_download_task,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user