refactor load/unload
This commit is contained in:
parent
b4670b5526
commit
bbbf4779df
@ -16,9 +16,6 @@ export abstract class AIEngine extends BaseExtension {
|
||||
*/
|
||||
override onLoad() {
|
||||
this.registerEngine()
|
||||
|
||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
||||
}
|
||||
|
||||
/**
|
||||
@ -27,31 +24,4 @@ export abstract class AIEngine extends BaseExtension {
|
||||
registerEngine() {
|
||||
EngineManager.instance().register(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads the model.
|
||||
*/
|
||||
async loadModel(model: Partial<Model>, abortController?: AbortController): Promise<any> {
|
||||
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelReady, model)
|
||||
return Promise.resolve()
|
||||
}
|
||||
/**
|
||||
* Stops the model.
|
||||
*/
|
||||
async unloadModel(model?: Partial<Model>): Promise<any> {
|
||||
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
|
||||
events.emit(ModelEvent.OnModelStopped, model ?? {})
|
||||
return Promise.resolve()
|
||||
}
|
||||
|
||||
/**
|
||||
* Inference request
|
||||
*/
|
||||
inference(data: MessageRequest) {}
|
||||
|
||||
/**
|
||||
* Stop inference
|
||||
*/
|
||||
stopInference() {}
|
||||
}
|
||||
|
||||
@ -6,35 +6,30 @@
|
||||
* @module llamacpp-extension/src/index
|
||||
*/
|
||||
|
||||
import {
|
||||
AIEngine,
|
||||
getJanDataFolderPath,
|
||||
fs,
|
||||
Model,
|
||||
} from '@janhq/core'
|
||||
import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core'
|
||||
|
||||
import { invoke } from '@tauri-apps/api/tauri'
|
||||
import {
|
||||
LocalProvider,
|
||||
ModelInfo,
|
||||
ListOptions,
|
||||
ListResult,
|
||||
PullOptions,
|
||||
PullResult,
|
||||
LoadOptions,
|
||||
SessionInfo,
|
||||
UnloadOptions,
|
||||
UnloadResult,
|
||||
ChatOptions,
|
||||
ChatCompletion,
|
||||
ChatCompletionChunk,
|
||||
DeleteOptions,
|
||||
DeleteResult,
|
||||
ImportOptions,
|
||||
ImportResult,
|
||||
AbortPullOptions,
|
||||
AbortPullResult,
|
||||
ChatCompletionRequest,
|
||||
localProvider,
|
||||
modelInfo,
|
||||
listOptions,
|
||||
listResult,
|
||||
pullOptions,
|
||||
pullResult,
|
||||
loadOptions,
|
||||
sessionInfo,
|
||||
unloadOptions,
|
||||
unloadResult,
|
||||
chatOptions,
|
||||
chatCompletion,
|
||||
chatCompletionChunk,
|
||||
deleteOptions,
|
||||
deleteResult,
|
||||
importOptions,
|
||||
importResult,
|
||||
abortPullOptions,
|
||||
abortPullResult,
|
||||
chatCompletionRequest,
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
@ -61,246 +56,224 @@ function parseGGUFFileName(filename: string): {
|
||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||
*/
|
||||
export default class inference_llamacpp_extension
|
||||
export default class llamacpp_extension
|
||||
extends AIEngine
|
||||
implements LocalProvider
|
||||
implements localProvider
|
||||
{
|
||||
provider: string = 'llamacpp'
|
||||
readonly providerId: string = 'llamcpp'
|
||||
|
||||
private activeSessions: Map<string, SessionInfo> = new Map()
|
||||
readonly providerId: string = 'llamacpp'
|
||||
|
||||
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||
private modelsBasePath!: string
|
||||
private activeRequests: Map<string, AbortController> = new Map()
|
||||
|
||||
override async onLoad(): Promise<void> {
|
||||
super.onLoad() // Calls registerEngine() from AIEngine
|
||||
this.registerSettings(SETTINGS_DEFINITIONS)
|
||||
this.registerSettings(SETTINGS)
|
||||
|
||||
const customPath = await this.getSetting<string>(
|
||||
LlamaCppSettings.ModelsPath,
|
||||
''
|
||||
)
|
||||
if (customPath && (await fs.exists(customPath))) {
|
||||
this.modelsBasePath = customPath
|
||||
} else {
|
||||
this.modelsBasePath = await path.join(
|
||||
// Initialize models base path - assuming this would be retrieved from settings
|
||||
this.modelsBasePath = await joinPath([
|
||||
await getJanDataFolderPath(),
|
||||
'models',
|
||||
ENGINE_ID
|
||||
)
|
||||
}
|
||||
await fs.createDirAll(this.modelsBasePath)
|
||||
|
||||
console.log(
|
||||
`${this.providerId} provider loaded. Models path: ${this.modelsBasePath}`
|
||||
)
|
||||
|
||||
// Optionally, list and register models with the core system if AIEngine expects it
|
||||
// const models = await this.listModels({ providerId: this.providerId });
|
||||
// this.registerModels(this.mapModelInfoToCoreModel(models)); // mapModelInfoToCoreModel would be a helper
|
||||
])
|
||||
}
|
||||
|
||||
async getModelsPath(): Promise<string> {
|
||||
// Ensure modelsBasePath is initialized
|
||||
if (!this.modelsBasePath) {
|
||||
const customPath = await this.getSetting<string>(
|
||||
LlamaCppSettings.ModelsPath,
|
||||
''
|
||||
)
|
||||
if (customPath && (await fs.exists(customPath))) {
|
||||
this.modelsBasePath = customPath
|
||||
// Implement the required LocalProvider interface methods
|
||||
async list(opts: listOptions): Promise<listResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async pull(opts: pullOptions): Promise<pullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
async load(opts: loadOptions): Promise<sessionInfo> {
|
||||
const args: string[] = []
|
||||
|
||||
// model option is required
|
||||
args.push('-m', opts.modelPath)
|
||||
args.push('--port', String(opts.port || 8080)) // Default port if not specified
|
||||
|
||||
if (opts.n_gpu_layers === undefined) {
|
||||
// in case of CPU only build, this option will be ignored
|
||||
args.push('-ngl', '99')
|
||||
} else {
|
||||
this.modelsBasePath = await path.join(
|
||||
await getJanDataFolderPath(),
|
||||
'models',
|
||||
ENGINE_ID
|
||||
)
|
||||
}
|
||||
await fs.createDirAll(this.modelsBasePath)
|
||||
}
|
||||
return this.modelsBasePath
|
||||
args.push('-ngl', String(opts.n_gpu_layers))
|
||||
}
|
||||
|
||||
async listModels(_opts: ListOptions): Promise<ListResult> {
|
||||
const modelsDir = await this.getModelsPath()
|
||||
const result: ModelInfo[] = []
|
||||
if (opts.n_ctx !== undefined) {
|
||||
args.push('-c', String(opts.n_ctx))
|
||||
}
|
||||
|
||||
// Add remaining options from the interface
|
||||
if (opts.threads !== undefined) {
|
||||
args.push('--threads', String(opts.threads))
|
||||
}
|
||||
|
||||
if (opts.threads_batch !== undefined) {
|
||||
args.push('--threads-batch', String(opts.threads_batch))
|
||||
}
|
||||
|
||||
if (opts.ctx_size !== undefined) {
|
||||
args.push('--ctx-size', String(opts.ctx_size))
|
||||
}
|
||||
|
||||
if (opts.n_predict !== undefined) {
|
||||
args.push('--n-predict', String(opts.n_predict))
|
||||
}
|
||||
|
||||
if (opts.batch_size !== undefined) {
|
||||
args.push('--batch-size', String(opts.batch_size))
|
||||
}
|
||||
|
||||
if (opts.ubatch_size !== undefined) {
|
||||
args.push('--ubatch-size', String(opts.ubatch_size))
|
||||
}
|
||||
|
||||
if (opts.device !== undefined) {
|
||||
args.push('--device', opts.device)
|
||||
}
|
||||
|
||||
if (opts.split_mode !== undefined) {
|
||||
args.push('--split-mode', opts.split_mode)
|
||||
}
|
||||
|
||||
if (opts.main_gpu !== undefined) {
|
||||
args.push('--main-gpu', String(opts.main_gpu))
|
||||
}
|
||||
|
||||
// Boolean flags
|
||||
if (opts.flash_attn === true) {
|
||||
args.push('--flash-attn')
|
||||
}
|
||||
|
||||
if (opts.cont_batching === true) {
|
||||
args.push('--cont-batching')
|
||||
}
|
||||
|
||||
if (opts.no_mmap === true) {
|
||||
args.push('--no-mmap')
|
||||
}
|
||||
|
||||
if (opts.mlock === true) {
|
||||
args.push('--mlock')
|
||||
}
|
||||
|
||||
if (opts.no_kv_offload === true) {
|
||||
args.push('--no-kv-offload')
|
||||
}
|
||||
|
||||
if (opts.cache_type_k !== undefined) {
|
||||
args.push('--cache-type-k', opts.cache_type_k)
|
||||
}
|
||||
|
||||
if (opts.cache_type_v !== undefined) {
|
||||
args.push('--cache-type-v', opts.cache_type_v)
|
||||
}
|
||||
|
||||
if (opts.defrag_thold !== undefined) {
|
||||
args.push('--defrag-thold', String(opts.defrag_thold))
|
||||
}
|
||||
|
||||
if (opts.rope_scaling !== undefined) {
|
||||
args.push('--rope-scaling', opts.rope_scaling)
|
||||
}
|
||||
|
||||
if (opts.rope_scale !== undefined) {
|
||||
args.push('--rope-scale', String(opts.rope_scale))
|
||||
}
|
||||
|
||||
if (opts.rope_freq_base !== undefined) {
|
||||
args.push('--rope-freq-base', String(opts.rope_freq_base))
|
||||
}
|
||||
|
||||
if (opts.rope_freq_scale !== undefined) {
|
||||
args.push('--rope-freq-scale', String(opts.rope_freq_scale))
|
||||
}
|
||||
console.log('Calling Tauri command load with args:', args)
|
||||
|
||||
try {
|
||||
if (!(await fs.exists(modelsDir))) {
|
||||
await fs.createDirAll(modelsDir)
|
||||
return []
|
||||
}
|
||||
|
||||
const entries = await fs.readDir(modelsDir)
|
||||
for (const entry of entries) {
|
||||
if (entry.name?.endsWith('.gguf') && entry.isFile) {
|
||||
const modelPath = await path.join(modelsDir, entry.name)
|
||||
const stats = await fs.stat(modelPath)
|
||||
const parsedName = parseGGUFFileName(entry.name)
|
||||
|
||||
result.push({
|
||||
id: `${parsedName.baseModelId}${parsedName.quant ? `/${parsedName.quant}` : ''}`, // e.g., "mistral-7b/Q4_0"
|
||||
name: entry.name.replace('.gguf', ''), // Or a more human-friendly name
|
||||
quant_type: parsedName.quant,
|
||||
providerId: this.providerId,
|
||||
sizeBytes: stats.size,
|
||||
path: modelPath,
|
||||
tags: [this.providerId, parsedName.quant || 'unknown_quant'].filter(
|
||||
Boolean
|
||||
) as string[],
|
||||
const sessionInfo = await invoke<sessionInfo>('plugin:llamacpp|load', {
|
||||
args: args,
|
||||
})
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[${this.providerId}] Error listing models:`, error)
|
||||
// Depending on desired behavior, either throw or return empty/partial list
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// pullModel
|
||||
async pullModel(opts: PullOptions): Promise<PullResult> {
|
||||
// TODO: Implement pullModel
|
||||
return 0;
|
||||
}
|
||||
|
||||
// abortPull
|
||||
async abortPull(opts: AbortPullOptions): Promise<AbortPullResult> {
|
||||
// TODO: implement abortPull
|
||||
}
|
||||
|
||||
async load(opts: LoadOptions): Promise<SessionInfo> {
|
||||
if (opts.providerId !== this.providerId) {
|
||||
throw new Error('Invalid providerId for LlamaCppProvider.loadModel')
|
||||
}
|
||||
|
||||
const sessionId = uuidv4()
|
||||
const loadParams = {
|
||||
model_path: opts.modelPath,
|
||||
session_id: sessionId, // Pass sessionId to Rust for tracking
|
||||
// Default llama.cpp server options, can be overridden by opts.options
|
||||
port: opts.options?.port ?? 0, // 0 for dynamic port assignment by OS
|
||||
n_gpu_layers:
|
||||
opts.options?.n_gpu_layers ??
|
||||
(await this.getSetting(LlamaCppSettings.DefaultNGpuLayers, -1)),
|
||||
n_ctx:
|
||||
opts.options?.n_ctx ??
|
||||
(await this.getSetting(LlamaCppSettings.DefaultNContext, 2048)),
|
||||
// Spread any other options from opts.options
|
||||
...(opts.options || {}),
|
||||
}
|
||||
|
||||
try {
|
||||
console.log(
|
||||
`[${this.providerId}] Requesting to load model: ${opts.modelPath} with options:`,
|
||||
loadParams
|
||||
)
|
||||
// This matches the Rust handler: core::utils::extensions::inference_llamacpp_extension::server::load
|
||||
const rustResponse: {
|
||||
session_id: string
|
||||
port: number
|
||||
model_path: string
|
||||
settings: Record<string, unknown>
|
||||
} = await invoke('plugin:llamacpp|load', { params: loadParams }) // Adjust namespace if needed
|
||||
|
||||
if (!rustResponse || !rustResponse.port) {
|
||||
throw new Error(
|
||||
'Rust load function did not return expected port or session info.'
|
||||
)
|
||||
}
|
||||
|
||||
const sessionInfo: SessionInfo = {
|
||||
sessionId: rustResponse.session_id, // Use sessionId from Rust if it regenerates/confirms it
|
||||
port: rustResponse.port,
|
||||
modelPath: rustResponse.model_path,
|
||||
providerId: this.providerId,
|
||||
settings: rustResponse.settings, // Settings actually used by the server
|
||||
}
|
||||
|
||||
// Store the session info for later use
|
||||
this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
|
||||
console.log(
|
||||
`[${this.providerId}] Model loaded: ${sessionInfo.modelPath} on port ${sessionInfo.port}, session: ${sessionInfo.sessionId}`
|
||||
)
|
||||
|
||||
return sessionInfo
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`[${this.providerId}] Error loading model ${opts.modelPath}:`,
|
||||
error
|
||||
)
|
||||
throw error // Re-throw to be handled by the caller
|
||||
}
|
||||
}
|
||||
|
||||
async unload(opts: UnloadOptions): Promise<UnloadResult> {
|
||||
if (opts.providerId !== this.providerId) {
|
||||
return { success: false, error: 'Invalid providerId' }
|
||||
}
|
||||
const session = this.activeSessions.get(opts.sessionId)
|
||||
if (!session) {
|
||||
return {
|
||||
success: false,
|
||||
error: `No active session found for id: ${opts.sessionId}`,
|
||||
console.error('Error loading llama-server:', error)
|
||||
throw new Error(`Failed to load llama-server: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
async unload(opts: unloadOptions): Promise<unloadResult> {
|
||||
try {
|
||||
console.log(
|
||||
`[${this.providerId}] Requesting to unload model for session: ${opts.sessionId}`
|
||||
)
|
||||
// Matches: core::utils::extensions::inference_llamacpp_extension::server::unload
|
||||
const rustResponse: { success: boolean; error?: string } = await invoke(
|
||||
'plugin:llamacpp|unload',
|
||||
{ sessionId: opts.sessionId }
|
||||
)
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<unloadResult>('plugin:llamacpp|unload', {
|
||||
session_id: opts.sessionId, // Using PID as session ID
|
||||
})
|
||||
|
||||
if (rustResponse.success) {
|
||||
// If successful, remove from active sessions
|
||||
if (result.success) {
|
||||
this.activeSessions.delete(opts.sessionId)
|
||||
console.log(
|
||||
`[${this.providerId}] Session ${opts.sessionId} unloaded successfully.`
|
||||
)
|
||||
return { success: true }
|
||||
console.log(`Successfully unloaded model with PID ${opts.sessionId}`)
|
||||
} else {
|
||||
console.error(
|
||||
`[${this.providerId}] Failed to unload session ${opts.sessionId}: ${rustResponse.error}`
|
||||
)
|
||||
console.warn(`Failed to unload model: ${result.error}`)
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
console.error('Error in unload command:', error)
|
||||
return {
|
||||
success: false,
|
||||
error: rustResponse.error || 'Unknown error during unload',
|
||||
error: `Failed to unload model: ${error}`,
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error(
|
||||
`[${this.providerId}] Error invoking unload for session ${opts.sessionId}:`,
|
||||
error
|
||||
)
|
||||
return { success: false, error: error.message || String(error) }
|
||||
}
|
||||
}
|
||||
|
||||
async chat(
|
||||
opts: ChatOptions
|
||||
): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>> {}
|
||||
|
||||
async deleteModel(opts: DeleteOptions): Promise<DeleteResult> {}
|
||||
|
||||
async importModel(opts: ImportOptions): Promise<ImportResult> {}
|
||||
|
||||
override async loadModel(model: Model): Promise<any> {
|
||||
if (model.engine?.toString() !== this.provider) return Promise.resolve()
|
||||
console.log(
|
||||
`[${this.providerId} AIEngine] Received OnModelInit for:`,
|
||||
model.id
|
||||
opts: chatOptions
|
||||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||
const sessionInfo = this.activeSessions.get(opts.sessionId)
|
||||
if (!sessionInfo) {
|
||||
throw new Error(
|
||||
`No active session found for sessionId: ${opts.sessionId}`
|
||||
)
|
||||
return super.load(model)
|
||||
}
|
||||
|
||||
override async unloadModel(model?: Model): Promise<any> {
|
||||
if (model?.engine && model.engine.toString() !== this.provider)
|
||||
return Promise.resolve()
|
||||
console.log(
|
||||
`[${this.providerId} AIEngine] Received OnModelStop for:`,
|
||||
model?.id || 'all models'
|
||||
)
|
||||
return super.unload(model)
|
||||
// For streaming responses
|
||||
if (opts.stream) {
|
||||
return this.streamChat(opts)
|
||||
}
|
||||
|
||||
// For non-streaming responses
|
||||
try {
|
||||
return await invoke<chatCompletion>('plugin:llamacpp|chat', { opts })
|
||||
} catch (error) {
|
||||
console.error('Error during chat completion:', error)
|
||||
throw new Error(`Chat completion failed: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
async delete(opts: deleteOptions): Promise<deleteResult> {
|
||||
throw new Error("method not implemented yet")
|
||||
}
|
||||
|
||||
async import(opts: importOptions): Promise<importResult> {
|
||||
throw new Error("method not implemented yet")
|
||||
}
|
||||
|
||||
async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
||||
throw new Error('method not implemented yet')
|
||||
}
|
||||
|
||||
// Optional method for direct client access
|
||||
getChatClient(sessionId: string): any {
|
||||
throw new Error("method not implemented yet")
|
||||
}
|
||||
|
||||
onUnload(): void {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
// --- Re-using OpenAI types (minimal definitions for this example) ---
|
||||
// In a real project, you'd import these from 'openai' or a shared types package.
|
||||
export interface ChatCompletionRequestMessage {
|
||||
export interface chatCompletionRequestMessage {
|
||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||
content: string | null;
|
||||
name?: string;
|
||||
@ -10,9 +10,9 @@ export interface ChatCompletionRequestMessage {
|
||||
tool_call_id?: string;
|
||||
}
|
||||
|
||||
export interface ChatCompletionRequest {
|
||||
export interface chatCompletionRequest {
|
||||
model: string; // Model ID, though for local it might be implicit via sessionId
|
||||
messages: ChatCompletionRequestMessage[];
|
||||
messages: chatCompletionRequestMessage[];
|
||||
temperature?: number | null;
|
||||
top_p?: number | null;
|
||||
n?: number | null;
|
||||
@ -26,41 +26,41 @@ export interface ChatCompletionRequest {
|
||||
// ... TODO: other OpenAI params
|
||||
}
|
||||
|
||||
export interface ChatCompletionChunkChoiceDelta {
|
||||
export interface chatCompletionChunkChoiceDelta {
|
||||
content?: string | null;
|
||||
role?: 'system' | 'user' | 'assistant' | 'tool';
|
||||
tool_calls?: any[]; // Simplified
|
||||
}
|
||||
|
||||
export interface ChatCompletionChunkChoice {
|
||||
export interface chatCompletionChunkChoice {
|
||||
index: number;
|
||||
delta: ChatCompletionChunkChoiceDelta;
|
||||
delta: chatCompletionChunkChoiceDelta;
|
||||
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
|
||||
}
|
||||
|
||||
export interface ChatCompletionChunk {
|
||||
export interface chatCompletionChunk {
|
||||
id: string;
|
||||
object: 'chat.completion.chunk';
|
||||
created: number;
|
||||
model: string;
|
||||
choices: ChatCompletionChunkChoice[];
|
||||
choices: chatCompletionChunkChoice[];
|
||||
system_fingerprint?: string;
|
||||
}
|
||||
|
||||
|
||||
export interface ChatCompletionChoice {
|
||||
export interface chatCompletionChoice {
|
||||
index: number;
|
||||
message: ChatCompletionRequestMessage; // Response message
|
||||
message: chatCompletionRequestMessage; // Response message
|
||||
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
|
||||
logprobs?: any; // Simplified
|
||||
}
|
||||
|
||||
export interface ChatCompletion {
|
||||
export interface chatCompletion {
|
||||
id: string;
|
||||
object: 'chat.completion';
|
||||
created: number;
|
||||
model: string; // Model ID used
|
||||
choices: ChatCompletionChoice[];
|
||||
choices: chatCompletionChoice[];
|
||||
usage?: {
|
||||
prompt_tokens: number;
|
||||
completion_tokens: number;
|
||||
@ -72,7 +72,7 @@ export interface ChatCompletion {
|
||||
|
||||
|
||||
// Shared model metadata
|
||||
export interface ModelInfo {
|
||||
export interface modelInfo {
|
||||
id: string; // e.g. "qwen3-4B" or "org/model/quant"
|
||||
name: string; // human‑readable, e.g., "Qwen3 4B Q4_0"
|
||||
quant_type?: string; // q4_0 (optional as it might be part of ID or name)
|
||||
@ -86,24 +86,24 @@ export interface ModelInfo {
|
||||
}
|
||||
|
||||
// 1. /list
|
||||
export interface ListOptions {
|
||||
export interface listOptions {
|
||||
providerId: string; // To specify which provider if a central manager calls this
|
||||
}
|
||||
export type ListResult = ModelInfo[];
|
||||
export type listResult = ModelInfo[];
|
||||
|
||||
// 2. /pull
|
||||
export interface PullOptions {
|
||||
export interface pullOptions {
|
||||
providerId: string;
|
||||
modelId: string; // Identifier for the model to pull (e.g., from a known registry)
|
||||
downloadUrl: string; // URL to download the model from
|
||||
/** optional callback to receive download progress */
|
||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
|
||||
}
|
||||
export interface PullResult {
|
||||
export interface pullResult {
|
||||
success: boolean;
|
||||
path?: string; // local file path to the pulled model
|
||||
error?: string;
|
||||
modelInfo?: ModelInfo; // Info of the pulled model
|
||||
modelInfo?: modelInfo; // Info of the pulled model
|
||||
}
|
||||
|
||||
// 3. /load
|
||||
@ -135,7 +135,7 @@ export interface loadOptions {
|
||||
rope_freq_scale?: number
|
||||
}
|
||||
|
||||
export interface SessionInfo {
|
||||
export interface sessionInfo {
|
||||
sessionId: string; // opaque handle for unload/chat
|
||||
port: number; // llama-server output port (corrected from portid)
|
||||
modelPath: string; // path of the loaded model
|
||||
@ -143,71 +143,71 @@ export interface SessionInfo {
|
||||
}
|
||||
|
||||
// 4. /unload
|
||||
export interface UnloadOptions {
|
||||
export interface unloadOptions {
|
||||
providerId: string;
|
||||
sessionId: string;
|
||||
}
|
||||
export interface UnloadResult {
|
||||
export interface unloadResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 5. /chat
|
||||
export interface ChatOptions {
|
||||
export interface chatOptions {
|
||||
providerId: string;
|
||||
sessionId: string;
|
||||
/** Full OpenAI ChatCompletionRequest payload */
|
||||
payload: ChatCompletionRequest;
|
||||
payload: chatCompletionRequest;
|
||||
}
|
||||
// Output for /chat will be Promise<ChatCompletion> for non-streaming
|
||||
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
|
||||
|
||||
// 6. /delete
|
||||
export interface DeleteOptions {
|
||||
export interface deleteOptions {
|
||||
providerId: string;
|
||||
modelId: string; // The ID of the model to delete (implies finding its path)
|
||||
modelPath?: string; // Optionally, direct path can be provided
|
||||
}
|
||||
export interface DeleteResult {
|
||||
export interface deleteResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 7. /import
|
||||
export interface ImportOptions {
|
||||
export interface importOptions {
|
||||
providerId: string;
|
||||
sourcePath: string; // Path to the local model file to import
|
||||
desiredModelId?: string; // Optional: if user wants to name it specifically
|
||||
}
|
||||
export interface ImportResult {
|
||||
export interface importResult {
|
||||
success: boolean;
|
||||
modelInfo?: ModelInfo;
|
||||
modelInfo?: modelInfo;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// 8. /abortPull
|
||||
export interface AbortPullOptions {
|
||||
export interface abortPullOptions {
|
||||
providerId: string;
|
||||
modelId: string; // The modelId whose download is to be aborted
|
||||
}
|
||||
export interface AbortPullResult {
|
||||
export interface abortPullResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
|
||||
// The interface for any local provider
|
||||
export interface LocalProvider {
|
||||
export interface localProvider {
|
||||
readonly providerId: string;
|
||||
|
||||
listModels(opts: ListOptions): Promise<ListResult>;
|
||||
pullModel(opts: PullOptions): Promise<PullResult>;
|
||||
loadModel(opts: LoadOptions): Promise<SessionInfo>;
|
||||
unloadModel(opts: UnloadOptions): Promise<UnloadResult>;
|
||||
chat(opts: ChatOptions): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>>;
|
||||
deleteModel(opts: DeleteOptions): Promise<DeleteResult>;
|
||||
importModel(opts: ImportOptions): Promise<ImportResult>;
|
||||
abortPull(opts: AbortPullOptions): Promise<AbortPullResult>;
|
||||
list(opts: listOptions): Promise<listResult>;
|
||||
pull(opts: pullOptions): Promise<pullResult>;
|
||||
load(opts: loadOptions): Promise<sessionInfo>;
|
||||
unload(opts: unloadOptions): Promise<unloadResult>;
|
||||
chat(opts: chatOptions): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>;
|
||||
delete(opts: deleteOptions): Promise<deleteResult>;
|
||||
import(opts: importOptions): Promise<importResult>;
|
||||
abortPull(opts: abortPullOptions): Promise<abortPullResult>;
|
||||
|
||||
// Optional: for direct access to underlying client if needed for specific streaming cases
|
||||
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session
|
||||
|
||||
@ -3,7 +3,6 @@ use serde::{Serialize, Deserialize};
|
||||
use tauri::path::BaseDirectory;
|
||||
use tauri::{AppHandle, Manager, State}; // Import Manager trait
|
||||
use tokio::process::Command;
|
||||
use std::collections::HashMap;
|
||||
use uuid::Uuid;
|
||||
use thiserror;
|
||||
|
||||
@ -70,7 +69,12 @@ pub struct SessionInfo {
|
||||
pub session_id: String, // opaque handle for unload/chat
|
||||
pub port: u16, // llama-server output port
|
||||
pub model_path: String, // path of the loaded model
|
||||
pub settings: HashMap<String, serde_json::Value>, // The actual settings used to load
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct UnloadResult {
|
||||
success: bool,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
// --- Load Command ---
|
||||
@ -102,40 +106,12 @@ pub async fn load(
|
||||
)));
|
||||
}
|
||||
|
||||
let mut port = 8080; // Default port
|
||||
let mut model_path = String::new();
|
||||
let mut settings: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
|
||||
// Extract arguments into settings map and specific fields
|
||||
let mut i = 0;
|
||||
while i < args.len() {
|
||||
if args[i] == "--port" && i + 1 < args.len() {
|
||||
if let Ok(p) = args[i + 1].parse::<u16>() {
|
||||
port = p;
|
||||
}
|
||||
settings.insert("port".to_string(), serde_json::Value::String(args[i + 1].clone()));
|
||||
i += 2;
|
||||
} else if args[i] == "-m" && i + 1 < args.len() {
|
||||
model_path = args[i + 1].clone();
|
||||
settings.insert("modelPath".to_string(), serde_json::Value::String(model_path.clone()));
|
||||
i += 2;
|
||||
} else if i + 1 < args.len() && args[i].starts_with("-") {
|
||||
// Store other arguments as settings
|
||||
let key = args[i].trim_start_matches("-").trim_start_matches("-");
|
||||
settings.insert(key.to_string(), serde_json::Value::String(args[i + 1].clone()));
|
||||
i += 2;
|
||||
} else {
|
||||
// Handle boolean flags
|
||||
if args[i].starts_with("-") {
|
||||
let key = args[i].trim_start_matches("-").trim_start_matches("-");
|
||||
settings.insert(key.to_string(), serde_json::Value::Bool(true));
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
let port = 8080; // Default port
|
||||
|
||||
// Configure the command to run the server
|
||||
let mut command = Command::new(server_path);
|
||||
|
||||
let model_path = args[0].replace("-m", "");
|
||||
command.args(args);
|
||||
|
||||
// Optional: Redirect stdio if needed (e.g., for logging within Jan)
|
||||
@ -145,17 +121,21 @@ pub async fn load(
|
||||
// Spawn the child process
|
||||
let child = command.spawn().map_err(ServerError::Io)?;
|
||||
|
||||
log::info!("Server process started with PID: {:?}", child.id());
|
||||
// Get the PID to use as session ID
|
||||
let pid = child.id().map(|id| id.to_string()).unwrap_or_else(|| {
|
||||
// Fallback in case we can't get the PID for some reason
|
||||
format!("unknown_pid_{}", Uuid::new_v4())
|
||||
});
|
||||
|
||||
log::info!("Server process started with PID: {}", pid);
|
||||
|
||||
// Store the child process handle in the state
|
||||
*process_lock = Some(child);
|
||||
|
||||
let session_id = format!("session_{}", Uuid::new_v4());
|
||||
let session_info = SessionInfo {
|
||||
session_id,
|
||||
session_id: pid, // Use PID as session ID
|
||||
port,
|
||||
model_path,
|
||||
settings,
|
||||
};
|
||||
|
||||
Ok(session_info)
|
||||
@ -163,32 +143,65 @@ pub async fn load(
|
||||
|
||||
// --- Unload Command ---
|
||||
#[tauri::command]
|
||||
pub async fn unload(state: State<'_, AppState>) -> ServerResult<()> {
|
||||
pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
|
||||
let mut process_lock = state.llama_server_process.lock().await;
|
||||
|
||||
// Take the child process out of the Option, leaving None in its place
|
||||
if let Some(mut child) = process_lock.take() {
|
||||
// Convert the PID to a string to compare with the session_id
|
||||
let process_pid = child.id().map(|pid| pid.to_string()).unwrap_or_default();
|
||||
|
||||
// Check if the session_id matches the PID
|
||||
if session_id != process_pid && !session_id.is_empty() && !process_pid.is_empty() {
|
||||
// Put the process back in the lock since we're not killing it
|
||||
*process_lock = Some(child);
|
||||
|
||||
log::warn!(
|
||||
"Session ID mismatch: provided {} vs process {}",
|
||||
session_id,
|
||||
process_pid
|
||||
);
|
||||
|
||||
return Ok(UnloadResult {
|
||||
success: false,
|
||||
error: Some(format!("Session ID mismatch: provided {} doesn't match process {}",
|
||||
session_id, process_pid)),
|
||||
});
|
||||
}
|
||||
|
||||
log::info!(
|
||||
"Attempting to terminate server process with PID: {:?}",
|
||||
child.id()
|
||||
);
|
||||
|
||||
// Kill the process
|
||||
// `start_kill` is preferred in async contexts
|
||||
match child.start_kill() {
|
||||
Ok(_) => {
|
||||
log::info!("Server process termination signal sent.");
|
||||
Ok(())
|
||||
log::info!("Server process termination signal sent successfully");
|
||||
|
||||
Ok(UnloadResult {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
Err(e) => {
|
||||
// For simplicity, we log and return error.
|
||||
log::error!("Failed to kill server process: {}", e);
|
||||
// Put it back? Maybe not useful if kill failed.
|
||||
// *process_lock = Some(child);
|
||||
Err(ServerError::Io(e))
|
||||
|
||||
// Return formatted error
|
||||
Ok(UnloadResult {
|
||||
success: false,
|
||||
error: Some(format!("Failed to kill server process: {}", e)),
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log::warn!("Attempted to unload server, but it was not running.");
|
||||
Ok(())
|
||||
log::warn!("Attempted to unload server, but no process was running");
|
||||
|
||||
// If no process is running but client thinks there is,
|
||||
// still report success since the end state is what they wanted
|
||||
Ok(UnloadResult {
|
||||
success: true,
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user