refactor load/unload

This commit is contained in:
Akarshan Biswas 2025-05-20 12:39:18 +05:30 committed by Louis
parent b4670b5526
commit bbbf4779df
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
4 changed files with 306 additions and 350 deletions

View File

@ -16,9 +16,6 @@ export abstract class AIEngine extends BaseExtension {
*/ */
override onLoad() { override onLoad() {
this.registerEngine() this.registerEngine()
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
} }
/** /**
@ -27,31 +24,4 @@ export abstract class AIEngine extends BaseExtension {
registerEngine() { registerEngine() {
EngineManager.instance().register(this) EngineManager.instance().register(this)
} }
/**
* Loads the model.
*/
async loadModel(model: Partial<Model>, abortController?: AbortController): Promise<any> {
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelReady, model)
return Promise.resolve()
}
/**
* Stops the model.
*/
async unloadModel(model?: Partial<Model>): Promise<any> {
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
events.emit(ModelEvent.OnModelStopped, model ?? {})
return Promise.resolve()
}
/**
* Inference request
*/
inference(data: MessageRequest) {}
/**
* Stop inference
*/
stopInference() {}
} }

View File

@ -6,35 +6,30 @@
* @module llamacpp-extension/src/index * @module llamacpp-extension/src/index
*/ */
import { import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core'
AIEngine,
getJanDataFolderPath,
fs,
Model,
} from '@janhq/core'
import { invoke } from '@tauri-apps/api/tauri' import { invoke } from '@tauri-apps/api/tauri'
import { import {
LocalProvider, localProvider,
ModelInfo, modelInfo,
ListOptions, listOptions,
ListResult, listResult,
PullOptions, pullOptions,
PullResult, pullResult,
LoadOptions, loadOptions,
SessionInfo, sessionInfo,
UnloadOptions, unloadOptions,
UnloadResult, unloadResult,
ChatOptions, chatOptions,
ChatCompletion, chatCompletion,
ChatCompletionChunk, chatCompletionChunk,
DeleteOptions, deleteOptions,
DeleteResult, deleteResult,
ImportOptions, importOptions,
ImportResult, importResult,
AbortPullOptions, abortPullOptions,
AbortPullResult, abortPullResult,
ChatCompletionRequest, chatCompletionRequest,
} from './types' } from './types'
/** /**
@ -61,246 +56,224 @@ function parseGGUFFileName(filename: string): {
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/ */
export default class inference_llamacpp_extension export default class llamacpp_extension
extends AIEngine extends AIEngine
implements LocalProvider implements localProvider
{ {
provider: string = 'llamacpp' provider: string = 'llamacpp'
readonly providerId: string = 'llamcpp' readonly providerId: string = 'llamacpp'
private activeSessions: Map<string, SessionInfo> = new Map()
private activeSessions: Map<string, sessionInfo> = new Map()
private modelsBasePath!: string private modelsBasePath!: string
private activeRequests: Map<string, AbortController> = new Map()
override async onLoad(): Promise<void> { override async onLoad(): Promise<void> {
super.onLoad() // Calls registerEngine() from AIEngine super.onLoad() // Calls registerEngine() from AIEngine
this.registerSettings(SETTINGS_DEFINITIONS) this.registerSettings(SETTINGS)
const customPath = await this.getSetting<string>( // Initialize models base path - assuming this would be retrieved from settings
LlamaCppSettings.ModelsPath, this.modelsBasePath = await joinPath([
''
)
if (customPath && (await fs.exists(customPath))) {
this.modelsBasePath = customPath
} else {
this.modelsBasePath = await path.join(
await getJanDataFolderPath(), await getJanDataFolderPath(),
'models', 'models',
ENGINE_ID ])
)
}
await fs.createDirAll(this.modelsBasePath)
console.log(
`${this.providerId} provider loaded. Models path: ${this.modelsBasePath}`
)
// Optionally, list and register models with the core system if AIEngine expects it
// const models = await this.listModels({ providerId: this.providerId });
// this.registerModels(this.mapModelInfoToCoreModel(models)); // mapModelInfoToCoreModel would be a helper
} }
async getModelsPath(): Promise<string> { // Implement the required LocalProvider interface methods
// Ensure modelsBasePath is initialized async list(opts: listOptions): Promise<listResult> {
if (!this.modelsBasePath) { throw new Error('method not implemented yet')
const customPath = await this.getSetting<string>( }
LlamaCppSettings.ModelsPath,
'' async pull(opts: pullOptions): Promise<pullResult> {
) throw new Error('method not implemented yet')
if (customPath && (await fs.exists(customPath))) { }
this.modelsBasePath = customPath
async load(opts: loadOptions): Promise<sessionInfo> {
const args: string[] = []
// model option is required
args.push('-m', opts.modelPath)
args.push('--port', String(opts.port || 8080)) // Default port if not specified
if (opts.n_gpu_layers === undefined) {
// in case of CPU only build, this option will be ignored
args.push('-ngl', '99')
} else { } else {
this.modelsBasePath = await path.join( args.push('-ngl', String(opts.n_gpu_layers))
await getJanDataFolderPath(),
'models',
ENGINE_ID
)
}
await fs.createDirAll(this.modelsBasePath)
}
return this.modelsBasePath
} }
async listModels(_opts: ListOptions): Promise<ListResult> { if (opts.n_ctx !== undefined) {
const modelsDir = await this.getModelsPath() args.push('-c', String(opts.n_ctx))
const result: ModelInfo[] = [] }
// Add remaining options from the interface
if (opts.threads !== undefined) {
args.push('--threads', String(opts.threads))
}
if (opts.threads_batch !== undefined) {
args.push('--threads-batch', String(opts.threads_batch))
}
if (opts.ctx_size !== undefined) {
args.push('--ctx-size', String(opts.ctx_size))
}
if (opts.n_predict !== undefined) {
args.push('--n-predict', String(opts.n_predict))
}
if (opts.batch_size !== undefined) {
args.push('--batch-size', String(opts.batch_size))
}
if (opts.ubatch_size !== undefined) {
args.push('--ubatch-size', String(opts.ubatch_size))
}
if (opts.device !== undefined) {
args.push('--device', opts.device)
}
if (opts.split_mode !== undefined) {
args.push('--split-mode', opts.split_mode)
}
if (opts.main_gpu !== undefined) {
args.push('--main-gpu', String(opts.main_gpu))
}
// Boolean flags
if (opts.flash_attn === true) {
args.push('--flash-attn')
}
if (opts.cont_batching === true) {
args.push('--cont-batching')
}
if (opts.no_mmap === true) {
args.push('--no-mmap')
}
if (opts.mlock === true) {
args.push('--mlock')
}
if (opts.no_kv_offload === true) {
args.push('--no-kv-offload')
}
if (opts.cache_type_k !== undefined) {
args.push('--cache-type-k', opts.cache_type_k)
}
if (opts.cache_type_v !== undefined) {
args.push('--cache-type-v', opts.cache_type_v)
}
if (opts.defrag_thold !== undefined) {
args.push('--defrag-thold', String(opts.defrag_thold))
}
if (opts.rope_scaling !== undefined) {
args.push('--rope-scaling', opts.rope_scaling)
}
if (opts.rope_scale !== undefined) {
args.push('--rope-scale', String(opts.rope_scale))
}
if (opts.rope_freq_base !== undefined) {
args.push('--rope-freq-base', String(opts.rope_freq_base))
}
if (opts.rope_freq_scale !== undefined) {
args.push('--rope-freq-scale', String(opts.rope_freq_scale))
}
console.log('Calling Tauri command load with args:', args)
try { try {
if (!(await fs.exists(modelsDir))) { const sessionInfo = await invoke<sessionInfo>('plugin:llamacpp|load', {
await fs.createDirAll(modelsDir) args: args,
return []
}
const entries = await fs.readDir(modelsDir)
for (const entry of entries) {
if (entry.name?.endsWith('.gguf') && entry.isFile) {
const modelPath = await path.join(modelsDir, entry.name)
const stats = await fs.stat(modelPath)
const parsedName = parseGGUFFileName(entry.name)
result.push({
id: `${parsedName.baseModelId}${parsedName.quant ? `/${parsedName.quant}` : ''}`, // e.g., "mistral-7b/Q4_0"
name: entry.name.replace('.gguf', ''), // Or a more human-friendly name
quant_type: parsedName.quant,
providerId: this.providerId,
sizeBytes: stats.size,
path: modelPath,
tags: [this.providerId, parsedName.quant || 'unknown_quant'].filter(
Boolean
) as string[],
}) })
}
}
} catch (error) {
console.error(`[${this.providerId}] Error listing models:`, error)
// Depending on desired behavior, either throw or return empty/partial list
}
return result
}
// pullModel
async pullModel(opts: PullOptions): Promise<PullResult> {
// TODO: Implement pullModel
return 0;
}
// abortPull
async abortPull(opts: AbortPullOptions): Promise<AbortPullResult> {
// TODO: implement abortPull
}
async load(opts: LoadOptions): Promise<SessionInfo> {
if (opts.providerId !== this.providerId) {
throw new Error('Invalid providerId for LlamaCppProvider.loadModel')
}
const sessionId = uuidv4()
const loadParams = {
model_path: opts.modelPath,
session_id: sessionId, // Pass sessionId to Rust for tracking
// Default llama.cpp server options, can be overridden by opts.options
port: opts.options?.port ?? 0, // 0 for dynamic port assignment by OS
n_gpu_layers:
opts.options?.n_gpu_layers ??
(await this.getSetting(LlamaCppSettings.DefaultNGpuLayers, -1)),
n_ctx:
opts.options?.n_ctx ??
(await this.getSetting(LlamaCppSettings.DefaultNContext, 2048)),
// Spread any other options from opts.options
...(opts.options || {}),
}
try {
console.log(
`[${this.providerId}] Requesting to load model: ${opts.modelPath} with options:`,
loadParams
)
// This matches the Rust handler: core::utils::extensions::inference_llamacpp_extension::server::load
const rustResponse: {
session_id: string
port: number
model_path: string
settings: Record<string, unknown>
} = await invoke('plugin:llamacpp|load', { params: loadParams }) // Adjust namespace if needed
if (!rustResponse || !rustResponse.port) {
throw new Error(
'Rust load function did not return expected port or session info.'
)
}
const sessionInfo: SessionInfo = {
sessionId: rustResponse.session_id, // Use sessionId from Rust if it regenerates/confirms it
port: rustResponse.port,
modelPath: rustResponse.model_path,
providerId: this.providerId,
settings: rustResponse.settings, // Settings actually used by the server
}
// Store the session info for later use
this.activeSessions.set(sessionInfo.sessionId, sessionInfo) this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
console.log(
`[${this.providerId}] Model loaded: ${sessionInfo.modelPath} on port ${sessionInfo.port}, session: ${sessionInfo.sessionId}`
)
return sessionInfo return sessionInfo
} catch (error) { } catch (error) {
console.error( console.error('Error loading llama-server:', error)
`[${this.providerId}] Error loading model ${opts.modelPath}:`, throw new Error(`Failed to load llama-server: ${error}`)
error
)
throw error // Re-throw to be handled by the caller
}
}
async unload(opts: UnloadOptions): Promise<UnloadResult> {
if (opts.providerId !== this.providerId) {
return { success: false, error: 'Invalid providerId' }
}
const session = this.activeSessions.get(opts.sessionId)
if (!session) {
return {
success: false,
error: `No active session found for id: ${opts.sessionId}`,
} }
} }
async unload(opts: unloadOptions): Promise<unloadResult> {
try { try {
console.log( // Pass the PID as the session_id
`[${this.providerId}] Requesting to unload model for session: ${opts.sessionId}` const result = await invoke<unloadResult>('plugin:llamacpp|unload', {
) session_id: opts.sessionId, // Using PID as session ID
// Matches: core::utils::extensions::inference_llamacpp_extension::server::unload })
const rustResponse: { success: boolean; error?: string } = await invoke(
'plugin:llamacpp|unload',
{ sessionId: opts.sessionId }
)
if (rustResponse.success) { // If successful, remove from active sessions
if (result.success) {
this.activeSessions.delete(opts.sessionId) this.activeSessions.delete(opts.sessionId)
console.log( console.log(`Successfully unloaded model with PID ${opts.sessionId}`)
`[${this.providerId}] Session ${opts.sessionId} unloaded successfully.`
)
return { success: true }
} else { } else {
console.error( console.warn(`Failed to unload model: ${result.error}`)
`[${this.providerId}] Failed to unload session ${opts.sessionId}: ${rustResponse.error}` }
)
return result
} catch (error) {
console.error('Error in unload command:', error)
return { return {
success: false, success: false,
error: rustResponse.error || 'Unknown error during unload', error: `Failed to unload model: ${error}`,
} }
} }
} catch (error: any) {
console.error(
`[${this.providerId}] Error invoking unload for session ${opts.sessionId}:`,
error
)
return { success: false, error: error.message || String(error) }
}
} }
async chat( async chat(
opts: ChatOptions opts: chatOptions
): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>> {} ): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
const sessionInfo = this.activeSessions.get(opts.sessionId)
async deleteModel(opts: DeleteOptions): Promise<DeleteResult> {} if (!sessionInfo) {
throw new Error(
async importModel(opts: ImportOptions): Promise<ImportResult> {} `No active session found for sessionId: ${opts.sessionId}`
override async loadModel(model: Model): Promise<any> {
if (model.engine?.toString() !== this.provider) return Promise.resolve()
console.log(
`[${this.providerId} AIEngine] Received OnModelInit for:`,
model.id
) )
return super.load(model)
} }
override async unloadModel(model?: Model): Promise<any> { // For streaming responses
if (model?.engine && model.engine.toString() !== this.provider) if (opts.stream) {
return Promise.resolve() return this.streamChat(opts)
console.log( }
`[${this.providerId} AIEngine] Received OnModelStop for:`,
model?.id || 'all models' // For non-streaming responses
) try {
return super.unload(model) return await invoke<chatCompletion>('plugin:llamacpp|chat', { opts })
} catch (error) {
console.error('Error during chat completion:', error)
throw new Error(`Chat completion failed: ${error}`)
}
}
async delete(opts: deleteOptions): Promise<deleteResult> {
throw new Error("method not implemented yet")
}
async import(opts: importOptions): Promise<importResult> {
throw new Error("method not implemented yet")
}
async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
throw new Error('method not implemented yet')
}
// Optional method for direct client access
getChatClient(sessionId: string): any {
throw new Error("method not implemented yet")
}
onUnload(): void {
throw new Error('Method not implemented.')
} }
} }

View File

@ -2,7 +2,7 @@
// --- Re-using OpenAI types (minimal definitions for this example) --- // --- Re-using OpenAI types (minimal definitions for this example) ---
// In a real project, you'd import these from 'openai' or a shared types package. // In a real project, you'd import these from 'openai' or a shared types package.
export interface ChatCompletionRequestMessage { export interface chatCompletionRequestMessage {
role: 'system' | 'user' | 'assistant' | 'tool'; role: 'system' | 'user' | 'assistant' | 'tool';
content: string | null; content: string | null;
name?: string; name?: string;
@ -10,9 +10,9 @@ export interface ChatCompletionRequestMessage {
tool_call_id?: string; tool_call_id?: string;
} }
export interface ChatCompletionRequest { export interface chatCompletionRequest {
model: string; // Model ID, though for local it might be implicit via sessionId model: string; // Model ID, though for local it might be implicit via sessionId
messages: ChatCompletionRequestMessage[]; messages: chatCompletionRequestMessage[];
temperature?: number | null; temperature?: number | null;
top_p?: number | null; top_p?: number | null;
n?: number | null; n?: number | null;
@ -26,41 +26,41 @@ export interface ChatCompletionRequest {
// ... TODO: other OpenAI params // ... TODO: other OpenAI params
} }
export interface ChatCompletionChunkChoiceDelta { export interface chatCompletionChunkChoiceDelta {
content?: string | null; content?: string | null;
role?: 'system' | 'user' | 'assistant' | 'tool'; role?: 'system' | 'user' | 'assistant' | 'tool';
tool_calls?: any[]; // Simplified tool_calls?: any[]; // Simplified
} }
export interface ChatCompletionChunkChoice { export interface chatCompletionChunkChoice {
index: number; index: number;
delta: ChatCompletionChunkChoiceDelta; delta: chatCompletionChunkChoiceDelta;
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null; finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
} }
export interface ChatCompletionChunk { export interface chatCompletionChunk {
id: string; id: string;
object: 'chat.completion.chunk'; object: 'chat.completion.chunk';
created: number; created: number;
model: string; model: string;
choices: ChatCompletionChunkChoice[]; choices: chatCompletionChunkChoice[];
system_fingerprint?: string; system_fingerprint?: string;
} }
export interface ChatCompletionChoice { export interface chatCompletionChoice {
index: number; index: number;
message: ChatCompletionRequestMessage; // Response message message: chatCompletionRequestMessage; // Response message
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'; finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
logprobs?: any; // Simplified logprobs?: any; // Simplified
} }
export interface ChatCompletion { export interface chatCompletion {
id: string; id: string;
object: 'chat.completion'; object: 'chat.completion';
created: number; created: number;
model: string; // Model ID used model: string; // Model ID used
choices: ChatCompletionChoice[]; choices: chatCompletionChoice[];
usage?: { usage?: {
prompt_tokens: number; prompt_tokens: number;
completion_tokens: number; completion_tokens: number;
@ -72,7 +72,7 @@ export interface ChatCompletion {
// Shared model metadata // Shared model metadata
export interface ModelInfo { export interface modelInfo {
id: string; // e.g. "qwen3-4B" or "org/model/quant" id: string; // e.g. "qwen3-4B" or "org/model/quant"
name: string; // humanreadable, e.g., "Qwen3 4B Q4_0" name: string; // humanreadable, e.g., "Qwen3 4B Q4_0"
quant_type?: string; // q4_0 (optional as it might be part of ID or name) quant_type?: string; // q4_0 (optional as it might be part of ID or name)
@ -86,24 +86,24 @@ export interface ModelInfo {
} }
// 1. /list // 1. /list
export interface ListOptions { export interface listOptions {
providerId: string; // To specify which provider if a central manager calls this providerId: string; // To specify which provider if a central manager calls this
} }
export type ListResult = ModelInfo[]; export type listResult = ModelInfo[];
// 2. /pull // 2. /pull
export interface PullOptions { export interface pullOptions {
providerId: string; providerId: string;
modelId: string; // Identifier for the model to pull (e.g., from a known registry) modelId: string; // Identifier for the model to pull (e.g., from a known registry)
downloadUrl: string; // URL to download the model from downloadUrl: string; // URL to download the model from
/** optional callback to receive download progress */ /** optional callback to receive download progress */
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void; onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
} }
export interface PullResult { export interface pullResult {
success: boolean; success: boolean;
path?: string; // local file path to the pulled model path?: string; // local file path to the pulled model
error?: string; error?: string;
modelInfo?: ModelInfo; // Info of the pulled model modelInfo?: modelInfo; // Info of the pulled model
} }
// 3. /load // 3. /load
@ -135,7 +135,7 @@ export interface loadOptions {
rope_freq_scale?: number rope_freq_scale?: number
} }
export interface SessionInfo { export interface sessionInfo {
sessionId: string; // opaque handle for unload/chat sessionId: string; // opaque handle for unload/chat
port: number; // llama-server output port (corrected from portid) port: number; // llama-server output port (corrected from portid)
modelPath: string; // path of the loaded model modelPath: string; // path of the loaded model
@ -143,71 +143,71 @@ export interface SessionInfo {
} }
// 4. /unload // 4. /unload
export interface UnloadOptions { export interface unloadOptions {
providerId: string; providerId: string;
sessionId: string; sessionId: string;
} }
export interface UnloadResult { export interface unloadResult {
success: boolean; success: boolean;
error?: string; error?: string;
} }
// 5. /chat // 5. /chat
export interface ChatOptions { export interface chatOptions {
providerId: string; providerId: string;
sessionId: string; sessionId: string;
/** Full OpenAI ChatCompletionRequest payload */ /** Full OpenAI ChatCompletionRequest payload */
payload: ChatCompletionRequest; payload: chatCompletionRequest;
} }
// Output for /chat will be Promise<ChatCompletion> for non-streaming // Output for /chat will be Promise<ChatCompletion> for non-streaming
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming // or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
// 6. /delete // 6. /delete
export interface DeleteOptions { export interface deleteOptions {
providerId: string; providerId: string;
modelId: string; // The ID of the model to delete (implies finding its path) modelId: string; // The ID of the model to delete (implies finding its path)
modelPath?: string; // Optionally, direct path can be provided modelPath?: string; // Optionally, direct path can be provided
} }
export interface DeleteResult { export interface deleteResult {
success: boolean; success: boolean;
error?: string; error?: string;
} }
// 7. /import // 7. /import
export interface ImportOptions { export interface importOptions {
providerId: string; providerId: string;
sourcePath: string; // Path to the local model file to import sourcePath: string; // Path to the local model file to import
desiredModelId?: string; // Optional: if user wants to name it specifically desiredModelId?: string; // Optional: if user wants to name it specifically
} }
export interface ImportResult { export interface importResult {
success: boolean; success: boolean;
modelInfo?: ModelInfo; modelInfo?: modelInfo;
error?: string; error?: string;
} }
// 8. /abortPull // 8. /abortPull
export interface AbortPullOptions { export interface abortPullOptions {
providerId: string; providerId: string;
modelId: string; // The modelId whose download is to be aborted modelId: string; // The modelId whose download is to be aborted
} }
export interface AbortPullResult { export interface abortPullResult {
success: boolean; success: boolean;
error?: string; error?: string;
} }
// The interface for any local provider // The interface for any local provider
export interface LocalProvider { export interface localProvider {
readonly providerId: string; readonly providerId: string;
listModels(opts: ListOptions): Promise<ListResult>; list(opts: listOptions): Promise<listResult>;
pullModel(opts: PullOptions): Promise<PullResult>; pull(opts: pullOptions): Promise<pullResult>;
loadModel(opts: LoadOptions): Promise<SessionInfo>; load(opts: loadOptions): Promise<sessionInfo>;
unloadModel(opts: UnloadOptions): Promise<UnloadResult>; unload(opts: unloadOptions): Promise<unloadResult>;
chat(opts: ChatOptions): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>>; chat(opts: chatOptions): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>;
deleteModel(opts: DeleteOptions): Promise<DeleteResult>; delete(opts: deleteOptions): Promise<deleteResult>;
importModel(opts: ImportOptions): Promise<ImportResult>; import(opts: importOptions): Promise<importResult>;
abortPull(opts: AbortPullOptions): Promise<AbortPullResult>; abortPull(opts: abortPullOptions): Promise<abortPullResult>;
// Optional: for direct access to underlying client if needed for specific streaming cases // Optional: for direct access to underlying client if needed for specific streaming cases
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session

View File

@ -3,7 +3,6 @@ use serde::{Serialize, Deserialize};
use tauri::path::BaseDirectory; use tauri::path::BaseDirectory;
use tauri::{AppHandle, Manager, State}; // Import Manager trait use tauri::{AppHandle, Manager, State}; // Import Manager trait
use tokio::process::Command; use tokio::process::Command;
use std::collections::HashMap;
use uuid::Uuid; use uuid::Uuid;
use thiserror; use thiserror;
@ -70,7 +69,12 @@ pub struct SessionInfo {
pub session_id: String, // opaque handle for unload/chat pub session_id: String, // opaque handle for unload/chat
pub port: u16, // llama-server output port pub port: u16, // llama-server output port
pub model_path: String, // path of the loaded model pub model_path: String, // path of the loaded model
pub settings: HashMap<String, serde_json::Value>, // The actual settings used to load }
#[derive(serde::Serialize, serde::Deserialize)]
pub struct UnloadResult {
success: bool,
error: Option<String>,
} }
// --- Load Command --- // --- Load Command ---
@ -102,40 +106,12 @@ pub async fn load(
))); )));
} }
let mut port = 8080; // Default port let port = 8080; // Default port
let mut model_path = String::new();
let mut settings: HashMap<String, serde_json::Value> = HashMap::new();
// Extract arguments into settings map and specific fields
let mut i = 0;
while i < args.len() {
if args[i] == "--port" && i + 1 < args.len() {
if let Ok(p) = args[i + 1].parse::<u16>() {
port = p;
}
settings.insert("port".to_string(), serde_json::Value::String(args[i + 1].clone()));
i += 2;
} else if args[i] == "-m" && i + 1 < args.len() {
model_path = args[i + 1].clone();
settings.insert("modelPath".to_string(), serde_json::Value::String(model_path.clone()));
i += 2;
} else if i + 1 < args.len() && args[i].starts_with("-") {
// Store other arguments as settings
let key = args[i].trim_start_matches("-").trim_start_matches("-");
settings.insert(key.to_string(), serde_json::Value::String(args[i + 1].clone()));
i += 2;
} else {
// Handle boolean flags
if args[i].starts_with("-") {
let key = args[i].trim_start_matches("-").trim_start_matches("-");
settings.insert(key.to_string(), serde_json::Value::Bool(true));
}
i += 1;
}
}
// Configure the command to run the server // Configure the command to run the server
let mut command = Command::new(server_path); let mut command = Command::new(server_path);
let model_path = args[0].replace("-m", "");
command.args(args); command.args(args);
// Optional: Redirect stdio if needed (e.g., for logging within Jan) // Optional: Redirect stdio if needed (e.g., for logging within Jan)
@ -145,17 +121,21 @@ pub async fn load(
// Spawn the child process // Spawn the child process
let child = command.spawn().map_err(ServerError::Io)?; let child = command.spawn().map_err(ServerError::Io)?;
log::info!("Server process started with PID: {:?}", child.id()); // Get the PID to use as session ID
let pid = child.id().map(|id| id.to_string()).unwrap_or_else(|| {
// Fallback in case we can't get the PID for some reason
format!("unknown_pid_{}", Uuid::new_v4())
});
log::info!("Server process started with PID: {}", pid);
// Store the child process handle in the state // Store the child process handle in the state
*process_lock = Some(child); *process_lock = Some(child);
let session_id = format!("session_{}", Uuid::new_v4());
let session_info = SessionInfo { let session_info = SessionInfo {
session_id, session_id: pid, // Use PID as session ID
port, port,
model_path, model_path,
settings,
}; };
Ok(session_info) Ok(session_info)
@ -163,32 +143,65 @@ pub async fn load(
// --- Unload Command --- // --- Unload Command ---
#[tauri::command] #[tauri::command]
pub async fn unload(state: State<'_, AppState>) -> ServerResult<()> { pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
let mut process_lock = state.llama_server_process.lock().await; let mut process_lock = state.llama_server_process.lock().await;
// Take the child process out of the Option, leaving None in its place // Take the child process out of the Option, leaving None in its place
if let Some(mut child) = process_lock.take() { if let Some(mut child) = process_lock.take() {
// Convert the PID to a string to compare with the session_id
let process_pid = child.id().map(|pid| pid.to_string()).unwrap_or_default();
// Check if the session_id matches the PID
if session_id != process_pid && !session_id.is_empty() && !process_pid.is_empty() {
// Put the process back in the lock since we're not killing it
*process_lock = Some(child);
log::warn!(
"Session ID mismatch: provided {} vs process {}",
session_id,
process_pid
);
return Ok(UnloadResult {
success: false,
error: Some(format!("Session ID mismatch: provided {} doesn't match process {}",
session_id, process_pid)),
});
}
log::info!( log::info!(
"Attempting to terminate server process with PID: {:?}", "Attempting to terminate server process with PID: {:?}",
child.id() child.id()
); );
// Kill the process // Kill the process
// `start_kill` is preferred in async contexts
match child.start_kill() { match child.start_kill() {
Ok(_) => { Ok(_) => {
log::info!("Server process termination signal sent."); log::info!("Server process termination signal sent successfully");
Ok(())
Ok(UnloadResult {
success: true,
error: None,
})
} }
Err(e) => { Err(e) => {
// For simplicity, we log and return error.
log::error!("Failed to kill server process: {}", e); log::error!("Failed to kill server process: {}", e);
// Put it back? Maybe not useful if kill failed.
// *process_lock = Some(child); // Return formatted error
Err(ServerError::Io(e)) Ok(UnloadResult {
success: false,
error: Some(format!("Failed to kill server process: {}", e)),
})
} }
} }
} else { } else {
log::warn!("Attempted to unload server, but it was not running."); log::warn!("Attempted to unload server, but no process was running");
Ok(())
// If no process is running but client thinks there is,
// still report success since the end state is what they wanted
Ok(UnloadResult {
success: true,
error: None,
})
} }
} }