refactor: rename interfaces and add getLoadedModels
The changes include: - Renaming interfaces (sessionInfo -> SessionInfo, unloadResult -> UnloadResult) for consistency - Adding getLoadedModels() method to retrieve active model IDs - Updating variable names from modelId to model_id for alignment - Updating cleanup paths to use XDG-standard locations - Improving type consistency across extension implementation
This commit is contained in:
parent
4ffc504150
commit
dbcce86bb8
4
Makefile
4
Makefile
@ -79,8 +79,8 @@ else ifeq ($(shell uname -s),Linux)
|
||||
rm -rfv ./electron/pre-install/*.tgz
|
||||
rm -rfv ./src-tauri/resources
|
||||
rm -rfv ./src-tauri/target
|
||||
rm -rfv "~/jan/extensions"
|
||||
rm -rfv "~/.cache/jan*"
|
||||
rm -rfv ~/.local/share/Jan/data/extensions
|
||||
rm -rfv ~/.cache/jan*
|
||||
else
|
||||
find . -name "node_modules" -type d -prune -exec rm -rfv '{}' +
|
||||
find . -name ".next" -type d -exec rm -rfv '{}' +
|
||||
|
||||
@ -132,28 +132,15 @@ export interface modelInfo {
|
||||
// 1. /list
|
||||
export type listResult = modelInfo[]
|
||||
|
||||
// 3. /load
|
||||
export interface loadOptions {
|
||||
modelId: string
|
||||
modelPath: string
|
||||
mmprojPath?: string
|
||||
port?: number
|
||||
}
|
||||
|
||||
export interface sessionInfo {
|
||||
export interface SessionInfo {
|
||||
pid: string // opaque handle for unload/chat
|
||||
port: number // llama-server output port (corrected from portid)
|
||||
modelId: string, //name of the model
|
||||
modelPath: string // path of the loaded model
|
||||
apiKey: string
|
||||
model_id: string, //name of the model
|
||||
model_path: string // path of the loaded model
|
||||
api_key: string
|
||||
}
|
||||
|
||||
// 4. /unload
|
||||
export interface unloadOptions {
|
||||
providerId: string
|
||||
sessionId: string
|
||||
}
|
||||
export interface unloadResult {
|
||||
export interface UnloadResult {
|
||||
success: boolean
|
||||
error?: string
|
||||
}
|
||||
@ -211,12 +198,12 @@ export abstract class AIEngine extends BaseExtension {
|
||||
/**
|
||||
* Loads a model into memory
|
||||
*/
|
||||
abstract load(modelId: string): Promise<sessionInfo>
|
||||
abstract load(modelId: string): Promise<SessionInfo>
|
||||
|
||||
/**
|
||||
* Unloads a model from memory
|
||||
*/
|
||||
abstract unload(sessionId: string): Promise<unloadResult>
|
||||
abstract unload(sessionId: string): Promise<UnloadResult>
|
||||
|
||||
/**
|
||||
* Sends a chat request to the model
|
||||
@ -238,6 +225,11 @@ export abstract class AIEngine extends BaseExtension {
|
||||
*/
|
||||
abstract abortImport(modelId: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Get currently loaded models
|
||||
*/
|
||||
abstract getLoadedModels(): Promise<string[]>
|
||||
|
||||
/**
|
||||
* Optional method to get the underlying chat client
|
||||
*/
|
||||
|
||||
@ -12,9 +12,8 @@ import {
|
||||
fs,
|
||||
joinPath,
|
||||
modelInfo,
|
||||
loadOptions,
|
||||
sessionInfo,
|
||||
unloadResult,
|
||||
SessionInfo,
|
||||
UnloadResult,
|
||||
chatCompletion,
|
||||
chatCompletionChunk,
|
||||
ImportOptions,
|
||||
@ -94,7 +93,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
private config: LlamacppConfig
|
||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||
private activeSessions: Map<string, SessionInfo> = new Map()
|
||||
private providerPath!: string
|
||||
private apiSecret: string = 'Jan'
|
||||
|
||||
@ -141,7 +140,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
// Terminate all active sessions
|
||||
for (const [_, sInfo] of this.activeSessions) {
|
||||
try {
|
||||
await this.unload(sInfo.modelId)
|
||||
await this.unload(sInfo.model_id)
|
||||
} catch (error) {
|
||||
console.error(`Failed to unload model ${sInfo.modelId}:`, error)
|
||||
}
|
||||
@ -214,11 +213,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
let modelInfos: modelInfo[] = []
|
||||
for (const modelId of modelIds) {
|
||||
const path = await joinPath([
|
||||
modelsDir,
|
||||
modelId,
|
||||
'model.yml',
|
||||
])
|
||||
const path = await joinPath([modelsDir, modelId, 'model.yml'])
|
||||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||||
|
||||
const modelInfo = {
|
||||
@ -284,7 +279,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||||
let mmprojPath = opts.mmprojPath ? await maybeDownload(opts.mmprojPath, 'mmproj.gguf') : undefined
|
||||
let mmprojPath = opts.mmprojPath
|
||||
? await maybeDownload(opts.mmprojPath, 'mmproj.gguf')
|
||||
: undefined
|
||||
|
||||
if (downloadItems.length > 0) {
|
||||
let downloadCompleted = false
|
||||
@ -372,10 +369,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
return port
|
||||
}
|
||||
|
||||
override async load(modelId: string): Promise<sessionInfo> {
|
||||
override async load(modelId: string): Promise<SessionInfo> {
|
||||
const sInfo = this.findSessionByModel(modelId)
|
||||
if (sInfo) {
|
||||
throw new Error("Model already loaded!!")
|
||||
throw new Error('Model already loaded!!')
|
||||
}
|
||||
const args: string[] = []
|
||||
const cfg = this.config
|
||||
@ -456,13 +453,15 @@ export default class llamacpp_extension extends AIEngine {
|
||||
args.push('--reasoning-budget', String(cfg.reasoning_budget))
|
||||
|
||||
console.log('Calling Tauri command llama_load with args:', args)
|
||||
const backendPath = await getBackendExePath(backend, version)
|
||||
const libraryPath = await joinPath([this.providerPath, 'lib'])
|
||||
|
||||
try {
|
||||
// TODO: add LIBRARY_PATH
|
||||
const sInfo = await invoke<sessionInfo>('load_llama_model', {
|
||||
backend_path: await getBackendExePath(backend, version),
|
||||
library_path: await joinPath([this.providerPath, 'lib']),
|
||||
args: args
|
||||
const sInfo = await invoke<SessionInfo>('load_llama_model', {
|
||||
backendPath,
|
||||
libraryPath,
|
||||
args
|
||||
})
|
||||
|
||||
// Store the session info for later use
|
||||
@ -475,15 +474,15 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
}
|
||||
|
||||
override async unload(modelId: string): Promise<unloadResult> {
|
||||
const sInfo: sessionInfo = this.findSessionByModel(modelId)
|
||||
override async unload(modelId: string): Promise<UnloadResult> {
|
||||
const sInfo: SessionInfo = this.findSessionByModel(modelId)
|
||||
if (!sInfo) {
|
||||
throw new Error(`No active session found for model: ${modelId}`)
|
||||
}
|
||||
const pid = sInfo.pid
|
||||
try {
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<unloadResult>('unload_llama_model', {
|
||||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||||
pid: pid
|
||||
})
|
||||
|
||||
@ -570,9 +569,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private findSessionByModel(modelId: string): sessionInfo | undefined {
|
||||
private findSessionByModel(modelId: string): SessionInfo | undefined {
|
||||
return Array.from(this.activeSessions.values()).find(
|
||||
(session) => session.modelId === modelId
|
||||
(session) => session.model_id === modelId
|
||||
)
|
||||
}
|
||||
|
||||
@ -612,11 +611,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
}
|
||||
|
||||
override async delete(modelId: string): Promise<void> {
|
||||
const modelDir = await joinPath([
|
||||
this.providerPath,
|
||||
'models',
|
||||
modelId,
|
||||
])
|
||||
const modelDir = await joinPath([this.providerPath, 'models', modelId])
|
||||
|
||||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||||
throw new Error(`Model ${modelId} does not exist`)
|
||||
@ -625,6 +620,14 @@ export default class llamacpp_extension extends AIEngine {
|
||||
await fs.rm(modelDir)
|
||||
}
|
||||
|
||||
override async getLoadedModels(): Promise<string[]> {
|
||||
let lmodels: string[] = []
|
||||
for (const [_, sInfo] of this.activeSessions) {
|
||||
lmodels.push(sInfo.model_id)
|
||||
}
|
||||
return lmodels
|
||||
}
|
||||
|
||||
// Optional method for direct client access
|
||||
override getChatClient(sessionId: string): any {
|
||||
throw new Error('method not implemented yet')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user