feat: update AIEngine load method and backend path handling
- Changed load method to accept modelId instead of loadOptions for better clarity and simplicity - Renamed engineBasePath parameter to backendPath for consistency with the backend's directory structure - Added getRandomPort method to ensure unique ports for each session to prevent conflicts - Refactored configuration and model loading logic to improve maintainability and reduce redundancy
This commit is contained in:
parent
9e24e28341
commit
fd9e034461
@ -178,7 +178,7 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Loads a model into memory
|
||||
*/
|
||||
abstract load(opts: loadOptions): Promise<sessionInfo>
|
||||
abstract load(modelId: string): Promise<sessionInfo>
|
||||
|
||||
/**
|
||||
* Unloads a model from memory
|
||||
|
||||
@ -21,34 +21,38 @@ import {
|
||||
chatCompletionRequest,
|
||||
events,
|
||||
} from '@janhq/core'
|
||||
import { listSupportedBackends, downloadBackend, isBackendInstalled } from './backend'
|
||||
import {
|
||||
listSupportedBackends,
|
||||
downloadBackend,
|
||||
isBackendInstalled,
|
||||
} from './backend'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
type LlamacppConfig = {
|
||||
backend: string;
|
||||
n_gpu_layers: number;
|
||||
ctx_size: number;
|
||||
threads: number;
|
||||
threads_batch: number;
|
||||
n_predict: number;
|
||||
batch_size: number;
|
||||
ubatch_size: number;
|
||||
device: string;
|
||||
split_mode: string;
|
||||
main_gpu: number;
|
||||
flash_attn: boolean;
|
||||
cont_batching: boolean;
|
||||
no_mmap: boolean;
|
||||
mlock: boolean;
|
||||
no_kv_offload: boolean;
|
||||
cache_type_k: string;
|
||||
cache_type_v: string;
|
||||
defrag_thold: number;
|
||||
rope_scaling: string;
|
||||
rope_scale: number;
|
||||
rope_freq_base: number;
|
||||
rope_freq_scale: number;
|
||||
reasoning_budget: number;
|
||||
backend: string
|
||||
n_gpu_layers: number
|
||||
ctx_size: number
|
||||
threads: number
|
||||
threads_batch: number
|
||||
n_predict: number
|
||||
batch_size: number
|
||||
ubatch_size: number
|
||||
device: string
|
||||
split_mode: string
|
||||
main_gpu: number
|
||||
flash_attn: boolean
|
||||
cont_batching: boolean
|
||||
no_mmap: boolean
|
||||
mlock: boolean
|
||||
no_kv_offload: boolean
|
||||
cache_type_k: string
|
||||
cache_type_v: string
|
||||
defrag_thold: number
|
||||
rope_scaling: string
|
||||
rope_scale: number
|
||||
rope_freq_base: number
|
||||
rope_freq_scale: number
|
||||
reasoning_budget: number
|
||||
}
|
||||
|
||||
interface DownloadItem {
|
||||
@ -64,7 +68,6 @@ interface ModelConfig {
|
||||
size_bytes: number
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
@ -85,11 +88,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
private config: LlamacppConfig
|
||||
private downloadManager
|
||||
private downloadBackend // for testing
|
||||
private downloadBackend // for testing
|
||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||
private modelsBasePath!: string
|
||||
private enginesBasePath!: string
|
||||
private apiSecret: string = "Jan"
|
||||
private apiSecret: string = 'Jan'
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
super.onLoad() // Calls registerEngine() from AIEngine
|
||||
@ -117,33 +119,36 @@ export default class llamacpp_extension extends AIEngine {
|
||||
let config = {}
|
||||
for (const item of SETTINGS) {
|
||||
const defaultValue = item.controllerProps.value
|
||||
config[item.key] = await this.getSetting<typeof defaultValue>(item.key, defaultValue)
|
||||
config[item.key] = await this.getSetting<typeof defaultValue>(
|
||||
item.key,
|
||||
defaultValue
|
||||
)
|
||||
}
|
||||
this.config = config as LlamacppConfig
|
||||
|
||||
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
||||
this.downloadManager = window.core.extensionManager.getByName(
|
||||
'@janhq/download-extension'
|
||||
)
|
||||
|
||||
// Initialize models base path - assuming this would be retrieved from settings
|
||||
this.modelsBasePath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
'models',
|
||||
])
|
||||
|
||||
this.enginesBasePath = await joinPath([await getJanDataFolderPath(), 'engines'])
|
||||
}
|
||||
|
||||
override async onUnload(): Promise<void> {
|
||||
// Terminate all active sessions
|
||||
for (const [sessionId, _] of this.activeSessions) {
|
||||
try {
|
||||
await this.unload(sessionId);
|
||||
await this.unload(sessionId)
|
||||
} catch (error) {
|
||||
console.error(`Failed to unload session ${sessionId}:`, error);
|
||||
console.error(`Failed to unload session ${sessionId}:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the sessions map
|
||||
this.activeSessions.clear();
|
||||
this.activeSessions.clear()
|
||||
}
|
||||
|
||||
onSettingUpdate<T>(key: string, value: T): void {
|
||||
@ -168,7 +173,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
private async generateApiKey(modelId: string): Promise<string> {
|
||||
const hash = await invoke<string>('generate_api_key', {
|
||||
modelId: modelId,
|
||||
apiSecret: this.apiSecret
|
||||
apiSecret: this.apiSecret,
|
||||
})
|
||||
return hash
|
||||
}
|
||||
@ -211,7 +216,12 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
let modelInfos: modelInfo[] = []
|
||||
for (const modelId of modelIds) {
|
||||
const path = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
|
||||
const path = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||||
|
||||
const modelInfo = {
|
||||
@ -235,14 +245,21 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
// check for empty parts or path traversal
|
||||
const parts = id.split('/')
|
||||
return parts.every(s => s !== '' && s !== '.' && s !== '..')
|
||||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||||
}
|
||||
|
||||
if (!isValidModelId(modelId)) {
|
||||
throw new Error(`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`)
|
||||
throw new Error(
|
||||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||||
)
|
||||
}
|
||||
|
||||
let configPath = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
|
||||
let configPath = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
if (await fs.existsSync(configPath)) {
|
||||
throw new Error(`Model ${modelId} already exists`)
|
||||
}
|
||||
@ -260,8 +277,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
let modelPath = opts.modelPath
|
||||
let mmprojPath = opts.mmprojPath
|
||||
|
||||
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
|
||||
if (opts.modelPath.startsWith("https://")) {
|
||||
const modelItem = {
|
||||
url: opts.modelPath,
|
||||
save_path: `${modelDir}/model.gguf`,
|
||||
}
|
||||
if (opts.modelPath.startsWith('https://')) {
|
||||
downloadItems.push(modelItem)
|
||||
modelPath = modelItem.save_path
|
||||
} else {
|
||||
@ -272,8 +292,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
if (opts.mmprojPath) {
|
||||
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
|
||||
if (opts.mmprojPath.startsWith("https://")) {
|
||||
const mmprojItem = {
|
||||
url: opts.mmprojPath,
|
||||
save_path: `${modelDir}/mmproj.gguf`,
|
||||
}
|
||||
if (opts.mmprojPath.startsWith('https://')) {
|
||||
downloadItems.push(mmprojItem)
|
||||
mmprojPath = mmprojItem.save_path
|
||||
} else {
|
||||
@ -298,7 +321,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
})
|
||||
downloadCompleted = transferred === total
|
||||
}
|
||||
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
|
||||
await this.downloadManager.downloadFiles(
|
||||
downloadItems,
|
||||
taskId,
|
||||
onProgress
|
||||
)
|
||||
} catch (error) {
|
||||
console.error('Error downloading model:', modelId, opts, error)
|
||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||
@ -307,7 +334,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
// once we reach this point, it either means download finishes or it was cancelled.
|
||||
// if there was an error, it would have been caught above
|
||||
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
|
||||
const eventName = downloadCompleted
|
||||
? 'onFileDownloadSuccess'
|
||||
: 'onFileDownloadStopped'
|
||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||
}
|
||||
|
||||
@ -315,9 +344,13 @@ export default class llamacpp_extension extends AIEngine {
|
||||
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
|
||||
// or absolute paths (if they are provided as local files)
|
||||
const janDataFolderPath = await getJanDataFolderPath()
|
||||
let size_bytes = (await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))).size
|
||||
let size_bytes = (
|
||||
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
|
||||
).size
|
||||
if (mmprojPath) {
|
||||
size_bytes += (await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))).size
|
||||
size_bytes += (
|
||||
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
|
||||
).size
|
||||
}
|
||||
|
||||
// TODO: add name as import() argument
|
||||
@ -328,10 +361,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
name: modelId,
|
||||
size_bytes,
|
||||
} as ModelConfig
|
||||
await invoke<void>(
|
||||
'write_yaml',
|
||||
{ data: modelConfig, savePath: `${modelDir}/model.yml` },
|
||||
)
|
||||
await invoke<void>('write_yaml', {
|
||||
data: modelConfig,
|
||||
savePath: `${modelDir}/model.yml`,
|
||||
})
|
||||
}
|
||||
|
||||
override async abortImport(modelId: string): Promise<void> {
|
||||
@ -339,24 +372,62 @@ export default class llamacpp_extension extends AIEngine {
|
||||
const taskId = this.createDownloadTaskId(modelId)
|
||||
await this.downloadManager.cancelDownload(taskId)
|
||||
}
|
||||
/**
|
||||
* Function to find a random port
|
||||
*/
|
||||
private async getRandomPort(): Promise<number> {
|
||||
let port: number
|
||||
do {
|
||||
port = Math.floor(Math.random() * 1000) + 3000
|
||||
} while (
|
||||
Array.from(this.activeSessions.values()).some(
|
||||
(info) => info.port === port
|
||||
)
|
||||
)
|
||||
return port
|
||||
}
|
||||
|
||||
override async load(opts: loadOptions): Promise<sessionInfo> {
|
||||
override async load(modelId: string): Promise<sessionInfo> {
|
||||
const args: string[] = []
|
||||
const cfg = this.config
|
||||
const sysInfo = await window.core.api.getSystemInfo()
|
||||
const [backend, version] = cfg.backend.split('-')
|
||||
const exe_name =
|
||||
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
||||
const backendPath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
'llamacpp',
|
||||
'backends',
|
||||
backend,
|
||||
version,
|
||||
'build',
|
||||
'bin',
|
||||
exe_name,
|
||||
])
|
||||
const modelPath = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
modelId,
|
||||
])
|
||||
const modelConfigPath = await joinPath([modelPath, 'model.yml'])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||
modelConfigPath,
|
||||
})
|
||||
const port = await this.getRandomPort()
|
||||
|
||||
// disable llama-server webui
|
||||
args.push('--no-webui')
|
||||
// update key for security; TODO: (qnixsynapse) Make it more secure
|
||||
const api_key = this.generateApiKey(opts.modelPath)
|
||||
const api_key = await this.generateApiKey(modelId)
|
||||
args.push(`--api-key ${api_key}`)
|
||||
|
||||
// model option is required
|
||||
// TODO: llama.cpp extension lookup model path based on modelId
|
||||
args.push('-m', opts.modelPath)
|
||||
args.push('-a', opts.modelId)
|
||||
args.push('--port', String(opts.port || 8080)) // Default port if not specified
|
||||
if (opts.mmprojPath) {
|
||||
args.push('--mmproj', opts.mmprojPath)
|
||||
args.push('-m', modelConfig.model_path)
|
||||
args.push('-a', modelId)
|
||||
args.push('--port', String(port)) // Default port if not specified
|
||||
if (modelConfig.mmproj_path) {
|
||||
args.push('--mmproj', modelConfig.mmproj_path)
|
||||
}
|
||||
|
||||
if (cfg.ctx_size !== undefined) {
|
||||
@ -366,14 +437,16 @@ export default class llamacpp_extension extends AIEngine {
|
||||
// Add remaining options from the interface
|
||||
if (cfg.n_gpu_layers > 0) args.push('-ngl', String(cfg.n_gpu_layers))
|
||||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||||
if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch))
|
||||
if (cfg.threads_batch > 0)
|
||||
args.push('--threads-batch', String(cfg.threads_batch))
|
||||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||||
if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode)
|
||||
if (cfg.main_gpu !== undefined) args.push('--main-gpu', String(cfg.main_gpu))
|
||||
if (cfg.main_gpu !== undefined)
|
||||
args.push('--main-gpu', String(cfg.main_gpu))
|
||||
|
||||
// Boolean flags
|
||||
if (cfg.flash_attn) args.push('--flash-attn')
|
||||
@ -396,7 +469,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
try {
|
||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||
server_path: this.enginesBasePath,
|
||||
backendPath: backendPath,
|
||||
args: args,
|
||||
})
|
||||
|
||||
@ -545,7 +618,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
override async delete(modelId: string): Promise<void> {
|
||||
const modelDir = await joinPath([this.modelsBasePath, this.provider, modelId])
|
||||
const modelDir = await joinPath([
|
||||
this.modelsBasePath,
|
||||
this.provider,
|
||||
modelId,
|
||||
])
|
||||
|
||||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||||
throw new Error(`Model ${modelId} does not exist`)
|
||||
|
||||
@ -58,7 +58,7 @@ pub struct unloadResult {
|
||||
pub async fn load_llama_model(
|
||||
_app_handle: AppHandle, // Get the AppHandle
|
||||
state: State<'_, AppState>, // Access the shared state
|
||||
engineBasePath: String,
|
||||
backendPath: String,
|
||||
args: Vec<String>, // Arguments from the frontend
|
||||
) -> ServerResult<sessionInfo> {
|
||||
let mut process_lock = state.llama_server_process.lock().await;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user