update folde structure. small refactoring
This commit is contained in:
parent
3b72d80979
commit
f7bcf43334
@ -7,7 +7,7 @@ import {
|
|||||||
import { invoke } from '@tauri-apps/api/core'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
|
|
||||||
// folder structure
|
// folder structure
|
||||||
// <Jan's data folder>/llamacpp/backends/<backend_name>/<backend_version>
|
// <Jan's data folder>/llamacpp/backends/<backend_version>/<backend_type>
|
||||||
|
|
||||||
// what should be available to the user for selection?
|
// what should be available to the user for selection?
|
||||||
export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> {
|
export async function listSupportedBackends(): Promise<{ version: string, backend: string }[]> {
|
||||||
@ -74,26 +74,38 @@ export async function listSupportedBackends(): Promise<{ version: string, backen
|
|||||||
return backendVersions
|
return backendVersions
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function isBackendInstalled(backend: string, version: string): Promise<boolean> {
|
export async function getBackendDir(backend: string, version: string): Promise<string> {
|
||||||
|
const janDataFolderPath = await getJanDataFolderPath()
|
||||||
|
const backendDir = await joinPath([janDataFolderPath, 'llamacpp', 'backends', version, backend])
|
||||||
|
return backendDir
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getBackendExePath(backend: string, version: string): Promise<string> {
|
||||||
const sysInfo = await window.core.api.getSystemInfo()
|
const sysInfo = await window.core.api.getSystemInfo()
|
||||||
const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
const exe_name = sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
||||||
|
const backendDir = await getBackendDir(backend, version)
|
||||||
|
const exePath = await joinPath([backendDir, 'build', 'bin', exe_name])
|
||||||
|
return exePath
|
||||||
|
}
|
||||||
|
|
||||||
const janDataFolderPath = await getJanDataFolderPath()
|
export async function isBackendInstalled(backend: string, version: string): Promise<boolean> {
|
||||||
const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version, 'build', 'bin', exe_name])
|
const exePath = await getBackendExePath(backend, version)
|
||||||
const result = await fs.existsSync(backendPath)
|
const result = await fs.existsSync(exePath)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function downloadBackend(backend: string, version: string): Promise<void> {
|
export async function downloadBackend(backend: string, version: string): Promise<void> {
|
||||||
const janDataFolderPath = await getJanDataFolderPath()
|
const janDataFolderPath = await getJanDataFolderPath()
|
||||||
const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
|
const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
|
||||||
|
const backendDir = await getBackendDir(backend, version)
|
||||||
|
const libDir = await joinPath([llamacppPath, 'lib'])
|
||||||
|
|
||||||
const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
||||||
|
|
||||||
const downloadItems = [
|
const downloadItems = [
|
||||||
{
|
{
|
||||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`,
|
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`,
|
||||||
save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']),
|
save_path: await joinPath([backendDir, 'backend.tar.gz']),
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -101,12 +113,12 @@ export async function downloadBackend(backend: string, version: string): Promise
|
|||||||
if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) {
|
if (backend.includes('cu11.7') && !(await _isCudaInstalled('11.7'))) {
|
||||||
downloadItems.push({
|
downloadItems.push({
|
||||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`,
|
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`,
|
||||||
save_path: await joinPath([llamacppPath, 'lib', 'cuda11.tar.gz']),
|
save_path: await joinPath([libDir, 'cuda11.tar.gz']),
|
||||||
})
|
})
|
||||||
} else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) {
|
} else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) {
|
||||||
downloadItems.push({
|
downloadItems.push({
|
||||||
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`,
|
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.0-x64.tar.gz`,
|
||||||
save_path: await joinPath([llamacppPath, 'lib', 'cuda12.tar.gz']),
|
save_path: await joinPath([libDir, 'cuda12.tar.gz']),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import {
|
|||||||
listSupportedBackends,
|
listSupportedBackends,
|
||||||
downloadBackend,
|
downloadBackend,
|
||||||
isBackendInstalled,
|
isBackendInstalled,
|
||||||
|
getBackendExePath,
|
||||||
} from './backend'
|
} from './backend'
|
||||||
import { invoke } from '@tauri-apps/api/core'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
|
|
||||||
@ -74,23 +75,27 @@ interface ModelConfig {
|
|||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Folder structure for downloaded models:
|
// Folder structure for llamacpp extension:
|
||||||
// <Jan's data folder>/models/llamacpp/<modelId>
|
// <Jan's data folder>/llamacpp
|
||||||
// - model.yml (required)
|
// - models/<modelId>/
|
||||||
// - model.gguf (optional, present if downloaded from URL)
|
// - model.yml (required)
|
||||||
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
// - model.gguf (optional, present if downloaded from URL)
|
||||||
//
|
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
||||||
// Contents of model.yml can be found in ModelConfig interface
|
// Contents of model.yml can be found in ModelConfig interface
|
||||||
|
//
|
||||||
|
// - backends/<backend_version>/<backend_type>/
|
||||||
|
// - build/bin/llama-server (or llama-server.exe on Windows)
|
||||||
|
//
|
||||||
|
// - lib/
|
||||||
|
// - e.g. libcudart.so.12
|
||||||
|
|
||||||
export default class llamacpp_extension extends AIEngine {
|
export default class llamacpp_extension extends AIEngine {
|
||||||
provider: string = 'llamacpp'
|
provider: string = 'llamacpp'
|
||||||
readonly providerId: string = 'llamacpp'
|
readonly providerId: string = 'llamacpp'
|
||||||
|
|
||||||
private config: LlamacppConfig
|
private config: LlamacppConfig
|
||||||
private downloadManager
|
|
||||||
private downloadBackend // for testing
|
|
||||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||||
private modelsBasePath!: string
|
private providerPath!: string
|
||||||
private apiSecret: string = 'Jan'
|
private apiSecret: string = 'Jan'
|
||||||
|
|
||||||
override async onLoad(): Promise<void> {
|
override async onLoad(): Promise<void> {
|
||||||
@ -114,7 +119,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
this.registerSettings(settings)
|
this.registerSettings(settings)
|
||||||
this.downloadBackend = downloadBackend
|
|
||||||
|
|
||||||
let config = {}
|
let config = {}
|
||||||
for (const item of SETTINGS) {
|
for (const item of SETTINGS) {
|
||||||
@ -126,14 +130,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
this.config = config as LlamacppConfig
|
this.config = config as LlamacppConfig
|
||||||
|
|
||||||
this.downloadManager = window.core.extensionManager.getByName(
|
|
||||||
'@janhq/download-extension'
|
|
||||||
)
|
|
||||||
|
|
||||||
// Initialize models base path - assuming this would be retrieved from settings
|
// Initialize models base path - assuming this would be retrieved from settings
|
||||||
this.modelsBasePath = await joinPath([
|
this.providerPath = await joinPath([
|
||||||
await getJanDataFolderPath(),
|
await getJanDataFolderPath(),
|
||||||
'models',
|
this.providerId,
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
// Implement the required LocalProvider interface methods
|
// Implement the required LocalProvider interface methods
|
||||||
override async list(): Promise<modelInfo[]> {
|
override async list(): Promise<modelInfo[]> {
|
||||||
const modelsDir = await joinPath([this.modelsBasePath, this.provider])
|
const modelsDir = await joinPath([this.providerPath, 'models'])
|
||||||
if (!(await fs.existsSync(modelsDir))) {
|
if (!(await fs.existsSync(modelsDir))) {
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
@ -215,8 +215,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
let modelInfos: modelInfo[] = []
|
let modelInfos: modelInfo[] = []
|
||||||
for (const modelId of modelIds) {
|
for (const modelId of modelIds) {
|
||||||
const path = await joinPath([
|
const path = await joinPath([
|
||||||
this.modelsBasePath,
|
modelsDir,
|
||||||
this.provider,
|
|
||||||
modelId,
|
modelId,
|
||||||
'model.yml',
|
'model.yml',
|
||||||
])
|
])
|
||||||
@ -246,64 +245,46 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isValidModelId(modelId)) {
|
if (!isValidModelId(modelId))
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
let configPath = await joinPath([
|
const configPath = await joinPath([
|
||||||
this.modelsBasePath,
|
this.providerPath,
|
||||||
this.provider,
|
'models',
|
||||||
modelId,
|
modelId,
|
||||||
'model.yml',
|
'model.yml',
|
||||||
])
|
])
|
||||||
if (await fs.existsSync(configPath)) {
|
if (await fs.existsSync(configPath))
|
||||||
throw new Error(`Model ${modelId} already exists`)
|
throw new Error(`Model ${modelId} already exists`)
|
||||||
}
|
|
||||||
|
|
||||||
const taskId = this.createDownloadTaskId(modelId)
|
|
||||||
|
|
||||||
// this is relative to Jan's data folder
|
// this is relative to Jan's data folder
|
||||||
const modelDir = `models/${this.provider}/${modelId}`
|
const modelDir = `${this.providerId}/models/${modelId}`
|
||||||
|
|
||||||
// we only use these from opts
|
// we only use these from opts
|
||||||
// opts.modelPath: URL to the model file
|
// opts.modelPath: URL to the model file
|
||||||
// opts.mmprojPath: URL to the mmproj file
|
// opts.mmprojPath: URL to the mmproj file
|
||||||
|
|
||||||
let downloadItems: DownloadItem[] = []
|
let downloadItems: DownloadItem[] = []
|
||||||
let modelPath = opts.modelPath
|
|
||||||
let mmprojPath = opts.mmprojPath
|
|
||||||
|
|
||||||
const modelItem = {
|
const maybeDownload = async (path: string, saveName: string) => {
|
||||||
url: opts.modelPath,
|
// if URL, add to downloadItems, and return local path
|
||||||
save_path: `${modelDir}/model.gguf`,
|
if (path.startsWith('https://')) {
|
||||||
}
|
const localPath = `${modelDir}/${saveName}`
|
||||||
if (opts.modelPath.startsWith('https://')) {
|
downloadItems.push({ url: path, save_path: localPath })
|
||||||
downloadItems.push(modelItem)
|
return localPath
|
||||||
modelPath = modelItem.save_path
|
|
||||||
} else {
|
|
||||||
// this should be absolute path
|
|
||||||
if (!(await fs.existsSync(modelPath))) {
|
|
||||||
throw new Error(`Model file not found: ${modelPath}`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if local file (absolute path), check if it exists
|
||||||
|
// and return the path
|
||||||
|
if (!(await fs.existsSync(path)))
|
||||||
|
throw new Error(`File not found: ${path}`)
|
||||||
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opts.mmprojPath) {
|
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||||||
const mmprojItem = {
|
let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined
|
||||||
url: opts.mmprojPath,
|
|
||||||
save_path: `${modelDir}/mmproj.gguf`,
|
|
||||||
}
|
|
||||||
if (opts.mmprojPath.startsWith('https://')) {
|
|
||||||
downloadItems.push(mmprojItem)
|
|
||||||
mmprojPath = mmprojItem.save_path
|
|
||||||
} else {
|
|
||||||
// this should be absolute path
|
|
||||||
if (!(await fs.existsSync(mmprojPath))) {
|
|
||||||
throw new Error(`MMProj file not found: ${mmprojPath}`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (downloadItems.length > 0) {
|
if (downloadItems.length > 0) {
|
||||||
let downloadCompleted = false
|
let downloadCompleted = false
|
||||||
@ -319,23 +300,24 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
})
|
})
|
||||||
downloadCompleted = transferred === total
|
downloadCompleted = transferred === total
|
||||||
}
|
}
|
||||||
await this.downloadManager.downloadFiles(
|
const downloadManager = window.core.extensionManager.getByName(
|
||||||
|
'@janhq/download-extension'
|
||||||
|
)
|
||||||
|
await downloadManager.downloadFiles(
|
||||||
downloadItems,
|
downloadItems,
|
||||||
taskId,
|
this.createDownloadTaskId(modelId),
|
||||||
onProgress
|
onProgress
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const eventName = downloadCompleted
|
||||||
|
? 'onFileDownloadSuccess'
|
||||||
|
: 'onFileDownloadStopped'
|
||||||
|
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error downloading model:', modelId, opts, error)
|
console.error('Error downloading model:', modelId, opts, error)
|
||||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
|
|
||||||
// once we reach this point, it either means download finishes or it was cancelled.
|
|
||||||
// if there was an error, it would have been caught above
|
|
||||||
const eventName = downloadCompleted
|
|
||||||
? 'onFileDownloadSuccess'
|
|
||||||
: 'onFileDownloadStopped'
|
|
||||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: check if files are valid GGUF files
|
// TODO: check if files are valid GGUF files
|
||||||
@ -362,14 +344,17 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||||||
await invoke<void>('write_yaml', {
|
await invoke<void>('write_yaml', {
|
||||||
data: modelConfig,
|
data: modelConfig,
|
||||||
savePath: `${modelDir}/model.yml`,
|
savePath: configPath,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
override async abortImport(modelId: string): Promise<void> {
|
override async abortImport(modelId: string): Promise<void> {
|
||||||
// prepand provider name to avoid name collision
|
// prepand provider name to avoid name collision
|
||||||
const taskId = this.createDownloadTaskId(modelId)
|
const taskId = this.createDownloadTaskId(modelId)
|
||||||
await this.downloadManager.cancelDownload(taskId)
|
const downloadManager = window.core.extensionManager.getByName(
|
||||||
|
'@janhq/download-extension'
|
||||||
|
)
|
||||||
|
await downloadManager.cancelDownload(taskId)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -390,31 +375,17 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
override async load(modelId: string): Promise<sessionInfo> {
|
override async load(modelId: string): Promise<sessionInfo> {
|
||||||
const args: string[] = []
|
const args: string[] = []
|
||||||
const cfg = this.config
|
const cfg = this.config
|
||||||
const sysInfo = await window.core.api.getSystemInfo()
|
|
||||||
const [version, backend] = cfg.version_backend.split('/')
|
const [version, backend] = cfg.version_backend.split('/')
|
||||||
if (!version || !backend) {
|
if (!version || !backend) {
|
||||||
// TODO: sometimes version_backend is not set correctly. to investigate
|
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const exe_name =
|
|
||||||
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
|
||||||
const janDataFolderPath = await getJanDataFolderPath()
|
const janDataFolderPath = await getJanDataFolderPath()
|
||||||
const backendPath = await joinPath([
|
|
||||||
janDataFolderPath,
|
|
||||||
'llamacpp',
|
|
||||||
'backends',
|
|
||||||
backend,
|
|
||||||
version,
|
|
||||||
'build',
|
|
||||||
'bin',
|
|
||||||
exe_name,
|
|
||||||
])
|
|
||||||
const modelConfigPath = await joinPath([
|
const modelConfigPath = await joinPath([
|
||||||
this.modelsBasePath,
|
this.providerPath,
|
||||||
this.provider,
|
'models',
|
||||||
modelId,
|
modelId,
|
||||||
'model.yml',
|
'model.yml',
|
||||||
])
|
])
|
||||||
@ -483,8 +454,9 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
console.log('Calling Tauri command llama_load with args:', args)
|
console.log('Calling Tauri command llama_load with args:', args)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// TODO: add LIBRARY_PATH
|
||||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||||
backendPath,
|
backendPath: await getBackendExePath(backend, version),
|
||||||
args,
|
args,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -636,8 +608,8 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
override async delete(modelId: string): Promise<void> {
|
override async delete(modelId: string): Promise<void> {
|
||||||
const modelDir = await joinPath([
|
const modelDir = await joinPath([
|
||||||
this.modelsBasePath,
|
this.providerPath,
|
||||||
this.provider,
|
'models',
|
||||||
modelId,
|
modelId,
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user