update folde structure. small refactoring
This commit is contained in:
parent
3b72d80979
commit
f7bcf43334
@ -7,7 +7,7 @@ import {
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
// folder structure
|
||||
// <Jan's data folder>/llamacpp/backends/<backend_name>/<backend_version>
|
||||
// <Jan's data folder>/llamacpp/backends/<backend_version>/<backend_type>
|
||||
|
||||
// what should be available to the user for selection?
|
||||
export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> {
|
||||
@ -74,26 +74,38 @@ export async function listSupportedBackends(): Promise<{ version: string, backen
|
||||
return backendVersions
|
||||
}
|
||||
|
||||
export async function isBackendInstalled(backend: string, version: string): Promise<boolean> {
|
||||
export async function getBackendDir(backend: string, version: string): Promise<string> {
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const backendDir = await joinPath([janDataFolderPath, 'llamacpp', 'backends', version, backend])
|
||||
return backendDir
|
||||
}
|
||||
|
||||
export async function getBackendExePath(backend: string, version: string): Promise<string> {
|
||||
const sysInfo = await window.core.api.getSystemInfo()
|
||||
const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
||||
const backendDir = await getBackendDir(backend, version)
|
||||
const exePath = await joinPath([backendDir, 'build', 'bin', exe_name])
|
||||
return exePath
|
||||
}
|
||||
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version, 'build', 'bin', exe_name])
|
||||
const result = await fs.existsSync(backendPath)
|
||||
export async function isBackendInstalled(backend: string, version: string): Promise<boolean> {
|
||||
const exePath = await getBackendExePath(backend, version)
|
||||
const result = await fs.existsSync(exePath)
|
||||
return result
|
||||
}
|
||||
|
||||
export async function downloadBackend(backend: string, version: string): Promise<void> {
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
|
||||
const backendDir = await getBackendDir(backend, version)
|
||||
const libDir = await joinPath([llamacppPath, 'lib'])
|
||||
|
||||
const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
||||
|
||||
const downloadItems = [
|
||||
{
|
||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`,
|
||||
save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']),
|
||||
save_path: await joinPath([backendDir, 'backend.tar.gz']),
|
||||
}
|
||||
]
|
||||
|
||||
@ -101,12 +113,12 @@ export async function downloadBackend(backend: string, version: string): Promise
|
||||
if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) {
|
||||
downloadItems.push({
|
||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`,
|
||||
save_path: await joinPath([llamacppPath, 'lib', 'cuda11.tar.gz']),
|
||||
save_path: await joinPath([libDir, 'cuda11.tar.gz']),
|
||||
})
|
||||
} else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) {
|
||||
downloadItems.push({
|
||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`,
|
||||
save_path: await joinPath([llamacppPath, 'lib', 'cuda12.tar.gz']),
|
||||
save_path: await joinPath([libDir, 'cuda12.tar.gz']),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ import {
|
||||
listSupportedBackends,
|
||||
downloadBackend,
|
||||
isBackendInstalled,
|
||||
getBackendExePath,
|
||||
} from './backend'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
@ -74,23 +75,27 @@ interface ModelConfig {
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
|
||||
// Folder structure for downloaded models:
|
||||
// <Jan's data folder>/models/llamacpp/<modelId>
|
||||
// Folder structure for llamacpp extension:
|
||||
// <Jan's data folder>/llamacpp
|
||||
// - models/<modelId>/
|
||||
// - model.yml (required)
|
||||
// - model.gguf (optional, present if downloaded from URL)
|
||||
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
||||
//
|
||||
// Contents of model.yml can be found in ModelConfig interface
|
||||
//
|
||||
// - backends/<backend_version>/<backend_type>/
|
||||
// - build/bin/llama-server (or llama-server.exe on Windows)
|
||||
//
|
||||
// - lib/
|
||||
// - e.g. libcudart.so.12
|
||||
|
||||
export default class llamacpp_extension extends AIEngine {
|
||||
provider: string = 'llamacpp'
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
private config: LlamacppConfig
|
||||
private downloadManager
|
||||
private downloadBackend // for testing
|
||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||
private modelsBasePath!: string
|
||||
private providerPath!: string
|
||||
private apiSecret: string = 'Jan'
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
@ -114,7 +119,6 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
this.registerSettings(settings)
|
||||
this.downloadBackend = downloadBackend
|
||||
|
||||
let config = {}
|
||||
for (const item of SETTINGS) {
|
||||
@ -126,14 +130,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
this.config = config as LlamacppConfig
|
||||
|
||||
this.downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
|
||||
// Initialize models base path - assuming this would be retrieved from settings
|
||||
this.modelsBasePath = await joinPath([
|
||||
this.providerPath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
'models',
|
||||
this.providerId,
|
||||
])
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
// Implement the required LocalProvider interface methods
|
||||
override async list(): Promise<modelInfo[]> {
|
||||
const modelsDir = await joinPath([this.modelsBasePath, this.provider])
|
||||
const modelsDir = await joinPath([this.providerPath, 'models'])
|
||||
if (!(await fs.existsSync(modelsDir))) {
|
||||
return []
|
||||
}
|
||||
@ -215,8 +215,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
let modelInfos: modelInfo[] = []
|
||||
for (const modelId of modelIds) {
|
||||
const path = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
modelsDir,
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
@ -246,65 +245,47 @@ export default class llamacpp_extension extends AIEngine {
|
||||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||||
}
|
||||
|
||||
if (!isValidModelId(modelId)) {
|
||||
if (!isValidModelId(modelId))
|
||||
throw new Error(
|
||||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||||
)
|
||||
}
|
||||
|
||||
let configPath = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
const configPath = await joinPath([
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
if (await fs.existsSync(configPath)) {
|
||||
if (await fs.existsSync(configPath))
|
||||
throw new Error(`Model ${modelId} already exists`)
|
||||
}
|
||||
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
|
||||
// this is relative to Jan's data folder
|
||||
const modelDir = `models/${this.provider}/${modelId}`
|
||||
const modelDir = `${this.providerId}/models/${modelId}`
|
||||
|
||||
// we only use these from opts
|
||||
// opts.modelPath: URL to the model file
|
||||
// opts.mmprojPath: URL to the mmproj file
|
||||
|
||||
let downloadItems: DownloadItem[] = []
|
||||
let modelPath = opts.modelPath
|
||||
let mmprojPath = opts.mmprojPath
|
||||
|
||||
const modelItem = {
|
||||
url: opts.modelPath,
|
||||
save_path: `${modelDir}/model.gguf`,
|
||||
}
|
||||
if (opts.modelPath.startsWith('https://')) {
|
||||
downloadItems.push(modelItem)
|
||||
modelPath = modelItem.save_path
|
||||
} else {
|
||||
// this should be absolute path
|
||||
if (!(await fs.existsSync(modelPath))) {
|
||||
throw new Error(`Model file not found: ${modelPath}`)
|
||||
}
|
||||
const maybeDownload = async (path: string, saveName: string) => {
|
||||
// if URL, add to downloadItems, and return local path
|
||||
if (path.startsWith('https://')) {
|
||||
const localPath = `${modelDir}/${saveName}`
|
||||
downloadItems.push({ url: path, save_path: localPath })
|
||||
return localPath
|
||||
}
|
||||
|
||||
if (opts.mmprojPath) {
|
||||
const mmprojItem = {
|
||||
url: opts.mmprojPath,
|
||||
save_path: `${modelDir}/mmproj.gguf`,
|
||||
}
|
||||
if (opts.mmprojPath.startsWith('https://')) {
|
||||
downloadItems.push(mmprojItem)
|
||||
mmprojPath = mmprojItem.save_path
|
||||
} else {
|
||||
// this should be absolute path
|
||||
if (!(await fs.existsSync(mmprojPath))) {
|
||||
throw new Error(`MMProj file not found: ${mmprojPath}`)
|
||||
}
|
||||
}
|
||||
// if local file (absolute path), check if it exists
|
||||
// and return the path
|
||||
if (!(await fs.existsSync(path)))
|
||||
throw new Error(`File not found: ${path}`)
|
||||
return path
|
||||
}
|
||||
|
||||
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||||
let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined
|
||||
|
||||
if (downloadItems.length > 0) {
|
||||
let downloadCompleted = false
|
||||
|
||||
@ -319,23 +300,24 @@ export default class llamacpp_extension extends AIEngine {
|
||||
})
|
||||
downloadCompleted = transferred === total
|
||||
}
|
||||
await this.downloadManager.downloadFiles(
|
||||
const downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
await downloadManager.downloadFiles(
|
||||
downloadItems,
|
||||
taskId,
|
||||
this.createDownloadTaskId(modelId),
|
||||
onProgress
|
||||
)
|
||||
|
||||
const eventName = downloadCompleted
|
||||
? 'onFileDownloadSuccess'
|
||||
: 'onFileDownloadStopped'
|
||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||
} catch (error) {
|
||||
console.error('Error downloading model:', modelId, opts, error)
|
||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||
throw error
|
||||
}
|
||||
|
||||
// once we reach this point, it either means download finishes or it was cancelled.
|
||||
// if there was an error, it would have been caught above
|
||||
const eventName = downloadCompleted
|
||||
? 'onFileDownloadSuccess'
|
||||
: 'onFileDownloadStopped'
|
||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||
}
|
||||
|
||||
// TODO: check if files are valid GGUF files
|
||||
@ -362,14 +344,17 @@ export default class llamacpp_extension extends AIEngine {
|
||||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||||
await invoke<void>('write_yaml', {
|
||||
data: modelConfig,
|
||||
savePath: `${modelDir}/model.yml`,
|
||||
savePath: configPath,
|
||||
})
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
// prepand provider name to avoid name collision
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
await this.downloadManager.cancelDownload(taskId)
|
||||
const downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
await downloadManager.cancelDownload(taskId)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -390,31 +375,17 @@ export default class llamacpp_extension extends AIEngine {
|
||||
override async load(modelId: string): Promise<sessionInfo> {
|
||||
const args: string[] = []
|
||||
const cfg = this.config
|
||||
const sysInfo = await window.core.api.getSystemInfo()
|
||||
const [version, backend] = cfg.version_backend.split('/')
|
||||
if (!version || !backend) {
|
||||
// TODO: sometimes version_backend is not set correctly. to investigate
|
||||
throw new Error(
|
||||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||||
)
|
||||
}
|
||||
|
||||
const exe_name =
|
||||
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
const backendPath = await joinPath([
|
||||
janDataFolderPath,
|
||||
'llamacpp',
|
||||
'backends',
|
||||
backend,
|
||||
version,
|
||||
'build',
|
||||
'bin',
|
||||
exe_name,
|
||||
])
|
||||
const modelConfigPath = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
@ -483,8 +454,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
console.log('Calling Tauri command llama_load with args:', args)
|
||||
|
||||
try {
|
||||
// TODO: add LIBRARY_PATH
|
||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||
backendPath,
|
||||
backendPath: await getBackendExePath(backend, version),
|
||||
args,
|
||||
})
|
||||
|
||||
@ -636,8 +608,8 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
override async delete(modelId: string): Promise<void> {
|
||||
const modelDir = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user