feat: update AIEngine load method and backend path handling

- Changed load method to accept modelId instead of loadOptions for better clarity and simplicity
- Renamed engineBasePath parameter to backendPath for consistency with the backend's directory structure
- Added getRandomPort method to ensure unique ports for each session to prevent conflicts
- Refactored configuration and model loading logic to improve maintainability and reduce redundancy
This commit is contained in:
Akarshan Biswas 2025-05-30 10:25:58 +05:30 committed by Louis
parent 9e24e28341
commit fd9e034461
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
3 changed files with 143 additions and 66 deletions

View File

@ -178,7 +178,7 @@ export abstract class AIEngine extends BaseExtension {
/**
* Loads a model into memory
*/
abstract load(opts: loadOptions): Promise<sessionInfo>
abstract load(modelId: string): Promise<sessionInfo>
/**
* Unloads a model from memory

View File

@ -21,34 +21,38 @@ import {
chatCompletionRequest,
events,
} from '@janhq/core'
import { listSupportedBackends, downloadBackend, isBackendInstalled } from './backend'
import {
listSupportedBackends,
downloadBackend,
isBackendInstalled,
} from './backend'
import { invoke } from '@tauri-apps/api/core'
type LlamacppConfig = {
backend: string;
n_gpu_layers: number;
ctx_size: number;
threads: number;
threads_batch: number;
n_predict: number;
batch_size: number;
ubatch_size: number;
device: string;
split_mode: string;
main_gpu: number;
flash_attn: boolean;
cont_batching: boolean;
no_mmap: boolean;
mlock: boolean;
no_kv_offload: boolean;
cache_type_k: string;
cache_type_v: string;
defrag_thold: number;
rope_scaling: string;
rope_scale: number;
rope_freq_base: number;
rope_freq_scale: number;
reasoning_budget: number;
backend: string
n_gpu_layers: number
ctx_size: number
threads: number
threads_batch: number
n_predict: number
batch_size: number
ubatch_size: number
device: string
split_mode: string
main_gpu: number
flash_attn: boolean
cont_batching: boolean
no_mmap: boolean
mlock: boolean
no_kv_offload: boolean
cache_type_k: string
cache_type_v: string
defrag_thold: number
rope_scaling: string
rope_scale: number
rope_freq_base: number
rope_freq_scale: number
reasoning_budget: number
}
interface DownloadItem {
@ -64,7 +68,6 @@ interface ModelConfig {
size_bytes: number
}
/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests.
@ -85,11 +88,10 @@ export default class llamacpp_extension extends AIEngine {
private config: LlamacppConfig
private downloadManager
private downloadBackend // for testing
private downloadBackend // for testing
private activeSessions: Map<string, sessionInfo> = new Map()
private modelsBasePath!: string
private enginesBasePath!: string
private apiSecret: string = "Jan"
private apiSecret: string = 'Jan'
override async onLoad(): Promise<void> {
super.onLoad() // Calls registerEngine() from AIEngine
@ -117,33 +119,36 @@ export default class llamacpp_extension extends AIEngine {
let config = {}
for (const item of SETTINGS) {
const defaultValue = item.controllerProps.value
config[item.key] = await this.getSetting<typeof defaultValue>(item.key, defaultValue)
config[item.key] = await this.getSetting<typeof defaultValue>(
item.key,
defaultValue
)
}
this.config = config as LlamacppConfig
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
this.downloadManager = window.core.extensionManager.getByName(
'@janhq/download-extension'
)
// Initialize models base path - assuming this would be retrieved from settings
this.modelsBasePath = await joinPath([
await getJanDataFolderPath(),
'models',
])
this.enginesBasePath = await joinPath([await getJanDataFolderPath(), 'engines'])
}
override async onUnload(): Promise<void> {
// Terminate all active sessions
for (const [sessionId, _] of this.activeSessions) {
try {
await this.unload(sessionId);
await this.unload(sessionId)
} catch (error) {
console.error(`Failed to unload session ${sessionId}:`, error);
console.error(`Failed to unload session ${sessionId}:`, error)
}
}
// Clear the sessions map
this.activeSessions.clear();
this.activeSessions.clear()
}
onSettingUpdate<T>(key: string, value: T): void {
@ -168,7 +173,7 @@ export default class llamacpp_extension extends AIEngine {
private async generateApiKey(modelId: string): Promise<string> {
const hash = await invoke<string>('generate_api_key', {
modelId: modelId,
apiSecret: this.apiSecret
apiSecret: this.apiSecret,
})
return hash
}
@ -211,7 +216,12 @@ export default class llamacpp_extension extends AIEngine {
let modelInfos: modelInfo[] = []
for (const modelId of modelIds) {
const path = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
const path = await joinPath([
this.modelsBasePath,
this.provider,
modelId,
'model.yml',
])
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
const modelInfo = {
@ -235,14 +245,21 @@ export default class llamacpp_extension extends AIEngine {
// check for empty parts or path traversal
const parts = id.split('/')
return parts.every(s => s !== '' && s !== '.' && s !== '..')
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
}
if (!isValidModelId(modelId)) {
throw new Error(`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`)
throw new Error(
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
)
}
let configPath = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
let configPath = await joinPath([
this.modelsBasePath,
this.provider,
modelId,
'model.yml',
])
if (await fs.existsSync(configPath)) {
throw new Error(`Model ${modelId} already exists`)
}
@ -260,8 +277,11 @@ export default class llamacpp_extension extends AIEngine {
let modelPath = opts.modelPath
let mmprojPath = opts.mmprojPath
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
if (opts.modelPath.startsWith("https://")) {
const modelItem = {
url: opts.modelPath,
save_path: `${modelDir}/model.gguf`,
}
if (opts.modelPath.startsWith('https://')) {
downloadItems.push(modelItem)
modelPath = modelItem.save_path
} else {
@ -272,8 +292,11 @@ export default class llamacpp_extension extends AIEngine {
}
if (opts.mmprojPath) {
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
if (opts.mmprojPath.startsWith("https://")) {
const mmprojItem = {
url: opts.mmprojPath,
save_path: `${modelDir}/mmproj.gguf`,
}
if (opts.mmprojPath.startsWith('https://')) {
downloadItems.push(mmprojItem)
mmprojPath = mmprojItem.save_path
} else {
@ -298,7 +321,11 @@ export default class llamacpp_extension extends AIEngine {
})
downloadCompleted = transferred === total
}
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
await this.downloadManager.downloadFiles(
downloadItems,
taskId,
onProgress
)
} catch (error) {
console.error('Error downloading model:', modelId, opts, error)
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
@ -307,7 +334,9 @@ export default class llamacpp_extension extends AIEngine {
// once we reach this point, it either means download finishes or it was cancelled.
// if there was an error, it would have been caught above
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
const eventName = downloadCompleted
? 'onFileDownloadSuccess'
: 'onFileDownloadStopped'
events.emit(eventName, { modelId, downloadType: 'Model' })
}
@ -315,9 +344,13 @@ export default class llamacpp_extension extends AIEngine {
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
// or absolute paths (if they are provided as local files)
const janDataFolderPath = await getJanDataFolderPath()
let size_bytes = (await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))).size
let size_bytes = (
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
).size
if (mmprojPath) {
size_bytes += (await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))).size
size_bytes += (
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
).size
}
// TODO: add name as import() argument
@ -328,10 +361,10 @@ export default class llamacpp_extension extends AIEngine {
name: modelId,
size_bytes,
} as ModelConfig
await invoke<void>(
'write_yaml',
{ data: modelConfig, savePath: `${modelDir}/model.yml` },
)
await invoke<void>('write_yaml', {
data: modelConfig,
savePath: `${modelDir}/model.yml`,
})
}
override async abortImport(modelId: string): Promise<void> {
@ -339,24 +372,62 @@ export default class llamacpp_extension extends AIEngine {
const taskId = this.createDownloadTaskId(modelId)
await this.downloadManager.cancelDownload(taskId)
}
/**
* Function to find a random port
*/
private async getRandomPort(): Promise<number> {
let port: number
do {
port = Math.floor(Math.random() * 1000) + 3000
} while (
Array.from(this.activeSessions.values()).some(
(info) => info.port === port
)
)
return port
}
override async load(opts: loadOptions): Promise<sessionInfo> {
override async load(modelId: string): Promise<sessionInfo> {
const args: string[] = []
const cfg = this.config
const sysInfo = await window.core.api.getSystemInfo()
const [backend, version] = cfg.backend.split('-')
const exe_name =
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
const backendPath = await joinPath([
await getJanDataFolderPath(),
'llamacpp',
'backends',
backend,
version,
'build',
'bin',
exe_name,
])
const modelPath = await joinPath([
this.modelsBasePath,
this.provider,
modelId,
])
const modelConfigPath = await joinPath([modelPath, 'model.yml'])
const modelConfig = await invoke<ModelConfig>('read_yaml', {
modelConfigPath,
})
const port = await this.getRandomPort()
// disable llama-server webui
args.push('--no-webui')
// update key for security; TODO: (qnixsynapse) Make it more secure
const api_key = this.generateApiKey(opts.modelPath)
const api_key = await this.generateApiKey(modelId)
args.push(`--api-key ${api_key}`)
// model option is required
// TODO: llama.cpp extension lookup model path based on modelId
args.push('-m', opts.modelPath)
args.push('-a', opts.modelId)
args.push('--port', String(opts.port || 8080)) // Default port if not specified
if (opts.mmprojPath) {
args.push('--mmproj', opts.mmprojPath)
args.push('-m', modelConfig.model_path)
args.push('-a', modelId)
args.push('--port', String(port)) // Default port if not specified
if (modelConfig.mmproj_path) {
args.push('--mmproj', modelConfig.mmproj_path)
}
if (cfg.ctx_size !== undefined) {
@ -366,14 +437,16 @@ export default class llamacpp_extension extends AIEngine {
// Add remaining options from the interface
if (cfg.n_gpu_layers > 0) args.push('-ngl', String(cfg.n_gpu_layers))
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch))
if (cfg.threads_batch > 0)
args.push('--threads-batch', String(cfg.threads_batch))
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
if (cfg.device.length > 0) args.push('--device', cfg.device)
if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode)
if (cfg.main_gpu !== undefined) args.push('--main-gpu', String(cfg.main_gpu))
if (cfg.main_gpu !== undefined)
args.push('--main-gpu', String(cfg.main_gpu))
// Boolean flags
if (cfg.flash_attn) args.push('--flash-attn')
@ -396,7 +469,7 @@ export default class llamacpp_extension extends AIEngine {
try {
const sInfo = await invoke<sessionInfo>('load_llama_model', {
server_path: this.enginesBasePath,
backendPath: backendPath,
args: args,
})
@ -545,7 +618,11 @@ export default class llamacpp_extension extends AIEngine {
}
override async delete(modelId: string): Promise<void> {
const modelDir = await joinPath([this.modelsBasePath, this.provider, modelId])
const modelDir = await joinPath([
this.modelsBasePath,
this.provider,
modelId,
])
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
throw new Error(`Model ${modelId} does not exist`)

View File

@ -58,7 +58,7 @@ pub struct unloadResult {
pub async fn load_llama_model(
_app_handle: AppHandle, // Get the AppHandle
state: State<'_, AppState>, // Access the shared state
engineBasePath: String,
backendPath: String,
args: Vec<String>, // Arguments from the frontend
) -> ServerResult<sessionInfo> {
let mut process_lock = state.llama_server_process.lock().await;