feat: Model import (download + local import) for llama.cpp extension (#5087)
* add pull and abortPull * add model import (download only) * write model.yaml. support local model import * remove cortex-related command * add TODO * remove cortex-related command
This commit is contained in:
parent
a7a2dcc8d8
commit
ded9ae733a
@ -91,21 +91,6 @@ export interface listOptions {
|
|||||||
}
|
}
|
||||||
export type listResult = modelInfo[]
|
export type listResult = modelInfo[]
|
||||||
|
|
||||||
// 2. /pull
|
|
||||||
export interface pullOptions {
|
|
||||||
providerId: string
|
|
||||||
modelId: string // Identifier for the model to pull (e.g., from a known registry)
|
|
||||||
downloadUrl: string // URL to download the model from
|
|
||||||
/** optional callback to receive download progress */
|
|
||||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void
|
|
||||||
}
|
|
||||||
export interface pullResult {
|
|
||||||
success: boolean
|
|
||||||
path?: string // local file path to the pulled model
|
|
||||||
error?: string
|
|
||||||
modelInfo?: modelInfo // Info of the pulled model
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. /load
|
// 3. /load
|
||||||
export interface loadOptions {
|
export interface loadOptions {
|
||||||
modelPath: string
|
modelPath: string
|
||||||
@ -174,27 +159,16 @@ export interface deleteResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 7. /import
|
// 7. /import
|
||||||
export interface importOptions {
|
export interface ImportOptions {
|
||||||
providerId: string
|
[key: string]: any
|
||||||
sourcePath: string // Path to the local model file to import
|
|
||||||
desiredModelId?: string // Optional: if user wants to name it specifically
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface importResult {
|
export interface importResult {
|
||||||
success: boolean
|
success: boolean
|
||||||
modelInfo?: modelInfo
|
modelInfo?: modelInfo
|
||||||
error?: string
|
error?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. /abortPull
|
|
||||||
export interface abortPullOptions {
|
|
||||||
providerId: string
|
|
||||||
modelId: string // The modelId whose download is to be aborted
|
|
||||||
}
|
|
||||||
export interface abortPullResult {
|
|
||||||
success: boolean
|
|
||||||
error?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base AIEngine
|
* Base AIEngine
|
||||||
* Applicable to all AI Engines
|
* Applicable to all AI Engines
|
||||||
@ -223,11 +197,6 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
*/
|
*/
|
||||||
abstract list(opts: listOptions): Promise<listResult>
|
abstract list(opts: listOptions): Promise<listResult>
|
||||||
|
|
||||||
/**
|
|
||||||
* Pulls/downloads a model
|
|
||||||
*/
|
|
||||||
abstract pull(opts: pullOptions): Promise<pullResult>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a model into memory
|
* Loads a model into memory
|
||||||
*/
|
*/
|
||||||
@ -251,12 +220,12 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Imports a model
|
* Imports a model
|
||||||
*/
|
*/
|
||||||
abstract import(opts: importOptions): Promise<importResult>
|
abstract import(modelId: string, opts: ImportOptions): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Aborts an ongoing model pull
|
* Aborts an ongoing model import
|
||||||
*/
|
*/
|
||||||
abstract abortPull(opts: abortPullOptions): Promise<abortPullResult>
|
abstract abortImport(modelId: string): Promise<void>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Optional method to get the underlying chat client
|
* Optional method to get the underlying chat client
|
||||||
|
|||||||
@ -21,7 +21,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@janhq/core": "../../core/package.tgz",
|
"@janhq/core": "../../core/package.tgz",
|
||||||
"@tauri-apps/api": "^1.4.0",
|
"@tauri-apps/api": "^2.5.0",
|
||||||
"fetch-retry": "^5.0.6",
|
"fetch-retry": "^5.0.6",
|
||||||
"ulidx": "^2.3.0"
|
"ulidx": "^2.3.0"
|
||||||
},
|
},
|
||||||
|
|||||||
@ -14,8 +14,6 @@ import {
|
|||||||
modelInfo,
|
modelInfo,
|
||||||
listOptions,
|
listOptions,
|
||||||
listResult,
|
listResult,
|
||||||
pullOptions,
|
|
||||||
pullResult,
|
|
||||||
loadOptions,
|
loadOptions,
|
||||||
sessionInfo,
|
sessionInfo,
|
||||||
unloadOptions,
|
unloadOptions,
|
||||||
@ -25,14 +23,17 @@ import {
|
|||||||
chatCompletionChunk,
|
chatCompletionChunk,
|
||||||
deleteOptions,
|
deleteOptions,
|
||||||
deleteResult,
|
deleteResult,
|
||||||
importOptions,
|
ImportOptions,
|
||||||
importResult,
|
|
||||||
abortPullOptions,
|
|
||||||
abortPullResult,
|
|
||||||
chatCompletionRequest,
|
chatCompletionRequest,
|
||||||
|
events,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { invoke } from '@tauri-apps/api/tauri'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
|
|
||||||
|
interface DownloadItem {
|
||||||
|
url: string
|
||||||
|
save_path: string
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Helper to convert GGUF model filename to a more structured ID/name
|
* Helper to convert GGUF model filename to a more structured ID/name
|
||||||
@ -62,6 +63,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
provider: string = 'llamacpp'
|
provider: string = 'llamacpp'
|
||||||
readonly providerId: string = 'llamacpp'
|
readonly providerId: string = 'llamacpp'
|
||||||
|
|
||||||
|
private downloadManager
|
||||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||||
private modelsBasePath!: string
|
private modelsBasePath!: string
|
||||||
private activeRequests: Map<string, AbortController> = new Map()
|
private activeRequests: Map<string, AbortController> = new Map()
|
||||||
@ -70,6 +72,8 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
super.onLoad() // Calls registerEngine() from AIEngine
|
super.onLoad() // Calls registerEngine() from AIEngine
|
||||||
this.registerSettings(SETTINGS)
|
this.registerSettings(SETTINGS)
|
||||||
|
|
||||||
|
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
||||||
|
|
||||||
// Initialize models base path - assuming this would be retrieved from settings
|
// Initialize models base path - assuming this would be retrieved from settings
|
||||||
this.modelsBasePath = await joinPath([
|
this.modelsBasePath = await joinPath([
|
||||||
await getJanDataFolderPath(),
|
await getJanDataFolderPath(),
|
||||||
@ -82,8 +86,91 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
throw new Error('method not implemented yet')
|
throw new Error('method not implemented yet')
|
||||||
}
|
}
|
||||||
|
|
||||||
override async pull(opts: pullOptions): Promise<pullResult> {
|
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||||||
throw new Error('method not implemented yet')
|
// TODO: sanitize modelId
|
||||||
|
// TODO: check if modelId already exists
|
||||||
|
const taskId = this.createDownloadTaskId(modelId)
|
||||||
|
|
||||||
|
// this is relative to Jan's data folder
|
||||||
|
const modelDir = `models/${this.provider}/${modelId}`
|
||||||
|
|
||||||
|
// we only use these from opts
|
||||||
|
// opts.modelPath: URL to the model file
|
||||||
|
// opts.mmprojPath: URL to the mmproj file
|
||||||
|
|
||||||
|
let downloadItems: DownloadItem[] = []
|
||||||
|
let modelPath = opts.modelPath
|
||||||
|
let mmprojPath = opts.mmprojPath
|
||||||
|
|
||||||
|
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
|
||||||
|
if (opts.modelPath.startsWith("https://")) {
|
||||||
|
downloadItems.push(modelItem)
|
||||||
|
modelPath = modelItem.save_path
|
||||||
|
} else {
|
||||||
|
// this should be absolute path
|
||||||
|
if (!(await fs.existsSync(modelPath))) {
|
||||||
|
throw new Error(`Model file not found: ${modelPath}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.mmprojPath) {
|
||||||
|
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
|
||||||
|
if (opts.mmprojPath.startsWith("https://")) {
|
||||||
|
downloadItems.push(mmprojItem)
|
||||||
|
mmprojPath = mmprojItem.save_path
|
||||||
|
} else {
|
||||||
|
// this should be absolute path
|
||||||
|
if (!(await fs.existsSync(mmprojPath))) {
|
||||||
|
throw new Error(`MMProj file not found: ${mmprojPath}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (downloadItems.length > 0) {
|
||||||
|
let downloadCompleted = false
|
||||||
|
|
||||||
|
try {
|
||||||
|
// emit download update event on progress
|
||||||
|
const onProgress = (transferred: number, total: number) => {
|
||||||
|
events.emit('onFileDownloadUpdate', {
|
||||||
|
modelId,
|
||||||
|
percent: transferred / total,
|
||||||
|
size: { transferred, total },
|
||||||
|
downloadType: 'Model',
|
||||||
|
})
|
||||||
|
downloadCompleted = transferred === total
|
||||||
|
}
|
||||||
|
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error downloading model:', modelId, opts, error)
|
||||||
|
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
|
||||||
|
// once we reach this point, it either means download finishes or it was cancelled.
|
||||||
|
// if there was an error, it would have been caught above
|
||||||
|
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
|
||||||
|
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check if files are valid GGUF files
|
||||||
|
|
||||||
|
await invoke<void>(
|
||||||
|
'write_yaml',
|
||||||
|
{
|
||||||
|
data: {
|
||||||
|
model_path: modelPath,
|
||||||
|
mmproj_path: mmprojPath,
|
||||||
|
},
|
||||||
|
savePath: `${modelDir}/model.yml`,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async abortImport(modelId: string): Promise<void> {
|
||||||
|
// prepand provider name to avoid name collision
|
||||||
|
const taskId = this.createDownloadTaskId(modelId)
|
||||||
|
await this.downloadManager.cancelDownload(taskId)
|
||||||
}
|
}
|
||||||
|
|
||||||
override async load(opts: loadOptions): Promise<sessionInfo> {
|
override async load(opts: loadOptions): Promise<sessionInfo> {
|
||||||
@ -234,6 +321,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private createDownloadTaskId(modelId: string) {
|
||||||
|
// prepend provider to make taksId unique across providers
|
||||||
|
return `${this.provider}/${modelId}`
|
||||||
|
}
|
||||||
|
|
||||||
private async *handleStreamingResponse(
|
private async *handleStreamingResponse(
|
||||||
url: string,
|
url: string,
|
||||||
headers: HeadersInit,
|
headers: HeadersInit,
|
||||||
@ -342,14 +434,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
throw new Error('method not implemented yet')
|
throw new Error('method not implemented yet')
|
||||||
}
|
}
|
||||||
|
|
||||||
override async import(opts: importOptions): Promise<importResult> {
|
|
||||||
throw new Error('method not implemented yet')
|
|
||||||
}
|
|
||||||
|
|
||||||
override async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
|
||||||
throw new Error('method not implemented yet')
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optional method for direct client access
|
// Optional method for direct client access
|
||||||
override getChatClient(sessionId: string): any {
|
override getChatClient(sessionId: string): any {
|
||||||
throw new Error('method not implemented yet')
|
throw new Error('method not implemented yet')
|
||||||
|
|||||||
@ -52,6 +52,7 @@ ash = "0.38.0"
|
|||||||
nvml-wrapper = "0.10.0"
|
nvml-wrapper = "0.10.0"
|
||||||
tauri-plugin-deep-link = "2"
|
tauri-plugin-deep-link = "2"
|
||||||
fix-path-env = { git = "https://github.com/tauri-apps/fix-path-env-rs" }
|
fix-path-env = { git = "https://github.com/tauri-apps/fix-path-env-rs" }
|
||||||
|
serde_yaml = "0.9.34"
|
||||||
|
|
||||||
[target.'cfg(windows)'.dependencies]
|
[target.'cfg(windows)'.dependencies]
|
||||||
libloading = "0.8.7"
|
libloading = "0.8.7"
|
||||||
|
|||||||
@ -283,14 +283,6 @@ fn copy_dir_recursive(src: &PathBuf, dst: &PathBuf) -> Result<(), io::Error> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
|
||||||
pub async fn reset_cortex_restart_count(state: State<'_, AppState>) -> Result<(), String> {
|
|
||||||
let mut count = state.cortex_restart_count.lock().await;
|
|
||||||
*count = 0;
|
|
||||||
log::info!("Cortex server restart count reset to 0.");
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub fn change_app_data_folder(
|
pub fn change_app_data_folder(
|
||||||
app_handle: tauri::AppHandle,
|
app_handle: tauri::AppHandle,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
pub mod download;
|
pub mod download;
|
||||||
|
pub mod extensions;
|
||||||
|
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Component, Path, PathBuf};
|
use std::path::{Component, Path, PathBuf};
|
||||||
@ -76,4 +77,24 @@ pub fn normalize_path(path: &Path) -> PathBuf {
|
|||||||
}
|
}
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
pub mod extensions;
|
|
||||||
|
#[tauri::command]
|
||||||
|
pub fn write_yaml(
|
||||||
|
app: tauri::AppHandle,
|
||||||
|
data: serde_json::Value,
|
||||||
|
save_path: &str,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
// TODO: have an internal function to check scope
|
||||||
|
let jan_data_folder = get_jan_data_folder_path(app.clone());
|
||||||
|
let save_path = normalize_path(&jan_data_folder.join(save_path));
|
||||||
|
if !save_path.starts_with(&jan_data_folder) {
|
||||||
|
return Err(format!(
|
||||||
|
"Error: save path {} is not under jan_data_folder {}",
|
||||||
|
save_path.to_string_lossy(),
|
||||||
|
jan_data_folder.to_string_lossy(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let mut file = fs::File::create(&save_path).map_err(|e| e.to_string())?;
|
||||||
|
serde_yaml::to_writer(&mut file, &data).map_err(|e| e.to_string())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|||||||
@ -56,7 +56,7 @@ pub fn run() {
|
|||||||
core::cmd::get_server_status,
|
core::cmd::get_server_status,
|
||||||
core::cmd::read_logs,
|
core::cmd::read_logs,
|
||||||
core::cmd::change_app_data_folder,
|
core::cmd::change_app_data_folder,
|
||||||
core::cmd::reset_cortex_restart_count,
|
core::migration::get_legacy_browser_data,
|
||||||
// MCP commands
|
// MCP commands
|
||||||
core::mcp::get_tools,
|
core::mcp::get_tools,
|
||||||
core::mcp::call_tool,
|
core::mcp::call_tool,
|
||||||
@ -79,6 +79,8 @@ pub fn run() {
|
|||||||
core::threads::get_thread_assistant,
|
core::threads::get_thread_assistant,
|
||||||
core::threads::create_thread_assistant,
|
core::threads::create_thread_assistant,
|
||||||
core::threads::modify_thread_assistant,
|
core::threads::modify_thread_assistant,
|
||||||
|
// generic utils
|
||||||
|
core::utils::write_yaml,
|
||||||
// Download
|
// Download
|
||||||
core::utils::download::download_files,
|
core::utils::download::download_files,
|
||||||
core::utils::download::cancel_download_task,
|
core::utils::download::cancel_download_task,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user