feat: update AIEngine load method and backend path handling
- Changed load method to accept modelId instead of loadOptions for better clarity and simplicity - Renamed engineBasePath parameter to backendPath for consistency with the backend's directory structure - Added getRandomPort method to ensure unique ports for each session to prevent conflicts - Refactored configuration and model loading logic to improve maintainability and reduce redundancy
This commit is contained in:
parent
9e24e28341
commit
fd9e034461
@ -178,7 +178,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Loads a model into memory
|
* Loads a model into memory
|
||||||
*/
|
*/
|
||||||
abstract load(opts: loadOptions): Promise<sessionInfo>
|
abstract load(modelId: string): Promise<sessionInfo>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads a model from memory
|
* Unloads a model from memory
|
||||||
|
|||||||
@ -21,34 +21,38 @@ import {
|
|||||||
chatCompletionRequest,
|
chatCompletionRequest,
|
||||||
events,
|
events,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { listSupportedBackends, downloadBackend, isBackendInstalled } from './backend'
|
import {
|
||||||
|
listSupportedBackends,
|
||||||
|
downloadBackend,
|
||||||
|
isBackendInstalled,
|
||||||
|
} from './backend'
|
||||||
import { invoke } from '@tauri-apps/api/core'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
|
|
||||||
type LlamacppConfig = {
|
type LlamacppConfig = {
|
||||||
backend: string;
|
backend: string
|
||||||
n_gpu_layers: number;
|
n_gpu_layers: number
|
||||||
ctx_size: number;
|
ctx_size: number
|
||||||
threads: number;
|
threads: number
|
||||||
threads_batch: number;
|
threads_batch: number
|
||||||
n_predict: number;
|
n_predict: number
|
||||||
batch_size: number;
|
batch_size: number
|
||||||
ubatch_size: number;
|
ubatch_size: number
|
||||||
device: string;
|
device: string
|
||||||
split_mode: string;
|
split_mode: string
|
||||||
main_gpu: number;
|
main_gpu: number
|
||||||
flash_attn: boolean;
|
flash_attn: boolean
|
||||||
cont_batching: boolean;
|
cont_batching: boolean
|
||||||
no_mmap: boolean;
|
no_mmap: boolean
|
||||||
mlock: boolean;
|
mlock: boolean
|
||||||
no_kv_offload: boolean;
|
no_kv_offload: boolean
|
||||||
cache_type_k: string;
|
cache_type_k: string
|
||||||
cache_type_v: string;
|
cache_type_v: string
|
||||||
defrag_thold: number;
|
defrag_thold: number
|
||||||
rope_scaling: string;
|
rope_scaling: string
|
||||||
rope_scale: number;
|
rope_scale: number
|
||||||
rope_freq_base: number;
|
rope_freq_base: number
|
||||||
rope_freq_scale: number;
|
rope_freq_scale: number
|
||||||
reasoning_budget: number;
|
reasoning_budget: number
|
||||||
}
|
}
|
||||||
|
|
||||||
interface DownloadItem {
|
interface DownloadItem {
|
||||||
@ -64,7 +68,6 @@ interface ModelConfig {
|
|||||||
size_bytes: number
|
size_bytes: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
@ -85,11 +88,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
private config: LlamacppConfig
|
private config: LlamacppConfig
|
||||||
private downloadManager
|
private downloadManager
|
||||||
private downloadBackend // for testing
|
private downloadBackend // for testing
|
||||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||||
private modelsBasePath!: string
|
private modelsBasePath!: string
|
||||||
private enginesBasePath!: string
|
private apiSecret: string = 'Jan'
|
||||||
private apiSecret: string = "Jan"
|
|
||||||
|
|
||||||
override async onLoad(): Promise<void> {
|
override async onLoad(): Promise<void> {
|
||||||
super.onLoad() // Calls registerEngine() from AIEngine
|
super.onLoad() // Calls registerEngine() from AIEngine
|
||||||
@ -117,33 +119,36 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
let config = {}
|
let config = {}
|
||||||
for (const item of SETTINGS) {
|
for (const item of SETTINGS) {
|
||||||
const defaultValue = item.controllerProps.value
|
const defaultValue = item.controllerProps.value
|
||||||
config[item.key] = await this.getSetting<typeof defaultValue>(item.key, defaultValue)
|
config[item.key] = await this.getSetting<typeof defaultValue>(
|
||||||
|
item.key,
|
||||||
|
defaultValue
|
||||||
|
)
|
||||||
}
|
}
|
||||||
this.config = config as LlamacppConfig
|
this.config = config as LlamacppConfig
|
||||||
|
|
||||||
this.downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
|
this.downloadManager = window.core.extensionManager.getByName(
|
||||||
|
'@janhq/download-extension'
|
||||||
|
)
|
||||||
|
|
||||||
// Initialize models base path - assuming this would be retrieved from settings
|
// Initialize models base path - assuming this would be retrieved from settings
|
||||||
this.modelsBasePath = await joinPath([
|
this.modelsBasePath = await joinPath([
|
||||||
await getJanDataFolderPath(),
|
await getJanDataFolderPath(),
|
||||||
'models',
|
'models',
|
||||||
])
|
])
|
||||||
|
|
||||||
this.enginesBasePath = await joinPath([await getJanDataFolderPath(), 'engines'])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async onUnload(): Promise<void> {
|
override async onUnload(): Promise<void> {
|
||||||
// Terminate all active sessions
|
// Terminate all active sessions
|
||||||
for (const [sessionId, _] of this.activeSessions) {
|
for (const [sessionId, _] of this.activeSessions) {
|
||||||
try {
|
try {
|
||||||
await this.unload(sessionId);
|
await this.unload(sessionId)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Failed to unload session ${sessionId}:`, error);
|
console.error(`Failed to unload session ${sessionId}:`, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the sessions map
|
// Clear the sessions map
|
||||||
this.activeSessions.clear();
|
this.activeSessions.clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
onSettingUpdate<T>(key: string, value: T): void {
|
onSettingUpdate<T>(key: string, value: T): void {
|
||||||
@ -168,7 +173,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
private async generateApiKey(modelId: string): Promise<string> {
|
private async generateApiKey(modelId: string): Promise<string> {
|
||||||
const hash = await invoke<string>('generate_api_key', {
|
const hash = await invoke<string>('generate_api_key', {
|
||||||
modelId: modelId,
|
modelId: modelId,
|
||||||
apiSecret: this.apiSecret
|
apiSecret: this.apiSecret,
|
||||||
})
|
})
|
||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
@ -211,7 +216,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
let modelInfos: modelInfo[] = []
|
let modelInfos: modelInfo[] = []
|
||||||
for (const modelId of modelIds) {
|
for (const modelId of modelIds) {
|
||||||
const path = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
|
const path = await joinPath([
|
||||||
|
this.modelsBasePath,
|
||||||
|
this.provider,
|
||||||
|
modelId,
|
||||||
|
'model.yml',
|
||||||
|
])
|
||||||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||||||
|
|
||||||
const modelInfo = {
|
const modelInfo = {
|
||||||
@ -235,14 +245,21 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
// check for empty parts or path traversal
|
// check for empty parts or path traversal
|
||||||
const parts = id.split('/')
|
const parts = id.split('/')
|
||||||
return parts.every(s => s !== '' && s !== '.' && s !== '..')
|
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isValidModelId(modelId)) {
|
if (!isValidModelId(modelId)) {
|
||||||
throw new Error(`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`)
|
throw new Error(
|
||||||
|
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
let configPath = await joinPath([this.modelsBasePath, this.provider, modelId, 'model.yml'])
|
let configPath = await joinPath([
|
||||||
|
this.modelsBasePath,
|
||||||
|
this.provider,
|
||||||
|
modelId,
|
||||||
|
'model.yml',
|
||||||
|
])
|
||||||
if (await fs.existsSync(configPath)) {
|
if (await fs.existsSync(configPath)) {
|
||||||
throw new Error(`Model ${modelId} already exists`)
|
throw new Error(`Model ${modelId} already exists`)
|
||||||
}
|
}
|
||||||
@ -260,8 +277,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
let modelPath = opts.modelPath
|
let modelPath = opts.modelPath
|
||||||
let mmprojPath = opts.mmprojPath
|
let mmprojPath = opts.mmprojPath
|
||||||
|
|
||||||
const modelItem = { url: opts.modelPath, save_path: `${modelDir}/model.gguf` }
|
const modelItem = {
|
||||||
if (opts.modelPath.startsWith("https://")) {
|
url: opts.modelPath,
|
||||||
|
save_path: `${modelDir}/model.gguf`,
|
||||||
|
}
|
||||||
|
if (opts.modelPath.startsWith('https://')) {
|
||||||
downloadItems.push(modelItem)
|
downloadItems.push(modelItem)
|
||||||
modelPath = modelItem.save_path
|
modelPath = modelItem.save_path
|
||||||
} else {
|
} else {
|
||||||
@ -272,8 +292,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (opts.mmprojPath) {
|
if (opts.mmprojPath) {
|
||||||
const mmprojItem = { url: opts.mmprojPath, save_path: `${modelDir}/mmproj.gguf` }
|
const mmprojItem = {
|
||||||
if (opts.mmprojPath.startsWith("https://")) {
|
url: opts.mmprojPath,
|
||||||
|
save_path: `${modelDir}/mmproj.gguf`,
|
||||||
|
}
|
||||||
|
if (opts.mmprojPath.startsWith('https://')) {
|
||||||
downloadItems.push(mmprojItem)
|
downloadItems.push(mmprojItem)
|
||||||
mmprojPath = mmprojItem.save_path
|
mmprojPath = mmprojItem.save_path
|
||||||
} else {
|
} else {
|
||||||
@ -298,7 +321,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
})
|
})
|
||||||
downloadCompleted = transferred === total
|
downloadCompleted = transferred === total
|
||||||
}
|
}
|
||||||
await this.downloadManager.downloadFiles(downloadItems, taskId, onProgress)
|
await this.downloadManager.downloadFiles(
|
||||||
|
downloadItems,
|
||||||
|
taskId,
|
||||||
|
onProgress
|
||||||
|
)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error downloading model:', modelId, opts, error)
|
console.error('Error downloading model:', modelId, opts, error)
|
||||||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||||||
@ -307,7 +334,9 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
// once we reach this point, it either means download finishes or it was cancelled.
|
// once we reach this point, it either means download finishes or it was cancelled.
|
||||||
// if there was an error, it would have been caught above
|
// if there was an error, it would have been caught above
|
||||||
const eventName = downloadCompleted ? 'onFileDownloadSuccess' : 'onFileDownloadStopped'
|
const eventName = downloadCompleted
|
||||||
|
? 'onFileDownloadSuccess'
|
||||||
|
: 'onFileDownloadStopped'
|
||||||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -315,9 +344,13 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
|
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
|
||||||
// or absolute paths (if they are provided as local files)
|
// or absolute paths (if they are provided as local files)
|
||||||
const janDataFolderPath = await getJanDataFolderPath()
|
const janDataFolderPath = await getJanDataFolderPath()
|
||||||
let size_bytes = (await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))).size
|
let size_bytes = (
|
||||||
|
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
|
||||||
|
).size
|
||||||
if (mmprojPath) {
|
if (mmprojPath) {
|
||||||
size_bytes += (await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))).size
|
size_bytes += (
|
||||||
|
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
|
||||||
|
).size
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add name as import() argument
|
// TODO: add name as import() argument
|
||||||
@ -328,10 +361,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
name: modelId,
|
name: modelId,
|
||||||
size_bytes,
|
size_bytes,
|
||||||
} as ModelConfig
|
} as ModelConfig
|
||||||
await invoke<void>(
|
await invoke<void>('write_yaml', {
|
||||||
'write_yaml',
|
data: modelConfig,
|
||||||
{ data: modelConfig, savePath: `${modelDir}/model.yml` },
|
savePath: `${modelDir}/model.yml`,
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
override async abortImport(modelId: string): Promise<void> {
|
override async abortImport(modelId: string): Promise<void> {
|
||||||
@ -339,24 +372,62 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
const taskId = this.createDownloadTaskId(modelId)
|
const taskId = this.createDownloadTaskId(modelId)
|
||||||
await this.downloadManager.cancelDownload(taskId)
|
await this.downloadManager.cancelDownload(taskId)
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* Function to find a random port
|
||||||
|
*/
|
||||||
|
private async getRandomPort(): Promise<number> {
|
||||||
|
let port: number
|
||||||
|
do {
|
||||||
|
port = Math.floor(Math.random() * 1000) + 3000
|
||||||
|
} while (
|
||||||
|
Array.from(this.activeSessions.values()).some(
|
||||||
|
(info) => info.port === port
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
|
||||||
override async load(opts: loadOptions): Promise<sessionInfo> {
|
override async load(modelId: string): Promise<sessionInfo> {
|
||||||
const args: string[] = []
|
const args: string[] = []
|
||||||
const cfg = this.config
|
const cfg = this.config
|
||||||
|
const sysInfo = await window.core.api.getSystemInfo()
|
||||||
|
const [backend, version] = cfg.backend.split('-')
|
||||||
|
const exe_name =
|
||||||
|
sysInfo.os_type === 'windows' ? 'llama-server.exe' : 'llama-server'
|
||||||
|
const backendPath = await joinPath([
|
||||||
|
await getJanDataFolderPath(),
|
||||||
|
'llamacpp',
|
||||||
|
'backends',
|
||||||
|
backend,
|
||||||
|
version,
|
||||||
|
'build',
|
||||||
|
'bin',
|
||||||
|
exe_name,
|
||||||
|
])
|
||||||
|
const modelPath = await joinPath([
|
||||||
|
this.modelsBasePath,
|
||||||
|
this.provider,
|
||||||
|
modelId,
|
||||||
|
])
|
||||||
|
const modelConfigPath = await joinPath([modelPath, 'model.yml'])
|
||||||
|
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||||||
|
modelConfigPath,
|
||||||
|
})
|
||||||
|
const port = await this.getRandomPort()
|
||||||
|
|
||||||
// disable llama-server webui
|
// disable llama-server webui
|
||||||
args.push('--no-webui')
|
args.push('--no-webui')
|
||||||
// update key for security; TODO: (qnixsynapse) Make it more secure
|
// update key for security; TODO: (qnixsynapse) Make it more secure
|
||||||
const api_key = this.generateApiKey(opts.modelPath)
|
const api_key = await this.generateApiKey(modelId)
|
||||||
args.push(`--api-key ${api_key}`)
|
args.push(`--api-key ${api_key}`)
|
||||||
|
|
||||||
// model option is required
|
// model option is required
|
||||||
// TODO: llama.cpp extension lookup model path based on modelId
|
// TODO: llama.cpp extension lookup model path based on modelId
|
||||||
args.push('-m', opts.modelPath)
|
args.push('-m', modelConfig.model_path)
|
||||||
args.push('-a', opts.modelId)
|
args.push('-a', modelId)
|
||||||
args.push('--port', String(opts.port || 8080)) // Default port if not specified
|
args.push('--port', String(port)) // Default port if not specified
|
||||||
if (opts.mmprojPath) {
|
if (modelConfig.mmproj_path) {
|
||||||
args.push('--mmproj', opts.mmprojPath)
|
args.push('--mmproj', modelConfig.mmproj_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cfg.ctx_size !== undefined) {
|
if (cfg.ctx_size !== undefined) {
|
||||||
@ -366,14 +437,16 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
// Add remaining options from the interface
|
// Add remaining options from the interface
|
||||||
if (cfg.n_gpu_layers > 0) args.push('-ngl', String(cfg.n_gpu_layers))
|
if (cfg.n_gpu_layers > 0) args.push('-ngl', String(cfg.n_gpu_layers))
|
||||||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||||||
if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch))
|
if (cfg.threads_batch > 0)
|
||||||
|
args.push('--threads-batch', String(cfg.threads_batch))
|
||||||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||||||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||||||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||||||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||||||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||||||
if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode)
|
if (cfg.split_mode.length > 0) args.push('--split-mode', cfg.split_mode)
|
||||||
if (cfg.main_gpu !== undefined) args.push('--main-gpu', String(cfg.main_gpu))
|
if (cfg.main_gpu !== undefined)
|
||||||
|
args.push('--main-gpu', String(cfg.main_gpu))
|
||||||
|
|
||||||
// Boolean flags
|
// Boolean flags
|
||||||
if (cfg.flash_attn) args.push('--flash-attn')
|
if (cfg.flash_attn) args.push('--flash-attn')
|
||||||
@ -396,7 +469,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||||
server_path: this.enginesBasePath,
|
backendPath: backendPath,
|
||||||
args: args,
|
args: args,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -545,7 +618,11 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
override async delete(modelId: string): Promise<void> {
|
override async delete(modelId: string): Promise<void> {
|
||||||
const modelDir = await joinPath([this.modelsBasePath, this.provider, modelId])
|
const modelDir = await joinPath([
|
||||||
|
this.modelsBasePath,
|
||||||
|
this.provider,
|
||||||
|
modelId,
|
||||||
|
])
|
||||||
|
|
||||||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||||||
throw new Error(`Model ${modelId} does not exist`)
|
throw new Error(`Model ${modelId} does not exist`)
|
||||||
|
|||||||
@ -58,7 +58,7 @@ pub struct unloadResult {
|
|||||||
pub async fn load_llama_model(
|
pub async fn load_llama_model(
|
||||||
_app_handle: AppHandle, // Get the AppHandle
|
_app_handle: AppHandle, // Get the AppHandle
|
||||||
state: State<'_, AppState>, // Access the shared state
|
state: State<'_, AppState>, // Access the shared state
|
||||||
engineBasePath: String,
|
backendPath: String,
|
||||||
args: Vec<String>, // Arguments from the frontend
|
args: Vec<String>, // Arguments from the frontend
|
||||||
) -> ServerResult<sessionInfo> {
|
) -> ServerResult<sessionInfo> {
|
||||||
let mut process_lock = state.llama_server_process.lock().await;
|
let mut process_lock = state.llama_server_process.lock().await;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user