From f7bcf43334686819d6d57161bb0d30f3b85d0fa9 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 2 Jun 2025 11:58:29 +0800 Subject: [PATCH] update folde structure. small refactoring --- extensions/llamacpp-extension/src/backend.ts | 28 +++- extensions/llamacpp-extension/src/index.ts | 146 ++++++++----------- 2 files changed, 79 insertions(+), 95 deletions(-) diff --git a/extensions/llamacpp-extension/src/backend.ts b/extensions/llamacpp-extension/src/backend.ts index ab9f93a03..9cc69432b 100644 --- a/extensions/llamacpp-extension/src/backend.ts +++ b/extensions/llamacpp-extension/src/backend.ts @@ -7,7 +7,7 @@ import { import { invoke } from '@tauri-apps/api/core' // folder structure -// /llamacpp/backends// +// /llamacpp/backends// // what should be available to the user for selection? export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> { @@ -74,26 +74,38 @@ export async function listSupportedBackends(): Promise<{ version: string, backen return backendVersions } -export async function isBackendInstalled(backend: string, version: string): Promise { +export async function getBackendDir(backend: string, version: string): Promise { + const janDataFolderPath = await getJanDataFolderPath() + const backendDir = await joinPath([janDataFolderPath, 'llamacpp', 'backends', version, backend]) + return backendDir +} + +export async function getBackendExePath(backend: string, version: string): Promise { const sysInfo = await window.core.api.getSystemInfo() const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server' + const backendDir = await getBackendDir(backend, version) + const exePath = await joinPath([backendDir, 'build', 'bin', exe_name]) + return exePath +} - const janDataFolderPath = await getJanDataFolderPath() - const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version, 'build', 'bin', exe_name]) - const result = await fs.existsSync(backendPath) +export async function isBackendInstalled(backend: string, version: string): Promise { + const exePath = await getBackendExePath(backend, version) + const result = await fs.existsSync(exePath) return result } export async function downloadBackend(backend: string, version: string): Promise { const janDataFolderPath = await getJanDataFolderPath() const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp']) + const backendDir = await getBackendDir(backend, version) + const libDir = await joinPath([llamacppPath, 'lib']) const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') const downloadItems = [ { url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`, - save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']), + save_path: await joinPath([backendDir, 'backend.tar.gz']), } ] @@ -101,12 +113,12 @@ export async function downloadBackend(backend: string, version: string): Promise if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) { downloadItems.push({ url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`, - save_path: await joinPath([llamacppPath, 'lib', 'cuda11.tar.gz']), + save_path: await joinPath([libDir, 'cuda11.tar.gz']), }) } else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) { downloadItems.push({ url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`, - save_path: await joinPath([llamacppPath, 'lib', 'cuda12.tar.gz']), + save_path: await joinPath([libDir, 'cuda12.tar.gz']), }) } diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 424f8bcb3..7165ff7b8 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -25,6 +25,7 @@ import { listSupportedBackends, downloadBackend, isBackendInstalled, + getBackendExePath, } from './backend' import { invoke } from '@tauri-apps/api/core' @@ -74,23 +75,27 @@ interface ModelConfig { * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -// Folder structure for downloaded models: -// /models/llamacpp/ -// - model.yml (required) -// - model.gguf (optional, present if downloaded from URL) -// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL) -// +// Folder structure for llamacpp extension: +// /llamacpp +// - models// +// - model.yml (required) +// - model.gguf (optional, present if downloaded from URL) +// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL) // Contents of model.yml can be found in ModelConfig interface +// +// - backends/// +// - build/bin/llama-server (or llama-server.exe on Windows) +// +// - lib/ +// - e.g. libcudart.so.12 export default class llamacpp_extension extends AIEngine { provider: string = 'llamacpp' readonly providerId: string = 'llamacpp' private config: LlamacppConfig - private downloadManager - private downloadBackend // for testing private activeSessions: Map = new Map() - private modelsBasePath!: string + private providerPath!: string private apiSecret: string = 'Jan' override async onLoad(): Promise { @@ -114,7 +119,6 @@ export default class llamacpp_extension extends AIEngine { } this.registerSettings(settings) - this.downloadBackend = downloadBackend let config = {} for (const item of SETTINGS) { @@ -126,14 +130,10 @@ export default class llamacpp_extension extends AIEngine { } this.config = config as LlamacppConfig - this.downloadManager = window.core.extensionManager.getByName( - '@janhq/download-extension' - ) - // Initialize models base path - assuming this would be retrieved from settings - this.modelsBasePath = await joinPath([ + this.providerPath = await joinPath([ await getJanDataFolderPath(), - 'models', + this.providerId, ]) } @@ -178,7 +178,7 @@ export default class llamacpp_extension extends AIEngine { // Implement the required LocalProvider interface methods override async list(): Promise { - const modelsDir = await joinPath([this.modelsBasePath, this.provider]) + const modelsDir = await joinPath([this.providerPath, 'models']) if (!(await fs.existsSync(modelsDir))) { return [] } @@ -215,8 +215,7 @@ export default class llamacpp_extension extends AIEngine { let modelInfos: modelInfo[] = [] for (const modelId of modelIds) { const path = await joinPath([ - this.modelsBasePath, - this.provider, + modelsDir, modelId, 'model.yml', ]) @@ -246,64 +245,46 @@ export default class llamacpp_extension extends AIEngine { return parts.every((s) => s !== '' && s !== '.' && s !== '..') } - if (!isValidModelId(modelId)) { + if (!isValidModelId(modelId)) throw new Error( `Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.` ) - } - let configPath = await joinPath([ - this.modelsBasePath, - this.provider, + const configPath = await joinPath([ + this.providerPath, + 'models', modelId, 'model.yml', ]) - if (await fs.existsSync(configPath)) { + if (await fs.existsSync(configPath)) throw new Error(`Model ${modelId} already exists`) - } - - const taskId = this.createDownloadTaskId(modelId) // this is relative to Jan's data folder - const modelDir = `models/${this.provider}/${modelId}` + const modelDir = `${this.providerId}/models/${modelId}` // we only use these from opts // opts.modelPath: URL to the model file // opts.mmprojPath: URL to the mmproj file let downloadItems: DownloadItem[] = [] - let modelPath = opts.modelPath - let mmprojPath = opts.mmprojPath - const modelItem = { - url: opts.modelPath, - save_path: `${modelDir}/model.gguf`, - } - if (opts.modelPath.startsWith('https://')) { - downloadItems.push(modelItem) - modelPath = modelItem.save_path - } else { - // this should be absolute path - if (!(await fs.existsSync(modelPath))) { - throw new Error(`Model file not found: ${modelPath}`) + const maybeDownload = async (path: string, saveName: string) => { + // if URL, add to downloadItems, and return local path + if (path.startsWith('https://')) { + const localPath = `${modelDir}/${saveName}` + downloadItems.push({ url: path, save_path: localPath }) + return localPath } + + // if local file (absolute path), check if it exists + // and return the path + if (!(await fs.existsSync(path))) + throw new Error(`File not found: ${path}`) + return path } - if (opts.mmprojPath) { - const mmprojItem = { - url: opts.mmprojPath, - save_path: `${modelDir}/mmproj.gguf`, - } - if (opts.mmprojPath.startsWith('https://')) { - downloadItems.push(mmprojItem) - mmprojPath = mmprojItem.save_path - } else { - // this should be absolute path - if (!(await fs.existsSync(mmprojPath))) { - throw new Error(`MMProj file not found: ${mmprojPath}`) - } - } - } + let modelPath = await maybeDownload(opts.modelPath, 'model.gguf') + let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined if (downloadItems.length > 0) { let downloadCompleted = false @@ -319,23 +300,24 @@ export default class llamacpp_extension extends AIEngine { }) downloadCompleted = transferred === total } - await this.downloadManager.downloadFiles( + const downloadManager = window.core.extensionManager.getByName( + '@janhq/download-extension' + ) + await downloadManager.downloadFiles( downloadItems, - taskId, + this.createDownloadTaskId(modelId), onProgress ) + + const eventName = downloadCompleted + ? 'onFileDownloadSuccess' + : 'onFileDownloadStopped' + events.emit(eventName, { modelId, downloadType: 'Model' }) } catch (error) { console.error('Error downloading model:', modelId, opts, error) events.emit('onFileDownloadError', { modelId, downloadType: 'Model' }) throw error } - - // once we reach this point, it either means download finishes or it was cancelled. - // if there was an error, it would have been caught above - const eventName = downloadCompleted - ? 'onFileDownloadSuccess' - : 'onFileDownloadStopped' - events.emit(eventName, { modelId, downloadType: 'Model' }) } // TODO: check if files are valid GGUF files @@ -362,14 +344,17 @@ export default class llamacpp_extension extends AIEngine { await fs.mkdir(await joinPath([janDataFolderPath, modelDir])) await invoke('write_yaml', { data: modelConfig, - savePath: `${modelDir}/model.yml`, + savePath: configPath, }) } override async abortImport(modelId: string): Promise { // prepand provider name to avoid name collision const taskId = this.createDownloadTaskId(modelId) - await this.downloadManager.cancelDownload(taskId) + const downloadManager = window.core.extensionManager.getByName( + '@janhq/download-extension' + ) + await downloadManager.cancelDownload(taskId) } /** @@ -390,31 +375,17 @@ export default class llamacpp_extension extends AIEngine { override async load(modelId: string): Promise { const args: string[] = [] const cfg = this.config - const sysInfo = await window.core.api.getSystemInfo() const [version, backend] = cfg.version_backend.split('/') if (!version || !backend) { - // TODO: sometimes version_backend is not set correctly. to investigate throw new Error( `Invalid version/backend format: ${cfg.version_backend}. Expected format: /` ) } - const exe_name = - sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server' const janDataFolderPath = await getJanDataFolderPath() - const backendPath = await joinPath([ - janDataFolderPath, - 'llamacpp', - 'backends', - backend, - version, - 'build', - 'bin', - exe_name, - ]) const modelConfigPath = await joinPath([ - this.modelsBasePath, - this.provider, + this.providerPath, + 'models', modelId, 'model.yml', ]) @@ -483,8 +454,9 @@ export default class llamacpp_extension extends AIEngine { console.log('Calling Tauri command llama_load with args:', args) try { + // TODO: add LIBRARY_PATH const sInfo = await invoke('load_llama_model', { - backendPath, + backendPath: await getBackendExePath(backend, version), args, }) @@ -636,8 +608,8 @@ export default class llamacpp_extension extends AIEngine { override async delete(modelId: string): Promise { const modelDir = await joinPath([ - this.modelsBasePath, - this.provider, + this.providerPath, + 'models', modelId, ])