update folde structure. small refactoring

This commit is contained in:
Thien Tran 2025-06-02 11:58:29 +08:00 committed by Louis
parent 3b72d80979
commit f7bcf43334
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
2 changed files with 79 additions and 95 deletions

View File

@ -7,7 +7,7 @@ import {
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
// folder structure // folder structure
// <Jan's data folder>/llamacpp/backends/<backend_name>/<backend_version> // <Jan's data folder>/llamacpp/backends/<backend_version>/<backend_type>
// what should be available to the user for selection? // what should be available to the user for selection?
export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> { export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> {
@ -74,26 +74,38 @@ export async function listSupportedBackends(): Promise<{ version: string, backen
return backendVersions return backendVersions
} }
export async function isBackendInstalled(backend: string, version: string): Promise<boolean> { export async function getBackendDir(backend: string, version: string): Promise<string> {
const janDataFolderPath = await getJanDataFolderPath()
const backendDir = await joinPath([janDataFolderPath, 'llamacpp', 'backends', version, backend])
return backendDir
}
export async function getBackendExePath(backend: string, version: string): Promise<string> {
const sysInfo = await window.core.api.getSystemInfo() const sysInfo = await window.core.api.getSystemInfo()
const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server' const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
const backendDir = await getBackendDir(backend, version)
const exePath = await joinPath([backendDir, 'build', 'bin', exe_name])
return exePath
}
const janDataFolderPath = await getJanDataFolderPath() export async function isBackendInstalled(backend: string, version: string): Promise<boolean> {
const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version, 'build', 'bin', exe_name]) const exePath = await getBackendExePath(backend, version)
const result = await fs.existsSync(backendPath) const result = await fs.existsSync(exePath)
return result return result
} }
export async function downloadBackend(backend: string, version: string): Promise<void> { export async function downloadBackend(backend: string, version: string): Promise<void> {
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp']) const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
const backendDir = await getBackendDir(backend, version)
const libDir = await joinPath([llamacppPath, 'lib'])
const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
const downloadItems = [ const downloadItems = [
{ {
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`, url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`,
save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']), save_path: await joinPath([backendDir, 'backend.tar.gz']),
} }
] ]
@ -101,12 +113,12 @@ export async function downloadBackend(backend: string, version: string): Promise
if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) { if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) {
downloadItems.push({ downloadItems.push({
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`, url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`,
save_path: await joinPath([llamacppPath, 'lib', 'cuda11.tar.gz']), save_path: await joinPath([libDir, 'cuda11.tar.gz']),
}) })
} else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) { } else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) {
downloadItems.push({ downloadItems.push({
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`, url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`,
save_path: await joinPath([llamacppPath, 'lib', 'cuda12.tar.gz']), save_path: await joinPath([libDir, 'cuda12.tar.gz']),
}) })
} }

View File

@ -25,6 +25,7 @@ import {
listSupportedBackends, listSupportedBackends,
downloadBackend, downloadBackend,
isBackendInstalled, isBackendInstalled,
getBackendExePath,
} from './backend' } from './backend'
import { invoke } from '@tauri-apps/api/core' import { invoke } from '@tauri-apps/api/core'
@ -74,23 +75,27 @@ interface ModelConfig {
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
// Folder structure for downloaded models: // Folder structure for llamacpp extension:
// <Jan's data folder>/models/llamacpp/<modelId> // <Jan's data folder>/llamacpp
// - models/<modelId>/
// - model.yml (required) // - model.yml (required)
// - model.gguf (optional, present if downloaded from URL) // - model.gguf (optional, present if downloaded from URL)
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL) // - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
//
// Contents of model.yml can be found in ModelConfig interface // Contents of model.yml can be found in ModelConfig interface
//
// - backends/<backend_version>/<backend_type>/
// - build/bin/llama-server (or llama-server.exe on Windows)
//
// - lib/
// - e.g. libcudart.so.12
export default class llamacpp_extension extends AIEngine { export default class llamacpp_extension extends AIEngine {
provider: string = 'llamacpp' provider: string = 'llamacpp'
readonly providerId: string = 'llamacpp' readonly providerId: string = 'llamacpp'
private config: LlamacppConfig private config: LlamacppConfig
private downloadManager
private downloadBackend // for testing
private activeSessions: Map<string, sessionInfo> = new Map() private activeSessions: Map<string, sessionInfo> = new Map()
private modelsBasePath!: string private providerPath!: string
private apiSecret: string = 'Jan' private apiSecret: string = 'Jan'
override async onLoad(): Promise<void> { override async onLoad(): Promise<void> {
@ -114,7 +119,6 @@ export default class llamacpp_extension extends AIEngine {
} }
this.registerSettings(settings) this.registerSettings(settings)
this.downloadBackend = downloadBackend
let config = {} let config = {}
for (const item of SETTINGS) { for (const item of SETTINGS) {
@ -126,14 +130,10 @@ export default class llamacpp_extension extends AIEngine {
} }
this.config = config as LlamacppConfig this.config = config as LlamacppConfig
this.downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
// Initialize models base path - assuming this would be retrieved from settings // Initialize models base path - assuming this would be retrieved from settings
this.modelsBasePath = await joinPath([ this.providerPath = await joinPath([
await getJanDataFolderPath(), await getJanDataFolderPath(),
'models', this.providerId,
]) ])
} }
@ -178,7 +178,7 @@ export default class llamacpp_extension extends AIEngine {
// Implement the required LocalProvider interface methods // Implement the required LocalProvider interface methods
override async list(): Promise<modelInfo[]> { override async list(): Promise<modelInfo[]> {
const modelsDir = await joinPath([this.modelsBasePath, this.provider]) const modelsDir = await joinPath([this.providerPath, 'models'])
if (!(await fs.existsSync(modelsDir))) { if (!(await fs.existsSync(modelsDir))) {
return [] return []
} }
@ -215,8 +215,7 @@ export default class llamacpp_extension extends AIEngine {
let modelInfos: modelInfo[] = [] let modelInfos: modelInfo[] = []
for (const modelId of modelIds) { for (const modelId of modelIds) {
const path = await joinPath([ const path = await joinPath([
this.modelsBasePath, modelsDir,
this.provider,
modelId, modelId,
'model.yml', 'model.yml',
]) ])
@ -246,65 +245,47 @@ export default class llamacpp_extension extends AIEngine {
return parts.every((s) => s !== '' && s !== '.' && s !== '..') return parts.every((s) => s !== '' && s !== '.' && s !== '..')
} }
if (!isValidModelId(modelId)) { if (!isValidModelId(modelId))
throw new Error( throw new Error(
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.` `Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
) )
}
let configPath = await joinPath([ const configPath = await joinPath([
this.modelsBasePath, this.providerPath,
this.provider, 'models',
modelId, modelId,
'model.yml', 'model.yml',
]) ])
if (await fs.existsSync(configPath)) { if (await fs.existsSync(configPath))
throw new Error(`Model ${modelId} already exists`) throw new Error(`Model ${modelId} already exists`)
}
const taskId = this.createDownloadTaskId(modelId)
// this is relative to Jan's data folder // this is relative to Jan's data folder
const modelDir = `models/${this.provider}/${modelId}` const modelDir = `${this.providerId}/models/${modelId}`
// we only use these from opts // we only use these from opts
// opts.modelPath: URL to the model file // opts.modelPath: URL to the model file
// opts.mmprojPath: URL to the mmproj file // opts.mmprojPath: URL to the mmproj file
let downloadItems: DownloadItem[] = [] let downloadItems: DownloadItem[] = []
let modelPath = opts.modelPath
let mmprojPath = opts.mmprojPath
const modelItem = { const maybeDownload = async (path: string, saveName: string) => {
url: opts.modelPath, // if URL, add to downloadItems, and return local path
save_path: `${modelDir}/model.gguf`, if (path.startsWith('https://')) {
} const localPath = `${modelDir}/${saveName}`
if (opts.modelPath.startsWith('https://')) { downloadItems.push({ url: path, save_path: localPath })
downloadItems.push(modelItem) return localPath
modelPath = modelItem.save_path
} else {
// this should be absolute path
if (!(await fs.existsSync(modelPath))) {
throw new Error(`Model file not found: ${modelPath}`)
}
} }
if (opts.mmprojPath) { // if local file (absolute path), check if it exists
const mmprojItem = { // and return the path
url: opts.mmprojPath, if (!(await fs.existsSync(path)))
save_path: `${modelDir}/mmproj.gguf`, throw new Error(`File not found: ${path}`)
} return path
if (opts.mmprojPath.startsWith('https://')) {
downloadItems.push(mmprojItem)
mmprojPath = mmprojItem.save_path
} else {
// this should be absolute path
if (!(await fs.existsSync(mmprojPath))) {
throw new Error(`MMProj file not found: ${mmprojPath}`)
}
}
} }
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined
if (downloadItems.length > 0) { if (downloadItems.length > 0) {
let downloadCompleted = false let downloadCompleted = false
@ -319,23 +300,24 @@ export default class llamacpp_extension extends AIEngine {
}) })
downloadCompleted = transferred === total downloadCompleted = transferred === total
} }
await this.downloadManager.downloadFiles( const downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
await downloadManager.downloadFiles(
downloadItems, downloadItems,
taskId, this.createDownloadTaskId(modelId),
onProgress onProgress
) )
const eventName = downloadCompleted
? 'onFileDownloadSuccess'
: 'onFileDownloadStopped'
events.emit(eventName, { modelId, downloadType: 'Model' })
} catch (error) { } catch (error) {
console.error('Error downloading model:', modelId, opts, error) console.error('Error downloading model:', modelId, opts, error)
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' }) events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
throw error throw error
} }
// once we reach this point, it either means download finishes or it was cancelled.
// if there was an error, it would have been caught above
const eventName = downloadCompleted
? 'onFileDownloadSuccess'
: 'onFileDownloadStopped'
events.emit(eventName, { modelId, downloadType: 'Model' })
} }
// TODO: check if files are valid GGUF files // TODO: check if files are valid GGUF files
@ -362,14 +344,17 @@ export default class llamacpp_extension extends AIEngine {
await fs.mkdir(await joinPath([janDataFolderPath, modelDir])) await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
await invoke<void>('write_yaml', { await invoke<void>('write_yaml', {
data: modelConfig, data: modelConfig,
savePath: `${modelDir}/model.yml`, savePath: configPath,
}) })
} }
override async abortImport(modelId: string): Promise<void> { override async abortImport(modelId: string): Promise<void> {
// prepand provider name to avoid name collision // prepand provider name to avoid name collision
const taskId = this.createDownloadTaskId(modelId) const taskId = this.createDownloadTaskId(modelId)
await this.downloadManager.cancelDownload(taskId) const downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
await downloadManager.cancelDownload(taskId)
} }
/** /**
@ -390,31 +375,17 @@ export default class llamacpp_extension extends AIEngine {
override async load(modelId: string): Promise<sessionInfo> { override async load(modelId: string): Promise<sessionInfo> {
const args: string[] = [] const args: string[] = []
const cfg = this.config const cfg = this.config
const sysInfo = await window.core.api.getSystemInfo()
const [version, backend] = cfg.version_backend.split('/') const [version, backend] = cfg.version_backend.split('/')
if (!version || !backend) { if (!version || !backend) {
// TODO: sometimes version_backend is not set correctly. to investigate
throw new Error( throw new Error(
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>` `Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
) )
} }
const exe_name =
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
const backendPath = await joinPath([
janDataFolderPath,
'llamacpp',
'backends',
backend,
version,
'build',
'bin',
exe_name,
])
const modelConfigPath = await joinPath([ const modelConfigPath = await joinPath([
this.modelsBasePath, this.providerPath,
this.provider, 'models',
modelId, modelId,
'model.yml', 'model.yml',
]) ])
@ -483,8 +454,9 @@ export default class llamacpp_extension extends AIEngine {
console.log('Calling Tauri command llama_load with args:', args) console.log('Calling Tauri command llama_load with args:', args)
try { try {
// TODO: add LIBRARY_PATH
const sInfo = await invoke<sessionInfo>('load_llama_model', { const sInfo = await invoke<sessionInfo>('load_llama_model', {
backendPath, backendPath: await getBackendExePath(backend, version),
args, args,
}) })
@ -636,8 +608,8 @@ export default class llamacpp_extension extends AIEngine {
override async delete(modelId: string): Promise<void> { override async delete(modelId: string): Promise<void> {
const modelDir = await joinPath([ const modelDir = await joinPath([
this.modelsBasePath, this.providerPath,
this.provider, 'models',
modelId, modelId,
]) ])