refactor load/unload
This commit is contained in:
parent
b4670b5526
commit
bbbf4779df
@ -16,9 +16,6 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
*/
|
*/
|
||||||
override onLoad() {
|
override onLoad() {
|
||||||
this.registerEngine()
|
this.registerEngine()
|
||||||
|
|
||||||
events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model))
|
|
||||||
events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -27,31 +24,4 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
registerEngine() {
|
registerEngine() {
|
||||||
EngineManager.instance().register(this)
|
EngineManager.instance().register(this)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Loads the model.
|
|
||||||
*/
|
|
||||||
async loadModel(model: Partial<Model>, abortController?: AbortController): Promise<any> {
|
|
||||||
if (model?.engine?.toString() !== this.provider) return Promise.resolve()
|
|
||||||
events.emit(ModelEvent.OnModelReady, model)
|
|
||||||
return Promise.resolve()
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Stops the model.
|
|
||||||
*/
|
|
||||||
async unloadModel(model?: Partial<Model>): Promise<any> {
|
|
||||||
if (model?.engine && model.engine.toString() !== this.provider) return Promise.resolve()
|
|
||||||
events.emit(ModelEvent.OnModelStopped, model ?? {})
|
|
||||||
return Promise.resolve()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Inference request
|
|
||||||
*/
|
|
||||||
inference(data: MessageRequest) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stop inference
|
|
||||||
*/
|
|
||||||
stopInference() {}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,35 +6,30 @@
|
|||||||
* @module llamacpp-extension/src/index
|
* @module llamacpp-extension/src/index
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {
|
import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core'
|
||||||
AIEngine,
|
|
||||||
getJanDataFolderPath,
|
|
||||||
fs,
|
|
||||||
Model,
|
|
||||||
} from '@janhq/core'
|
|
||||||
|
|
||||||
import { invoke } from '@tauri-apps/api/tauri'
|
import { invoke } from '@tauri-apps/api/tauri'
|
||||||
import {
|
import {
|
||||||
LocalProvider,
|
localProvider,
|
||||||
ModelInfo,
|
modelInfo,
|
||||||
ListOptions,
|
listOptions,
|
||||||
ListResult,
|
listResult,
|
||||||
PullOptions,
|
pullOptions,
|
||||||
PullResult,
|
pullResult,
|
||||||
LoadOptions,
|
loadOptions,
|
||||||
SessionInfo,
|
sessionInfo,
|
||||||
UnloadOptions,
|
unloadOptions,
|
||||||
UnloadResult,
|
unloadResult,
|
||||||
ChatOptions,
|
chatOptions,
|
||||||
ChatCompletion,
|
chatCompletion,
|
||||||
ChatCompletionChunk,
|
chatCompletionChunk,
|
||||||
DeleteOptions,
|
deleteOptions,
|
||||||
DeleteResult,
|
deleteResult,
|
||||||
ImportOptions,
|
importOptions,
|
||||||
ImportResult,
|
importResult,
|
||||||
AbortPullOptions,
|
abortPullOptions,
|
||||||
AbortPullResult,
|
abortPullResult,
|
||||||
ChatCompletionRequest,
|
chatCompletionRequest,
|
||||||
} from './types'
|
} from './types'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -61,246 +56,224 @@ function parseGGUFFileName(filename: string): {
|
|||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||||||
*/
|
*/
|
||||||
export default class inference_llamacpp_extension
|
export default class llamacpp_extension
|
||||||
extends AIEngine
|
extends AIEngine
|
||||||
implements LocalProvider
|
implements localProvider
|
||||||
{
|
{
|
||||||
provider: string = 'llamacpp'
|
provider: string = 'llamacpp'
|
||||||
readonly providerId: string = 'llamcpp'
|
readonly providerId: string = 'llamacpp'
|
||||||
|
|
||||||
private activeSessions: Map<string, SessionInfo> = new Map()
|
|
||||||
|
|
||||||
|
private activeSessions: Map<string, sessionInfo> = new Map()
|
||||||
private modelsBasePath!: string
|
private modelsBasePath!: string
|
||||||
|
private activeRequests: Map<string, AbortController> = new Map()
|
||||||
|
|
||||||
override async onLoad(): Promise<void> {
|
override async onLoad(): Promise<void> {
|
||||||
super.onLoad() // Calls registerEngine() from AIEngine
|
super.onLoad() // Calls registerEngine() from AIEngine
|
||||||
this.registerSettings(SETTINGS_DEFINITIONS)
|
this.registerSettings(SETTINGS)
|
||||||
|
|
||||||
const customPath = await this.getSetting<string>(
|
// Initialize models base path - assuming this would be retrieved from settings
|
||||||
LlamaCppSettings.ModelsPath,
|
this.modelsBasePath = await joinPath([
|
||||||
''
|
|
||||||
)
|
|
||||||
if (customPath && (await fs.exists(customPath))) {
|
|
||||||
this.modelsBasePath = customPath
|
|
||||||
} else {
|
|
||||||
this.modelsBasePath = await path.join(
|
|
||||||
await getJanDataFolderPath(),
|
await getJanDataFolderPath(),
|
||||||
'models',
|
'models',
|
||||||
ENGINE_ID
|
])
|
||||||
)
|
|
||||||
}
|
|
||||||
await fs.createDirAll(this.modelsBasePath)
|
|
||||||
|
|
||||||
console.log(
|
|
||||||
`${this.providerId} provider loaded. Models path: ${this.modelsBasePath}`
|
|
||||||
)
|
|
||||||
|
|
||||||
// Optionally, list and register models with the core system if AIEngine expects it
|
|
||||||
// const models = await this.listModels({ providerId: this.providerId });
|
|
||||||
// this.registerModels(this.mapModelInfoToCoreModel(models)); // mapModelInfoToCoreModel would be a helper
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async getModelsPath(): Promise<string> {
|
// Implement the required LocalProvider interface methods
|
||||||
// Ensure modelsBasePath is initialized
|
async list(opts: listOptions): Promise<listResult> {
|
||||||
if (!this.modelsBasePath) {
|
throw new Error('method not implemented yet')
|
||||||
const customPath = await this.getSetting<string>(
|
}
|
||||||
LlamaCppSettings.ModelsPath,
|
|
||||||
''
|
async pull(opts: pullOptions): Promise<pullResult> {
|
||||||
)
|
throw new Error('method not implemented yet')
|
||||||
if (customPath && (await fs.exists(customPath))) {
|
}
|
||||||
this.modelsBasePath = customPath
|
|
||||||
|
async load(opts: loadOptions): Promise<sessionInfo> {
|
||||||
|
const args: string[] = []
|
||||||
|
|
||||||
|
// model option is required
|
||||||
|
args.push('-m', opts.modelPath)
|
||||||
|
args.push('--port', String(opts.port || 8080)) // Default port if not specified
|
||||||
|
|
||||||
|
if (opts.n_gpu_layers === undefined) {
|
||||||
|
// in case of CPU only build, this option will be ignored
|
||||||
|
args.push('-ngl', '99')
|
||||||
} else {
|
} else {
|
||||||
this.modelsBasePath = await path.join(
|
args.push('-ngl', String(opts.n_gpu_layers))
|
||||||
await getJanDataFolderPath(),
|
|
||||||
'models',
|
|
||||||
ENGINE_ID
|
|
||||||
)
|
|
||||||
}
|
|
||||||
await fs.createDirAll(this.modelsBasePath)
|
|
||||||
}
|
|
||||||
return this.modelsBasePath
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async listModels(_opts: ListOptions): Promise<ListResult> {
|
if (opts.n_ctx !== undefined) {
|
||||||
const modelsDir = await this.getModelsPath()
|
args.push('-c', String(opts.n_ctx))
|
||||||
const result: ModelInfo[] = []
|
}
|
||||||
|
|
||||||
|
// Add remaining options from the interface
|
||||||
|
if (opts.threads !== undefined) {
|
||||||
|
args.push('--threads', String(opts.threads))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.threads_batch !== undefined) {
|
||||||
|
args.push('--threads-batch', String(opts.threads_batch))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.ctx_size !== undefined) {
|
||||||
|
args.push('--ctx-size', String(opts.ctx_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.n_predict !== undefined) {
|
||||||
|
args.push('--n-predict', String(opts.n_predict))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.batch_size !== undefined) {
|
||||||
|
args.push('--batch-size', String(opts.batch_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.ubatch_size !== undefined) {
|
||||||
|
args.push('--ubatch-size', String(opts.ubatch_size))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.device !== undefined) {
|
||||||
|
args.push('--device', opts.device)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.split_mode !== undefined) {
|
||||||
|
args.push('--split-mode', opts.split_mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.main_gpu !== undefined) {
|
||||||
|
args.push('--main-gpu', String(opts.main_gpu))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Boolean flags
|
||||||
|
if (opts.flash_attn === true) {
|
||||||
|
args.push('--flash-attn')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.cont_batching === true) {
|
||||||
|
args.push('--cont-batching')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.no_mmap === true) {
|
||||||
|
args.push('--no-mmap')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.mlock === true) {
|
||||||
|
args.push('--mlock')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.no_kv_offload === true) {
|
||||||
|
args.push('--no-kv-offload')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.cache_type_k !== undefined) {
|
||||||
|
args.push('--cache-type-k', opts.cache_type_k)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.cache_type_v !== undefined) {
|
||||||
|
args.push('--cache-type-v', opts.cache_type_v)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.defrag_thold !== undefined) {
|
||||||
|
args.push('--defrag-thold', String(opts.defrag_thold))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.rope_scaling !== undefined) {
|
||||||
|
args.push('--rope-scaling', opts.rope_scaling)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.rope_scale !== undefined) {
|
||||||
|
args.push('--rope-scale', String(opts.rope_scale))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.rope_freq_base !== undefined) {
|
||||||
|
args.push('--rope-freq-base', String(opts.rope_freq_base))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.rope_freq_scale !== undefined) {
|
||||||
|
args.push('--rope-freq-scale', String(opts.rope_freq_scale))
|
||||||
|
}
|
||||||
|
console.log('Calling Tauri command load with args:', args)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (!(await fs.exists(modelsDir))) {
|
const sessionInfo = await invoke<sessionInfo>('plugin:llamacpp|load', {
|
||||||
await fs.createDirAll(modelsDir)
|
args: args,
|
||||||
return []
|
|
||||||
}
|
|
||||||
|
|
||||||
const entries = await fs.readDir(modelsDir)
|
|
||||||
for (const entry of entries) {
|
|
||||||
if (entry.name?.endsWith('.gguf') && entry.isFile) {
|
|
||||||
const modelPath = await path.join(modelsDir, entry.name)
|
|
||||||
const stats = await fs.stat(modelPath)
|
|
||||||
const parsedName = parseGGUFFileName(entry.name)
|
|
||||||
|
|
||||||
result.push({
|
|
||||||
id: `${parsedName.baseModelId}${parsedName.quant ? `/${parsedName.quant}` : ''}`, // e.g., "mistral-7b/Q4_0"
|
|
||||||
name: entry.name.replace('.gguf', ''), // Or a more human-friendly name
|
|
||||||
quant_type: parsedName.quant,
|
|
||||||
providerId: this.providerId,
|
|
||||||
sizeBytes: stats.size,
|
|
||||||
path: modelPath,
|
|
||||||
tags: [this.providerId, parsedName.quant || 'unknown_quant'].filter(
|
|
||||||
Boolean
|
|
||||||
) as string[],
|
|
||||||
})
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error(`[${this.providerId}] Error listing models:`, error)
|
|
||||||
// Depending on desired behavior, either throw or return empty/partial list
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// pullModel
|
|
||||||
async pullModel(opts: PullOptions): Promise<PullResult> {
|
|
||||||
// TODO: Implement pullModel
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// abortPull
|
|
||||||
async abortPull(opts: AbortPullOptions): Promise<AbortPullResult> {
|
|
||||||
// TODO: implement abortPull
|
|
||||||
}
|
|
||||||
|
|
||||||
async load(opts: LoadOptions): Promise<SessionInfo> {
|
|
||||||
if (opts.providerId !== this.providerId) {
|
|
||||||
throw new Error('Invalid providerId for LlamaCppProvider.loadModel')
|
|
||||||
}
|
|
||||||
|
|
||||||
const sessionId = uuidv4()
|
|
||||||
const loadParams = {
|
|
||||||
model_path: opts.modelPath,
|
|
||||||
session_id: sessionId, // Pass sessionId to Rust for tracking
|
|
||||||
// Default llama.cpp server options, can be overridden by opts.options
|
|
||||||
port: opts.options?.port ?? 0, // 0 for dynamic port assignment by OS
|
|
||||||
n_gpu_layers:
|
|
||||||
opts.options?.n_gpu_layers ??
|
|
||||||
(await this.getSetting(LlamaCppSettings.DefaultNGpuLayers, -1)),
|
|
||||||
n_ctx:
|
|
||||||
opts.options?.n_ctx ??
|
|
||||||
(await this.getSetting(LlamaCppSettings.DefaultNContext, 2048)),
|
|
||||||
// Spread any other options from opts.options
|
|
||||||
...(opts.options || {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
console.log(
|
|
||||||
`[${this.providerId}] Requesting to load model: ${opts.modelPath} with options:`,
|
|
||||||
loadParams
|
|
||||||
)
|
|
||||||
// This matches the Rust handler: core::utils::extensions::inference_llamacpp_extension::server::load
|
|
||||||
const rustResponse: {
|
|
||||||
session_id: string
|
|
||||||
port: number
|
|
||||||
model_path: string
|
|
||||||
settings: Record<string, unknown>
|
|
||||||
} = await invoke('plugin:llamacpp|load', { params: loadParams }) // Adjust namespace if needed
|
|
||||||
|
|
||||||
if (!rustResponse || !rustResponse.port) {
|
|
||||||
throw new Error(
|
|
||||||
'Rust load function did not return expected port or session info.'
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const sessionInfo: SessionInfo = {
|
|
||||||
sessionId: rustResponse.session_id, // Use sessionId from Rust if it regenerates/confirms it
|
|
||||||
port: rustResponse.port,
|
|
||||||
modelPath: rustResponse.model_path,
|
|
||||||
providerId: this.providerId,
|
|
||||||
settings: rustResponse.settings, // Settings actually used by the server
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Store the session info for later use
|
||||||
this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
|
this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
|
||||||
console.log(
|
|
||||||
`[${this.providerId}] Model loaded: ${sessionInfo.modelPath} on port ${sessionInfo.port}, session: ${sessionInfo.sessionId}`
|
|
||||||
)
|
|
||||||
return sessionInfo
|
return sessionInfo
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(
|
console.error('Error loading llama-server:', error)
|
||||||
`[${this.providerId}] Error loading model ${opts.modelPath}:`,
|
throw new Error(`Failed to load llama-server: ${error}`)
|
||||||
error
|
|
||||||
)
|
|
||||||
throw error // Re-throw to be handled by the caller
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async unload(opts: UnloadOptions): Promise<UnloadResult> {
|
|
||||||
if (opts.providerId !== this.providerId) {
|
|
||||||
return { success: false, error: 'Invalid providerId' }
|
|
||||||
}
|
|
||||||
const session = this.activeSessions.get(opts.sessionId)
|
|
||||||
if (!session) {
|
|
||||||
return {
|
|
||||||
success: false,
|
|
||||||
error: `No active session found for id: ${opts.sessionId}`,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async unload(opts: unloadOptions): Promise<unloadResult> {
|
||||||
try {
|
try {
|
||||||
console.log(
|
// Pass the PID as the session_id
|
||||||
`[${this.providerId}] Requesting to unload model for session: ${opts.sessionId}`
|
const result = await invoke<unloadResult>('plugin:llamacpp|unload', {
|
||||||
)
|
session_id: opts.sessionId, // Using PID as session ID
|
||||||
// Matches: core::utils::extensions::inference_llamacpp_extension::server::unload
|
})
|
||||||
const rustResponse: { success: boolean; error?: string } = await invoke(
|
|
||||||
'plugin:llamacpp|unload',
|
|
||||||
{ sessionId: opts.sessionId }
|
|
||||||
)
|
|
||||||
|
|
||||||
if (rustResponse.success) {
|
// If successful, remove from active sessions
|
||||||
|
if (result.success) {
|
||||||
this.activeSessions.delete(opts.sessionId)
|
this.activeSessions.delete(opts.sessionId)
|
||||||
console.log(
|
console.log(`Successfully unloaded model with PID ${opts.sessionId}`)
|
||||||
`[${this.providerId}] Session ${opts.sessionId} unloaded successfully.`
|
|
||||||
)
|
|
||||||
return { success: true }
|
|
||||||
} else {
|
} else {
|
||||||
console.error(
|
console.warn(`Failed to unload model: ${result.error}`)
|
||||||
`[${this.providerId}] Failed to unload session ${opts.sessionId}: ${rustResponse.error}`
|
}
|
||||||
)
|
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error in unload command:', error)
|
||||||
return {
|
return {
|
||||||
success: false,
|
success: false,
|
||||||
error: rustResponse.error || 'Unknown error during unload',
|
error: `Failed to unload model: ${error}`,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
|
||||||
console.error(
|
|
||||||
`[${this.providerId}] Error invoking unload for session ${opts.sessionId}:`,
|
|
||||||
error
|
|
||||||
)
|
|
||||||
return { success: false, error: error.message || String(error) }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async chat(
|
async chat(
|
||||||
opts: ChatOptions
|
opts: chatOptions
|
||||||
): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>> {}
|
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||||||
|
const sessionInfo = this.activeSessions.get(opts.sessionId)
|
||||||
async deleteModel(opts: DeleteOptions): Promise<DeleteResult> {}
|
if (!sessionInfo) {
|
||||||
|
throw new Error(
|
||||||
async importModel(opts: ImportOptions): Promise<ImportResult> {}
|
`No active session found for sessionId: ${opts.sessionId}`
|
||||||
|
|
||||||
override async loadModel(model: Model): Promise<any> {
|
|
||||||
if (model.engine?.toString() !== this.provider) return Promise.resolve()
|
|
||||||
console.log(
|
|
||||||
`[${this.providerId} AIEngine] Received OnModelInit for:`,
|
|
||||||
model.id
|
|
||||||
)
|
)
|
||||||
return super.load(model)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async unloadModel(model?: Model): Promise<any> {
|
// For streaming responses
|
||||||
if (model?.engine && model.engine.toString() !== this.provider)
|
if (opts.stream) {
|
||||||
return Promise.resolve()
|
return this.streamChat(opts)
|
||||||
console.log(
|
}
|
||||||
`[${this.providerId} AIEngine] Received OnModelStop for:`,
|
|
||||||
model?.id || 'all models'
|
// For non-streaming responses
|
||||||
)
|
try {
|
||||||
return super.unload(model)
|
return await invoke<chatCompletion>('plugin:llamacpp|chat', { opts })
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error during chat completion:', error)
|
||||||
|
throw new Error(`Chat completion failed: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async delete(opts: deleteOptions): Promise<deleteResult> {
|
||||||
|
throw new Error("method not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
async import(opts: importOptions): Promise<importResult> {
|
||||||
|
throw new Error("method not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
|
||||||
|
throw new Error('method not implemented yet')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional method for direct client access
|
||||||
|
getChatClient(sessionId: string): any {
|
||||||
|
throw new Error("method not implemented yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
onUnload(): void {
|
||||||
|
throw new Error('Method not implemented.')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
// --- Re-using OpenAI types (minimal definitions for this example) ---
|
// --- Re-using OpenAI types (minimal definitions for this example) ---
|
||||||
// In a real project, you'd import these from 'openai' or a shared types package.
|
// In a real project, you'd import these from 'openai' or a shared types package.
|
||||||
export interface ChatCompletionRequestMessage {
|
export interface chatCompletionRequestMessage {
|
||||||
role: 'system' | 'user' | 'assistant' | 'tool';
|
role: 'system' | 'user' | 'assistant' | 'tool';
|
||||||
content: string | null;
|
content: string | null;
|
||||||
name?: string;
|
name?: string;
|
||||||
@ -10,9 +10,9 @@ export interface ChatCompletionRequestMessage {
|
|||||||
tool_call_id?: string;
|
tool_call_id?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletionRequest {
|
export interface chatCompletionRequest {
|
||||||
model: string; // Model ID, though for local it might be implicit via sessionId
|
model: string; // Model ID, though for local it might be implicit via sessionId
|
||||||
messages: ChatCompletionRequestMessage[];
|
messages: chatCompletionRequestMessage[];
|
||||||
temperature?: number | null;
|
temperature?: number | null;
|
||||||
top_p?: number | null;
|
top_p?: number | null;
|
||||||
n?: number | null;
|
n?: number | null;
|
||||||
@ -26,41 +26,41 @@ export interface ChatCompletionRequest {
|
|||||||
// ... TODO: other OpenAI params
|
// ... TODO: other OpenAI params
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletionChunkChoiceDelta {
|
export interface chatCompletionChunkChoiceDelta {
|
||||||
content?: string | null;
|
content?: string | null;
|
||||||
role?: 'system' | 'user' | 'assistant' | 'tool';
|
role?: 'system' | 'user' | 'assistant' | 'tool';
|
||||||
tool_calls?: any[]; // Simplified
|
tool_calls?: any[]; // Simplified
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletionChunkChoice {
|
export interface chatCompletionChunkChoice {
|
||||||
index: number;
|
index: number;
|
||||||
delta: ChatCompletionChunkChoiceDelta;
|
delta: chatCompletionChunkChoiceDelta;
|
||||||
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
|
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletionChunk {
|
export interface chatCompletionChunk {
|
||||||
id: string;
|
id: string;
|
||||||
object: 'chat.completion.chunk';
|
object: 'chat.completion.chunk';
|
||||||
created: number;
|
created: number;
|
||||||
model: string;
|
model: string;
|
||||||
choices: ChatCompletionChunkChoice[];
|
choices: chatCompletionChunkChoice[];
|
||||||
system_fingerprint?: string;
|
system_fingerprint?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
export interface ChatCompletionChoice {
|
export interface chatCompletionChoice {
|
||||||
index: number;
|
index: number;
|
||||||
message: ChatCompletionRequestMessage; // Response message
|
message: chatCompletionRequestMessage; // Response message
|
||||||
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
|
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
|
||||||
logprobs?: any; // Simplified
|
logprobs?: any; // Simplified
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatCompletion {
|
export interface chatCompletion {
|
||||||
id: string;
|
id: string;
|
||||||
object: 'chat.completion';
|
object: 'chat.completion';
|
||||||
created: number;
|
created: number;
|
||||||
model: string; // Model ID used
|
model: string; // Model ID used
|
||||||
choices: ChatCompletionChoice[];
|
choices: chatCompletionChoice[];
|
||||||
usage?: {
|
usage?: {
|
||||||
prompt_tokens: number;
|
prompt_tokens: number;
|
||||||
completion_tokens: number;
|
completion_tokens: number;
|
||||||
@ -72,7 +72,7 @@ export interface ChatCompletion {
|
|||||||
|
|
||||||
|
|
||||||
// Shared model metadata
|
// Shared model metadata
|
||||||
export interface ModelInfo {
|
export interface modelInfo {
|
||||||
id: string; // e.g. "qwen3-4B" or "org/model/quant"
|
id: string; // e.g. "qwen3-4B" or "org/model/quant"
|
||||||
name: string; // human‑readable, e.g., "Qwen3 4B Q4_0"
|
name: string; // human‑readable, e.g., "Qwen3 4B Q4_0"
|
||||||
quant_type?: string; // q4_0 (optional as it might be part of ID or name)
|
quant_type?: string; // q4_0 (optional as it might be part of ID or name)
|
||||||
@ -86,24 +86,24 @@ export interface ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 1. /list
|
// 1. /list
|
||||||
export interface ListOptions {
|
export interface listOptions {
|
||||||
providerId: string; // To specify which provider if a central manager calls this
|
providerId: string; // To specify which provider if a central manager calls this
|
||||||
}
|
}
|
||||||
export type ListResult = ModelInfo[];
|
export type listResult = ModelInfo[];
|
||||||
|
|
||||||
// 2. /pull
|
// 2. /pull
|
||||||
export interface PullOptions {
|
export interface pullOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
modelId: string; // Identifier for the model to pull (e.g., from a known registry)
|
modelId: string; // Identifier for the model to pull (e.g., from a known registry)
|
||||||
downloadUrl: string; // URL to download the model from
|
downloadUrl: string; // URL to download the model from
|
||||||
/** optional callback to receive download progress */
|
/** optional callback to receive download progress */
|
||||||
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
|
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
|
||||||
}
|
}
|
||||||
export interface PullResult {
|
export interface pullResult {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
path?: string; // local file path to the pulled model
|
path?: string; // local file path to the pulled model
|
||||||
error?: string;
|
error?: string;
|
||||||
modelInfo?: ModelInfo; // Info of the pulled model
|
modelInfo?: modelInfo; // Info of the pulled model
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. /load
|
// 3. /load
|
||||||
@ -135,7 +135,7 @@ export interface loadOptions {
|
|||||||
rope_freq_scale?: number
|
rope_freq_scale?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SessionInfo {
|
export interface sessionInfo {
|
||||||
sessionId: string; // opaque handle for unload/chat
|
sessionId: string; // opaque handle for unload/chat
|
||||||
port: number; // llama-server output port (corrected from portid)
|
port: number; // llama-server output port (corrected from portid)
|
||||||
modelPath: string; // path of the loaded model
|
modelPath: string; // path of the loaded model
|
||||||
@ -143,71 +143,71 @@ export interface SessionInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. /unload
|
// 4. /unload
|
||||||
export interface UnloadOptions {
|
export interface unloadOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
}
|
}
|
||||||
export interface UnloadResult {
|
export interface unloadResult {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. /chat
|
// 5. /chat
|
||||||
export interface ChatOptions {
|
export interface chatOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
sessionId: string;
|
sessionId: string;
|
||||||
/** Full OpenAI ChatCompletionRequest payload */
|
/** Full OpenAI ChatCompletionRequest payload */
|
||||||
payload: ChatCompletionRequest;
|
payload: chatCompletionRequest;
|
||||||
}
|
}
|
||||||
// Output for /chat will be Promise<ChatCompletion> for non-streaming
|
// Output for /chat will be Promise<ChatCompletion> for non-streaming
|
||||||
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
|
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
|
||||||
|
|
||||||
// 6. /delete
|
// 6. /delete
|
||||||
export interface DeleteOptions {
|
export interface deleteOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
modelId: string; // The ID of the model to delete (implies finding its path)
|
modelId: string; // The ID of the model to delete (implies finding its path)
|
||||||
modelPath?: string; // Optionally, direct path can be provided
|
modelPath?: string; // Optionally, direct path can be provided
|
||||||
}
|
}
|
||||||
export interface DeleteResult {
|
export interface deleteResult {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. /import
|
// 7. /import
|
||||||
export interface ImportOptions {
|
export interface importOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
sourcePath: string; // Path to the local model file to import
|
sourcePath: string; // Path to the local model file to import
|
||||||
desiredModelId?: string; // Optional: if user wants to name it specifically
|
desiredModelId?: string; // Optional: if user wants to name it specifically
|
||||||
}
|
}
|
||||||
export interface ImportResult {
|
export interface importResult {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
modelInfo?: ModelInfo;
|
modelInfo?: modelInfo;
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 8. /abortPull
|
// 8. /abortPull
|
||||||
export interface AbortPullOptions {
|
export interface abortPullOptions {
|
||||||
providerId: string;
|
providerId: string;
|
||||||
modelId: string; // The modelId whose download is to be aborted
|
modelId: string; // The modelId whose download is to be aborted
|
||||||
}
|
}
|
||||||
export interface AbortPullResult {
|
export interface abortPullResult {
|
||||||
success: boolean;
|
success: boolean;
|
||||||
error?: string;
|
error?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// The interface for any local provider
|
// The interface for any local provider
|
||||||
export interface LocalProvider {
|
export interface localProvider {
|
||||||
readonly providerId: string;
|
readonly providerId: string;
|
||||||
|
|
||||||
listModels(opts: ListOptions): Promise<ListResult>;
|
list(opts: listOptions): Promise<listResult>;
|
||||||
pullModel(opts: PullOptions): Promise<PullResult>;
|
pull(opts: pullOptions): Promise<pullResult>;
|
||||||
loadModel(opts: LoadOptions): Promise<SessionInfo>;
|
load(opts: loadOptions): Promise<sessionInfo>;
|
||||||
unloadModel(opts: UnloadOptions): Promise<UnloadResult>;
|
unload(opts: unloadOptions): Promise<unloadResult>;
|
||||||
chat(opts: ChatOptions): Promise<ChatCompletion | AsyncIterable<ChatCompletionChunk>>;
|
chat(opts: chatOptions): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>;
|
||||||
deleteModel(opts: DeleteOptions): Promise<DeleteResult>;
|
delete(opts: deleteOptions): Promise<deleteResult>;
|
||||||
importModel(opts: ImportOptions): Promise<ImportResult>;
|
import(opts: importOptions): Promise<importResult>;
|
||||||
abortPull(opts: AbortPullOptions): Promise<AbortPullResult>;
|
abortPull(opts: abortPullOptions): Promise<abortPullResult>;
|
||||||
|
|
||||||
// Optional: for direct access to underlying client if needed for specific streaming cases
|
// Optional: for direct access to underlying client if needed for specific streaming cases
|
||||||
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session
|
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session
|
||||||
|
|||||||
@ -3,7 +3,6 @@ use serde::{Serialize, Deserialize};
|
|||||||
use tauri::path::BaseDirectory;
|
use tauri::path::BaseDirectory;
|
||||||
use tauri::{AppHandle, Manager, State}; // Import Manager trait
|
use tauri::{AppHandle, Manager, State}; // Import Manager trait
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use std::collections::HashMap;
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
use thiserror;
|
use thiserror;
|
||||||
|
|
||||||
@ -70,7 +69,12 @@ pub struct SessionInfo {
|
|||||||
pub session_id: String, // opaque handle for unload/chat
|
pub session_id: String, // opaque handle for unload/chat
|
||||||
pub port: u16, // llama-server output port
|
pub port: u16, // llama-server output port
|
||||||
pub model_path: String, // path of the loaded model
|
pub model_path: String, // path of the loaded model
|
||||||
pub settings: HashMap<String, serde_json::Value>, // The actual settings used to load
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize)]
|
||||||
|
pub struct UnloadResult {
|
||||||
|
success: bool,
|
||||||
|
error: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Load Command ---
|
// --- Load Command ---
|
||||||
@ -102,40 +106,12 @@ pub async fn load(
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut port = 8080; // Default port
|
let port = 8080; // Default port
|
||||||
let mut model_path = String::new();
|
|
||||||
let mut settings: HashMap<String, serde_json::Value> = HashMap::new();
|
|
||||||
|
|
||||||
// Extract arguments into settings map and specific fields
|
|
||||||
let mut i = 0;
|
|
||||||
while i < args.len() {
|
|
||||||
if args[i] == "--port" && i + 1 < args.len() {
|
|
||||||
if let Ok(p) = args[i + 1].parse::<u16>() {
|
|
||||||
port = p;
|
|
||||||
}
|
|
||||||
settings.insert("port".to_string(), serde_json::Value::String(args[i + 1].clone()));
|
|
||||||
i += 2;
|
|
||||||
} else if args[i] == "-m" && i + 1 < args.len() {
|
|
||||||
model_path = args[i + 1].clone();
|
|
||||||
settings.insert("modelPath".to_string(), serde_json::Value::String(model_path.clone()));
|
|
||||||
i += 2;
|
|
||||||
} else if i + 1 < args.len() && args[i].starts_with("-") {
|
|
||||||
// Store other arguments as settings
|
|
||||||
let key = args[i].trim_start_matches("-").trim_start_matches("-");
|
|
||||||
settings.insert(key.to_string(), serde_json::Value::String(args[i + 1].clone()));
|
|
||||||
i += 2;
|
|
||||||
} else {
|
|
||||||
// Handle boolean flags
|
|
||||||
if args[i].starts_with("-") {
|
|
||||||
let key = args[i].trim_start_matches("-").trim_start_matches("-");
|
|
||||||
settings.insert(key.to_string(), serde_json::Value::Bool(true));
|
|
||||||
}
|
|
||||||
i += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configure the command to run the server
|
// Configure the command to run the server
|
||||||
let mut command = Command::new(server_path);
|
let mut command = Command::new(server_path);
|
||||||
|
|
||||||
|
let model_path = args[0].replace("-m", "");
|
||||||
command.args(args);
|
command.args(args);
|
||||||
|
|
||||||
// Optional: Redirect stdio if needed (e.g., for logging within Jan)
|
// Optional: Redirect stdio if needed (e.g., for logging within Jan)
|
||||||
@ -145,17 +121,21 @@ pub async fn load(
|
|||||||
// Spawn the child process
|
// Spawn the child process
|
||||||
let child = command.spawn().map_err(ServerError::Io)?;
|
let child = command.spawn().map_err(ServerError::Io)?;
|
||||||
|
|
||||||
log::info!("Server process started with PID: {:?}", child.id());
|
// Get the PID to use as session ID
|
||||||
|
let pid = child.id().map(|id| id.to_string()).unwrap_or_else(|| {
|
||||||
|
// Fallback in case we can't get the PID for some reason
|
||||||
|
format!("unknown_pid_{}", Uuid::new_v4())
|
||||||
|
});
|
||||||
|
|
||||||
|
log::info!("Server process started with PID: {}", pid);
|
||||||
|
|
||||||
// Store the child process handle in the state
|
// Store the child process handle in the state
|
||||||
*process_lock = Some(child);
|
*process_lock = Some(child);
|
||||||
|
|
||||||
let session_id = format!("session_{}", Uuid::new_v4());
|
|
||||||
let session_info = SessionInfo {
|
let session_info = SessionInfo {
|
||||||
session_id,
|
session_id: pid, // Use PID as session ID
|
||||||
port,
|
port,
|
||||||
model_path,
|
model_path,
|
||||||
settings,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(session_info)
|
Ok(session_info)
|
||||||
@ -163,32 +143,65 @@ pub async fn load(
|
|||||||
|
|
||||||
// --- Unload Command ---
|
// --- Unload Command ---
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn unload(state: State<'_, AppState>) -> ServerResult<()> {
|
pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
|
||||||
let mut process_lock = state.llama_server_process.lock().await;
|
let mut process_lock = state.llama_server_process.lock().await;
|
||||||
|
|
||||||
// Take the child process out of the Option, leaving None in its place
|
// Take the child process out of the Option, leaving None in its place
|
||||||
if let Some(mut child) = process_lock.take() {
|
if let Some(mut child) = process_lock.take() {
|
||||||
|
// Convert the PID to a string to compare with the session_id
|
||||||
|
let process_pid = child.id().map(|pid| pid.to_string()).unwrap_or_default();
|
||||||
|
|
||||||
|
// Check if the session_id matches the PID
|
||||||
|
if session_id != process_pid && !session_id.is_empty() && !process_pid.is_empty() {
|
||||||
|
// Put the process back in the lock since we're not killing it
|
||||||
|
*process_lock = Some(child);
|
||||||
|
|
||||||
|
log::warn!(
|
||||||
|
"Session ID mismatch: provided {} vs process {}",
|
||||||
|
session_id,
|
||||||
|
process_pid
|
||||||
|
);
|
||||||
|
|
||||||
|
return Ok(UnloadResult {
|
||||||
|
success: false,
|
||||||
|
error: Some(format!("Session ID mismatch: provided {} doesn't match process {}",
|
||||||
|
session_id, process_pid)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
log::info!(
|
log::info!(
|
||||||
"Attempting to terminate server process with PID: {:?}",
|
"Attempting to terminate server process with PID: {:?}",
|
||||||
child.id()
|
child.id()
|
||||||
);
|
);
|
||||||
|
|
||||||
// Kill the process
|
// Kill the process
|
||||||
// `start_kill` is preferred in async contexts
|
|
||||||
match child.start_kill() {
|
match child.start_kill() {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
log::info!("Server process termination signal sent.");
|
log::info!("Server process termination signal sent successfully");
|
||||||
Ok(())
|
|
||||||
|
Ok(UnloadResult {
|
||||||
|
success: true,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// For simplicity, we log and return error.
|
|
||||||
log::error!("Failed to kill server process: {}", e);
|
log::error!("Failed to kill server process: {}", e);
|
||||||
// Put it back? Maybe not useful if kill failed.
|
|
||||||
// *process_lock = Some(child);
|
// Return formatted error
|
||||||
Err(ServerError::Io(e))
|
Ok(UnloadResult {
|
||||||
|
success: false,
|
||||||
|
error: Some(format!("Failed to kill server process: {}", e)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log::warn!("Attempted to unload server, but it was not running.");
|
log::warn!("Attempted to unload server, but no process was running");
|
||||||
Ok(())
|
|
||||||
|
// If no process is running but client thinks there is,
|
||||||
|
// still report success since the end state is what they wanted
|
||||||
|
Ok(UnloadResult {
|
||||||
|
success: true,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user