refactor load/unload again; move types to core and refactor AIEngine abstract class

This commit is contained in:
Akarshan Biswas 2025-05-20 19:33:26 +05:30 committed by Louis
parent ee2cb9e625
commit a7a2dcc8d8
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 375 additions and 252 deletions

View File

@ -1,15 +1,208 @@
import { events } from '../../events'
import { BaseExtension } from '../../extension'
import { MessageRequest, Model, ModelEvent } from '../../../types'
import { EngineManager } from './EngineManager'
/* AIEngine class types */
export interface chatCompletionRequestMessage {
role: 'system' | 'user' | 'assistant' | 'tool'
content: string | null
name?: string
tool_calls?: any[] // Simplified
tool_call_id?: string
}
export interface chatCompletionRequest {
provider: string,
model: string // Model ID, though for local it might be implicit via sessionId
messages: chatCompletionRequestMessage[]
temperature?: number | null
top_p?: number | null
n?: number | null
stream?: boolean | null
stop?: string | string[] | null
max_tokens?: number
presence_penalty?: number | null
frequency_penalty?: number | null
logit_bias?: { [key: string]: number } | null
user?: string
// ... TODO: other OpenAI params
}
export interface chatCompletionChunkChoiceDelta {
content?: string | null
role?: 'system' | 'user' | 'assistant' | 'tool'
tool_calls?: any[] // Simplified
}
export interface chatCompletionChunkChoice {
index: number
delta: chatCompletionChunkChoiceDelta
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null
}
export interface chatCompletionChunk {
id: string
object: 'chat.completion.chunk'
created: number
model: string
choices: chatCompletionChunkChoice[]
system_fingerprint?: string
}
export interface chatCompletionChoice {
index: number
message: chatCompletionRequestMessage // Response message
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'
logprobs?: any // Simplified
}
export interface chatCompletion {
id: string
object: 'chat.completion'
created: number
model: string // Model ID used
choices: chatCompletionChoice[]
usage?: {
prompt_tokens: number
completion_tokens: number
total_tokens: number
}
system_fingerprint?: string
}
// --- End OpenAI types ---
// Shared model metadata
export interface modelInfo {
id: string // e.g. "qwen3-4B" or "org/model/quant"
name: string // humanreadable, e.g., "Qwen3 4B Q4_0"
quant_type?: string // q4_0 (optional as it might be part of ID or name)
providerId: string // e.g. "llama.cpp"
port: number
sizeBytes: number
tags?: string[]
path?: string // Absolute path to the model file, if applicable
// Additional provider-specific metadata can be added here
[key: string]: any
}
// 1. /list
export interface listOptions {
providerId: string // To specify which provider if a central manager calls this
}
export type listResult = modelInfo[]
// 2. /pull
export interface pullOptions {
providerId: string
modelId: string // Identifier for the model to pull (e.g., from a known registry)
downloadUrl: string // URL to download the model from
/** optional callback to receive download progress */
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void
}
export interface pullResult {
success: boolean
path?: string // local file path to the pulled model
error?: string
modelInfo?: modelInfo // Info of the pulled model
}
// 3. /load
export interface loadOptions {
modelPath: string
port?: number
n_gpu_layers?: number
n_ctx?: number
threads?: number
threads_batch?: number
ctx_size?: number
n_predict?: number
batch_size?: number
ubatch_size?: number
device?: string
split_mode?: string
main_gpu?: number
flash_attn?: boolean
cont_batching?: boolean
no_mmap?: boolean
mlock?: boolean
no_kv_offload?: boolean
cache_type_k?: string
cache_type_v?: string
defrag_thold?: number
rope_scaling?: string
rope_scale?: number
rope_freq_base?: number
rope_freq_scale?: number
}
export interface sessionInfo {
sessionId: string // opaque handle for unload/chat
port: number // llama-server output port (corrected from portid)
modelName: string, //name of the model
modelPath: string // path of the loaded model
}
// 4. /unload
export interface unloadOptions {
providerId: string
sessionId: string
}
export interface unloadResult {
success: boolean
error?: string
}
// 5. /chat
export interface chatOptions {
providerId: string
sessionId: string
/** Full OpenAI ChatCompletionRequest payload */
payload: chatCompletionRequest
}
// Output for /chat will be Promise<ChatCompletion> for non-streaming
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
// 6. /delete
export interface deleteOptions {
providerId: string
modelId: string // The ID of the model to delete (implies finding its path)
modelPath?: string // Optionally, direct path can be provided
}
export interface deleteResult {
success: boolean
error?: string
}
// 7. /import
export interface importOptions {
providerId: string
sourcePath: string // Path to the local model file to import
desiredModelId?: string // Optional: if user wants to name it specifically
}
export interface importResult {
success: boolean
modelInfo?: modelInfo
error?: string
}
// 8. /abortPull
export interface abortPullOptions {
providerId: string
modelId: string // The modelId whose download is to be aborted
}
export interface abortPullResult {
success: boolean
error?: string
}
/**
* Base AIEngine
* Applicable to all AI Engines
*/
export abstract class AIEngine extends BaseExtension {
// The inference engine
abstract provider: string
// The inference engine ID, implementing the readonly providerId from interface
abstract readonly provider: string
/**
* On extension load, subscribe to events.
@ -24,4 +217,49 @@ export abstract class AIEngine extends BaseExtension {
registerEngine() {
EngineManager.instance().register(this)
}
/**
* Lists available models
*/
abstract list(opts: listOptions): Promise<listResult>
/**
* Pulls/downloads a model
*/
abstract pull(opts: pullOptions): Promise<pullResult>
/**
* Loads a model into memory
*/
abstract load(opts: loadOptions): Promise<sessionInfo>
/**
* Unloads a model from memory
*/
abstract unload(opts: unloadOptions): Promise<unloadResult>
/**
* Sends a chat request to the model
*/
abstract chat(opts: chatCompletionRequest): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>
/**
* Deletes a model
*/
abstract delete(opts: deleteOptions): Promise<deleteResult>
/**
* Imports a model
*/
abstract import(opts: importOptions): Promise<importResult>
/**
* Aborts an ongoing model pull
*/
abstract abortPull(opts: abortPullOptions): Promise<abortPullResult>
/**
* Optional method to get the underlying chat client
*/
getChatClient?(sessionId: string): any
}

View File

@ -30,8 +30,7 @@
},
"files": [
"dist/*",
"package.json",
"README.md"
"package.json"
],
"bundleDependencies": [
"fetch-retry"

View File

@ -6,11 +6,11 @@
* @module llamacpp-extension/src/index
*/
import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core'
import { invoke } from '@tauri-apps/api/tauri'
import {
localProvider,
AIEngine,
getJanDataFolderPath,
fs,
joinPath,
modelInfo,
listOptions,
listResult,
@ -30,7 +30,9 @@ import {
abortPullOptions,
abortPullResult,
chatCompletionRequest,
} from './types'
} from '@janhq/core'
import { invoke } from '@tauri-apps/api/tauri'
/**
* Helper to convert GGUF model filename to a more structured ID/name
@ -56,10 +58,7 @@ function parseGGUFFileName(filename: string): {
* The class provides methods for initializing and stopping a model, and for making inference requests.
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
*/
export default class llamacpp_extension
extends AIEngine
implements localProvider
{
export default class llamacpp_extension extends AIEngine {
provider: string = 'llamacpp'
readonly providerId: string = 'llamacpp'
@ -79,17 +78,20 @@ export default class llamacpp_extension
}
// Implement the required LocalProvider interface methods
async list(opts: listOptions): Promise<listResult> {
override async list(opts: listOptions): Promise<listResult> {
throw new Error('method not implemented yet')
}
async pull(opts: pullOptions): Promise<pullResult> {
override async pull(opts: pullOptions): Promise<pullResult> {
throw new Error('method not implemented yet')
}
async load(opts: loadOptions): Promise<sessionInfo> {
override async load(opts: loadOptions): Promise<sessionInfo> {
const args: string[] = []
// disable llama-server webui
args.push('--no-webui')
// model option is required
args.push('-m', opts.modelPath)
args.push('--port', String(opts.port || 8080)) // Default port if not specified
@ -193,24 +195,24 @@ export default class llamacpp_extension
console.log('Calling Tauri command load with args:', args)
try {
const sessionInfo = await invoke<sessionInfo>('plugin:llamacpp|load', {
const sInfo = await invoke<sessionInfo>('load_llama_model', {
args: args,
})
// Store the session info for later use
this.activeSessions.set(sessionInfo.sessionId, sessionInfo)
this.activeSessions.set(sInfo.sessionId, sInfo)
return sessionInfo
return sInfo
} catch (error) {
console.error('Error loading llama-server:', error)
throw new Error(`Failed to load llama-server: ${error}`)
}
}
async unload(opts: unloadOptions): Promise<unloadResult> {
override async unload(opts: unloadOptions): Promise<unloadResult> {
try {
// Pass the PID as the session_id
const result = await invoke<unloadResult>('plugin:llamacpp|unload', {
const result = await invoke<unloadResult>('unload_llama_model', {
session_id: opts.sessionId, // Using PID as session ID
})
@ -232,27 +234,125 @@ export default class llamacpp_extension
}
}
async chat(
opts: chatOptions
private async *handleStreamingResponse(
url: string,
headers: HeadersInit,
body: string
): AsyncIterable<chatCompletionChunk> {
const response = await fetch(url, {
method: 'POST',
headers,
body,
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
)
}
if (!response.body) {
throw new Error('Response body is null')
}
const reader = response.body.getReader()
const decoder = new TextDecoder('utf-8')
let buffer = ''
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
buffer += decoder.decode(value, { stream: true })
// Process complete lines in the buffer
const lines = buffer.split('\n')
buffer = lines.pop() || '' // Keep the last incomplete line in the buffer
for (const line of lines) {
const trimmedLine = line.trim()
if (!trimmedLine || trimmedLine === 'data: [DONE]') {
continue
}
if (trimmedLine.startsWith('data: ')) {
const jsonStr = trimmedLine.slice(6)
try {
const chunk = JSON.parse(jsonStr) as chatCompletionChunk
yield chunk
} catch (e) {
console.error('Error parsing JSON from stream:', e)
}
}
}
}
} finally {
reader.releaseLock()
}
}
private findSessionByModel(modelName: string): sessionInfo | undefined {
for (const [, session] of this.activeSessions) {
if (session.modelName === modelName) {
return session
}
}
return undefined
}
override async chat(
opts: chatCompletionRequest
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
throw new Error("method not implemented yet")
const sessionInfo = this.findSessionByModel(opts.model)
if (!sessionInfo) {
throw new Error(`No active session found for model: ${opts.model}`)
}
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
const url = `${baseUrl}/chat/completions`
const headers = {
'Content-Type': 'application/json',
'Authorization': `Bearer test-k`,
}
const body = JSON.stringify(opts)
if (opts.stream) {
return this.handleStreamingResponse(url, headers, body)
}
// Handle non-streaming response
const response = await fetch(url, {
method: 'POST',
headers,
body,
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
)
}
return (await response.json()) as chatCompletion
}
async delete(opts: deleteOptions): Promise<deleteResult> {
throw new Error("method not implemented yet")
override async delete(opts: deleteOptions): Promise<deleteResult> {
throw new Error('method not implemented yet')
}
async import(opts: importOptions): Promise<importResult> {
throw new Error("method not implemented yet")
override async import(opts: importOptions): Promise<importResult> {
throw new Error('method not implemented yet')
}
async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
override async abortPull(opts: abortPullOptions): Promise<abortPullResult> {
throw new Error('method not implemented yet')
}
// Optional method for direct client access
getChatClient(sessionId: string): any {
throw new Error("method not implemented yet")
override getChatClient(sessionId: string): any {
throw new Error('method not implemented yet')
}
onUnload(): void {

View File

@ -1,214 +0,0 @@
// src/providers/local/types.ts
// --- Re-using OpenAI types (minimal definitions for this example) ---
// In a real project, you'd import these from 'openai' or a shared types package.
export interface chatCompletionRequestMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
content: string | null;
name?: string;
tool_calls?: any[]; // Simplified
tool_call_id?: string;
}
export interface chatCompletionRequest {
model: string; // Model ID, though for local it might be implicit via sessionId
messages: chatCompletionRequestMessage[];
temperature?: number | null;
top_p?: number | null;
n?: number | null;
stream?: boolean | null;
stop?: string | string[] | null;
max_tokens?: number;
presence_penalty?: number | null;
frequency_penalty?: number | null;
logit_bias?: Record<string, number> | null;
user?: string;
// ... TODO: other OpenAI params
}
export interface chatCompletionChunkChoiceDelta {
content?: string | null;
role?: 'system' | 'user' | 'assistant' | 'tool';
tool_calls?: any[]; // Simplified
}
export interface chatCompletionChunkChoice {
index: number;
delta: chatCompletionChunkChoiceDelta;
finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null;
}
export interface chatCompletionChunk {
id: string;
object: 'chat.completion.chunk';
created: number;
model: string;
choices: chatCompletionChunkChoice[];
system_fingerprint?: string;
}
export interface chatCompletionChoice {
index: number;
message: chatCompletionRequestMessage; // Response message
finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call';
logprobs?: any; // Simplified
}
export interface chatCompletion {
id: string;
object: 'chat.completion';
created: number;
model: string; // Model ID used
choices: chatCompletionChoice[];
usage?: {
prompt_tokens: number;
completion_tokens: number;
total_tokens: number;
};
system_fingerprint?: string;
}
// --- End OpenAI types ---
// Shared model metadata
export interface modelInfo {
id: string; // e.g. "qwen3-4B" or "org/model/quant"
name: string; // humanreadable, e.g., "Qwen3 4B Q4_0"
quant_type?: string; // q4_0 (optional as it might be part of ID or name)
providerId: string; // e.g. "llama.cpp"
port: number;
sizeBytes: number;
tags?: string[];
path?: string; // Absolute path to the model file, if applicable
// Additional provider-specific metadata can be added here
[key: string]: any;
}
// 1. /list
export interface listOptions {
providerId: string; // To specify which provider if a central manager calls this
}
export type listResult = modelInfo[];
// 2. /pull
export interface pullOptions {
providerId: string;
modelId: string; // Identifier for the model to pull (e.g., from a known registry)
downloadUrl: string; // URL to download the model from
/** optional callback to receive download progress */
onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void;
}
export interface pullResult {
success: boolean;
path?: string; // local file path to the pulled model
error?: string;
modelInfo?: modelInfo; // Info of the pulled model
}
// 3. /load
export interface loadOptions {
modelPath: string
port?: number
n_gpu_layers?: number
n_ctx?: number
threads?: number
threads_batch?: number
ctx_size?: number
n_predict?: number
batch_size?: number
ubatch_size?: number
device?: string
split_mode?: string
main_gpu?: number
flash_attn?: boolean
cont_batching?: boolean
no_mmap?: boolean
mlock?: boolean
no_kv_offload?: boolean
cache_type_k?: string
cache_type_v?: string
defrag_thold?: number
rope_scaling?: string
rope_scale?: number
rope_freq_base?: number
rope_freq_scale?: number
}
export interface sessionInfo {
sessionId: string; // opaque handle for unload/chat
port: number; // llama-server output port (corrected from portid)
modelPath: string; // path of the loaded model
settings: Record<string, unknown>; // The actual settings used to load
}
// 4. /unload
export interface unloadOptions {
providerId: string;
sessionId: string;
}
export interface unloadResult {
success: boolean;
error?: string;
}
// 5. /chat
export interface chatOptions {
providerId: string;
sessionId: string;
/** Full OpenAI ChatCompletionRequest payload */
payload: chatCompletionRequest;
}
// Output for /chat will be Promise<ChatCompletion> for non-streaming
// or Promise<AsyncIterable<ChatCompletionChunk>> for streaming
// 6. /delete
export interface deleteOptions {
providerId: string;
modelId: string; // The ID of the model to delete (implies finding its path)
modelPath?: string; // Optionally, direct path can be provided
}
export interface deleteResult {
success: boolean;
error?: string;
}
// 7. /import
export interface importOptions {
providerId: string;
sourcePath: string; // Path to the local model file to import
desiredModelId?: string; // Optional: if user wants to name it specifically
}
export interface importResult {
success: boolean;
modelInfo?: modelInfo;
error?: string;
}
// 8. /abortPull
export interface abortPullOptions {
providerId: string;
modelId: string; // The modelId whose download is to be aborted
}
export interface abortPullResult {
success: boolean;
error?: string;
}
// The interface for any local provider
export interface localProvider {
readonly providerId: string;
list(opts: listOptions): Promise<listResult>;
pull(opts: pullOptions): Promise<pullResult>;
load(opts: loadOptions): Promise<sessionInfo>;
unload(opts: unloadOptions): Promise<unloadResult>;
chat(opts: chatOptions): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>>;
delete(opts: deleteOptions): Promise<deleteResult>;
import(opts: importOptions): Promise<importResult>;
abortPull(opts: abortPullOptions): Promise<abortPullResult>;
// Optional: for direct access to underlying client if needed for specific streaming cases
getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session
}

View File

@ -79,7 +79,7 @@ pub struct UnloadResult {
// --- Load Command ---
#[tauri::command]
pub async fn load(
pub async fn load_llama_model(
app_handle: AppHandle, // Get the AppHandle
state: State<'_, AppState>, // Access the shared state
args: Vec<String>, // Arguments from the frontend
@ -143,7 +143,7 @@ pub async fn load(
// --- Unload Command ---
#[tauri::command]
pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
pub async fn unload_llama_model(session_id: String, state: State<'_, AppState>) -> ServerResult<UnloadResult> {
let mut process_lock = state.llama_server_process.lock().await;
// Take the child process out of the Option, leaving None in its place
if let Some(mut child) = process_lock.take() {

View File

@ -86,8 +86,8 @@ pub fn run() {
core::hardware::get_system_info,
core::hardware::get_system_usage,
// llama-cpp extension
core::utils::extensions::inference_llamacpp_extension::server::load,
core::utils::extensions::inference_llamacpp_extension::server::unload,
core::utils::extensions::inference_llamacpp_extension::server::load_llama_model,
core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model,
])
.manage(AppState {
app_token: Some(generate_app_token()),

View File

@ -211,7 +211,7 @@ export const stopModel = async (
): Promise<void> => {
const providerObj = EngineManager.instance().get(normalizeProvider(provider))
const modelObj = ModelManager.instance().get(model)
if (providerObj && modelObj) return providerObj?.unloadModel(modelObj)
if (providerObj && modelObj) return providerObj?.unload(modelObj)
}
/**