From a7a2dcc8d80a82f749a226f583a5294fb15be7ac Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Tue, 20 May 2025 19:33:26 +0530 Subject: [PATCH] refactor load/unload again; move types to core and refactor AIEngine abstract class --- .../browser/extensions/engines/AIEngine.ts | 246 +++++++++++++++++- extensions/llamacpp-extension/package.json | 3 +- extensions/llamacpp-extension/src/index.ts | 154 +++++++++-- extensions/llamacpp-extension/src/types.ts | 214 --------------- .../inference_llamacpp_extension/server.rs | 4 +- src-tauri/src/lib.rs | 4 +- web-app/src/lib/completion.ts | 2 +- 7 files changed, 375 insertions(+), 252 deletions(-) delete mode 100644 extensions/llamacpp-extension/src/types.ts diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index b5261b4e0..75f74f16c 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -1,15 +1,208 @@ -import { events } from '../../events' import { BaseExtension } from '../../extension' -import { MessageRequest, Model, ModelEvent } from '../../../types' import { EngineManager } from './EngineManager' +/* AIEngine class types */ + +export interface chatCompletionRequestMessage { + role: 'system' | 'user' | 'assistant' | 'tool' + content: string | null + name?: string + tool_calls?: any[] // Simplified + tool_call_id?: string +} + +export interface chatCompletionRequest { + provider: string, + model: string // Model ID, though for local it might be implicit via sessionId + messages: chatCompletionRequestMessage[] + temperature?: number | null + top_p?: number | null + n?: number | null + stream?: boolean | null + stop?: string | string[] | null + max_tokens?: number + presence_penalty?: number | null + frequency_penalty?: number | null + logit_bias?: { [key: string]: number } | null + user?: string + // ... TODO: other OpenAI params +} + +export interface chatCompletionChunkChoiceDelta { + content?: string | null + role?: 'system' | 'user' | 'assistant' | 'tool' + tool_calls?: any[] // Simplified +} + +export interface chatCompletionChunkChoice { + index: number + delta: chatCompletionChunkChoiceDelta + finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null +} + +export interface chatCompletionChunk { + id: string + object: 'chat.completion.chunk' + created: number + model: string + choices: chatCompletionChunkChoice[] + system_fingerprint?: string +} + +export interface chatCompletionChoice { + index: number + message: chatCompletionRequestMessage // Response message + finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' + logprobs?: any // Simplified +} + +export interface chatCompletion { + id: string + object: 'chat.completion' + created: number + model: string // Model ID used + choices: chatCompletionChoice[] + usage?: { + prompt_tokens: number + completion_tokens: number + total_tokens: number + } + system_fingerprint?: string +} +// --- End OpenAI types --- + +// Shared model metadata +export interface modelInfo { + id: string // e.g. "qwen3-4B" or "org/model/quant" + name: string // human‑readable, e.g., "Qwen3 4B Q4_0" + quant_type?: string // q4_0 (optional as it might be part of ID or name) + providerId: string // e.g. "llama.cpp" + port: number + sizeBytes: number + tags?: string[] + path?: string // Absolute path to the model file, if applicable + // Additional provider-specific metadata can be added here + [key: string]: any +} + +// 1. /list +export interface listOptions { + providerId: string // To specify which provider if a central manager calls this +} +export type listResult = modelInfo[] + +// 2. /pull +export interface pullOptions { + providerId: string + modelId: string // Identifier for the model to pull (e.g., from a known registry) + downloadUrl: string // URL to download the model from + /** optional callback to receive download progress */ + onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number }) => void +} +export interface pullResult { + success: boolean + path?: string // local file path to the pulled model + error?: string + modelInfo?: modelInfo // Info of the pulled model +} + +// 3. /load +export interface loadOptions { + modelPath: string + port?: number + n_gpu_layers?: number + n_ctx?: number + threads?: number + threads_batch?: number + ctx_size?: number + n_predict?: number + batch_size?: number + ubatch_size?: number + device?: string + split_mode?: string + main_gpu?: number + flash_attn?: boolean + cont_batching?: boolean + no_mmap?: boolean + mlock?: boolean + no_kv_offload?: boolean + cache_type_k?: string + cache_type_v?: string + defrag_thold?: number + rope_scaling?: string + rope_scale?: number + rope_freq_base?: number + rope_freq_scale?: number +} + +export interface sessionInfo { + sessionId: string // opaque handle for unload/chat + port: number // llama-server output port (corrected from portid) + modelName: string, //name of the model + modelPath: string // path of the loaded model +} + +// 4. /unload +export interface unloadOptions { + providerId: string + sessionId: string +} +export interface unloadResult { + success: boolean + error?: string +} + +// 5. /chat +export interface chatOptions { + providerId: string + sessionId: string + /** Full OpenAI ChatCompletionRequest payload */ + payload: chatCompletionRequest +} +// Output for /chat will be Promise for non-streaming +// or Promise> for streaming + +// 6. /delete +export interface deleteOptions { + providerId: string + modelId: string // The ID of the model to delete (implies finding its path) + modelPath?: string // Optionally, direct path can be provided +} +export interface deleteResult { + success: boolean + error?: string +} + +// 7. /import +export interface importOptions { + providerId: string + sourcePath: string // Path to the local model file to import + desiredModelId?: string // Optional: if user wants to name it specifically +} +export interface importResult { + success: boolean + modelInfo?: modelInfo + error?: string +} + +// 8. /abortPull +export interface abortPullOptions { + providerId: string + modelId: string // The modelId whose download is to be aborted +} +export interface abortPullResult { + success: boolean + error?: string +} + /** * Base AIEngine * Applicable to all AI Engines */ + export abstract class AIEngine extends BaseExtension { - // The inference engine - abstract provider: string + // The inference engine ID, implementing the readonly providerId from interface + abstract readonly provider: string /** * On extension load, subscribe to events. @@ -24,4 +217,49 @@ export abstract class AIEngine extends BaseExtension { registerEngine() { EngineManager.instance().register(this) } + + /** + * Lists available models + */ + abstract list(opts: listOptions): Promise + + /** + * Pulls/downloads a model + */ + abstract pull(opts: pullOptions): Promise + + /** + * Loads a model into memory + */ + abstract load(opts: loadOptions): Promise + + /** + * Unloads a model from memory + */ + abstract unload(opts: unloadOptions): Promise + + /** + * Sends a chat request to the model + */ + abstract chat(opts: chatCompletionRequest): Promise> + + /** + * Deletes a model + */ + abstract delete(opts: deleteOptions): Promise + + /** + * Imports a model + */ + abstract import(opts: importOptions): Promise + + /** + * Aborts an ongoing model pull + */ + abstract abortPull(opts: abortPullOptions): Promise + + /** + * Optional method to get the underlying chat client + */ + getChatClient?(sessionId: string): any } diff --git a/extensions/llamacpp-extension/package.json b/extensions/llamacpp-extension/package.json index 297c4225a..746ad9d06 100644 --- a/extensions/llamacpp-extension/package.json +++ b/extensions/llamacpp-extension/package.json @@ -30,8 +30,7 @@ }, "files": [ "dist/*", - "package.json", - "README.md" + "package.json" ], "bundleDependencies": [ "fetch-retry" diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 188945763..3edd1b6e5 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -6,11 +6,11 @@ * @module llamacpp-extension/src/index */ -import { AIEngine, getJanDataFolderPath, fs, joinPath } from '@janhq/core' - -import { invoke } from '@tauri-apps/api/tauri' import { - localProvider, + AIEngine, + getJanDataFolderPath, + fs, + joinPath, modelInfo, listOptions, listResult, @@ -30,7 +30,9 @@ import { abortPullOptions, abortPullResult, chatCompletionRequest, -} from './types' +} from '@janhq/core' + +import { invoke } from '@tauri-apps/api/tauri' /** * Helper to convert GGUF model filename to a more structured ID/name @@ -56,10 +58,7 @@ function parseGGUFFileName(filename: string): { * The class provides methods for initializing and stopping a model, and for making inference requests. * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ -export default class llamacpp_extension - extends AIEngine - implements localProvider -{ +export default class llamacpp_extension extends AIEngine { provider: string = 'llamacpp' readonly providerId: string = 'llamacpp' @@ -79,17 +78,20 @@ export default class llamacpp_extension } // Implement the required LocalProvider interface methods - async list(opts: listOptions): Promise { + override async list(opts: listOptions): Promise { throw new Error('method not implemented yet') } - async pull(opts: pullOptions): Promise { + override async pull(opts: pullOptions): Promise { throw new Error('method not implemented yet') } - async load(opts: loadOptions): Promise { + override async load(opts: loadOptions): Promise { const args: string[] = [] + // disable llama-server webui + args.push('--no-webui') + // model option is required args.push('-m', opts.modelPath) args.push('--port', String(opts.port || 8080)) // Default port if not specified @@ -193,24 +195,24 @@ export default class llamacpp_extension console.log('Calling Tauri command load with args:', args) try { - const sessionInfo = await invoke('plugin:llamacpp|load', { + const sInfo = await invoke('load_llama_model', { args: args, }) // Store the session info for later use - this.activeSessions.set(sessionInfo.sessionId, sessionInfo) + this.activeSessions.set(sInfo.sessionId, sInfo) - return sessionInfo + return sInfo } catch (error) { console.error('Error loading llama-server:', error) throw new Error(`Failed to load llama-server: ${error}`) } } - async unload(opts: unloadOptions): Promise { + override async unload(opts: unloadOptions): Promise { try { // Pass the PID as the session_id - const result = await invoke('plugin:llamacpp|unload', { + const result = await invoke('unload_llama_model', { session_id: opts.sessionId, // Using PID as session ID }) @@ -232,27 +234,125 @@ export default class llamacpp_extension } } - async chat( - opts: chatOptions + private async *handleStreamingResponse( + url: string, + headers: HeadersInit, + body: string + ): AsyncIterable { + const response = await fetch(url, { + method: 'POST', + headers, + body, + }) + if (!response.ok) { + const errorData = await response.json().catch(() => null) + throw new Error( + `API request failed with status ${response.status}: ${JSON.stringify(errorData)}` + ) + } + + if (!response.body) { + throw new Error('Response body is null') + } + + const reader = response.body.getReader() + const decoder = new TextDecoder('utf-8') + let buffer = '' + try { + while (true) { + const { done, value } = await reader.read() + + if (done) { + break + } + + buffer += decoder.decode(value, { stream: true }) + + // Process complete lines in the buffer + const lines = buffer.split('\n') + buffer = lines.pop() || '' // Keep the last incomplete line in the buffer + + for (const line of lines) { + const trimmedLine = line.trim() + if (!trimmedLine || trimmedLine === 'data: [DONE]') { + continue + } + + if (trimmedLine.startsWith('data: ')) { + const jsonStr = trimmedLine.slice(6) + try { + const chunk = JSON.parse(jsonStr) as chatCompletionChunk + yield chunk + } catch (e) { + console.error('Error parsing JSON from stream:', e) + } + } + } + } + } finally { + reader.releaseLock() + } + } + + private findSessionByModel(modelName: string): sessionInfo | undefined { + for (const [, session] of this.activeSessions) { + if (session.modelName === modelName) { + return session + } + } + return undefined + } + + override async chat( + opts: chatCompletionRequest ): Promise> { - throw new Error("method not implemented yet") + const sessionInfo = this.findSessionByModel(opts.model) + if (!sessionInfo) { + throw new Error(`No active session found for model: ${opts.model}`) + } + const baseUrl = `http://localhost:${sessionInfo.port}/v1` + const url = `${baseUrl}/chat/completions` + const headers = { + 'Content-Type': 'application/json', + 'Authorization': `Bearer test-k`, + } + + const body = JSON.stringify(opts) + if (opts.stream) { + return this.handleStreamingResponse(url, headers, body) + } + // Handle non-streaming response + const response = await fetch(url, { + method: 'POST', + headers, + body, + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => null) + throw new Error( + `API request failed with status ${response.status}: ${JSON.stringify(errorData)}` + ) + } + + return (await response.json()) as chatCompletion } - async delete(opts: deleteOptions): Promise { - throw new Error("method not implemented yet") + override async delete(opts: deleteOptions): Promise { + throw new Error('method not implemented yet') } - async import(opts: importOptions): Promise { - throw new Error("method not implemented yet") + override async import(opts: importOptions): Promise { + throw new Error('method not implemented yet') } - async abortPull(opts: abortPullOptions): Promise { + override async abortPull(opts: abortPullOptions): Promise { throw new Error('method not implemented yet') } // Optional method for direct client access - getChatClient(sessionId: string): any { - throw new Error("method not implemented yet") + override getChatClient(sessionId: string): any { + throw new Error('method not implemented yet') } onUnload(): void { diff --git a/extensions/llamacpp-extension/src/types.ts b/extensions/llamacpp-extension/src/types.ts deleted file mode 100644 index dda61f2fd..000000000 --- a/extensions/llamacpp-extension/src/types.ts +++ /dev/null @@ -1,214 +0,0 @@ -// src/providers/local/types.ts - -// --- Re-using OpenAI types (minimal definitions for this example) --- -// In a real project, you'd import these from 'openai' or a shared types package. -export interface chatCompletionRequestMessage { - role: 'system' | 'user' | 'assistant' | 'tool'; - content: string | null; - name?: string; - tool_calls?: any[]; // Simplified - tool_call_id?: string; -} - -export interface chatCompletionRequest { - model: string; // Model ID, though for local it might be implicit via sessionId - messages: chatCompletionRequestMessage[]; - temperature?: number | null; - top_p?: number | null; - n?: number | null; - stream?: boolean | null; - stop?: string | string[] | null; - max_tokens?: number; - presence_penalty?: number | null; - frequency_penalty?: number | null; - logit_bias?: Record | null; - user?: string; - // ... TODO: other OpenAI params -} - -export interface chatCompletionChunkChoiceDelta { - content?: string | null; - role?: 'system' | 'user' | 'assistant' | 'tool'; - tool_calls?: any[]; // Simplified -} - -export interface chatCompletionChunkChoice { - index: number; - delta: chatCompletionChunkChoiceDelta; - finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null; -} - -export interface chatCompletionChunk { - id: string; - object: 'chat.completion.chunk'; - created: number; - model: string; - choices: chatCompletionChunkChoice[]; - system_fingerprint?: string; -} - - -export interface chatCompletionChoice { - index: number; - message: chatCompletionRequestMessage; // Response message - finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call'; - logprobs?: any; // Simplified -} - -export interface chatCompletion { - id: string; - object: 'chat.completion'; - created: number; - model: string; // Model ID used - choices: chatCompletionChoice[]; - usage?: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; - system_fingerprint?: string; -} -// --- End OpenAI types --- - - -// Shared model metadata -export interface modelInfo { - id: string; // e.g. "qwen3-4B" or "org/model/quant" - name: string; // human‑readable, e.g., "Qwen3 4B Q4_0" - quant_type?: string; // q4_0 (optional as it might be part of ID or name) - providerId: string; // e.g. "llama.cpp" - port: number; - sizeBytes: number; - tags?: string[]; - path?: string; // Absolute path to the model file, if applicable - // Additional provider-specific metadata can be added here - [key: string]: any; -} - -// 1. /list -export interface listOptions { - providerId: string; // To specify which provider if a central manager calls this -} -export type listResult = modelInfo[]; - -// 2. /pull -export interface pullOptions { - providerId: string; - modelId: string; // Identifier for the model to pull (e.g., from a known registry) - downloadUrl: string; // URL to download the model from - /** optional callback to receive download progress */ - onProgress?: (progress: { percent: number; downloadedBytes: number; totalBytes?: number; }) => void; -} -export interface pullResult { - success: boolean; - path?: string; // local file path to the pulled model - error?: string; - modelInfo?: modelInfo; // Info of the pulled model -} - -// 3. /load -export interface loadOptions { - modelPath: string - port?: number - n_gpu_layers?: number - n_ctx?: number - threads?: number - threads_batch?: number - ctx_size?: number - n_predict?: number - batch_size?: number - ubatch_size?: number - device?: string - split_mode?: string - main_gpu?: number - flash_attn?: boolean - cont_batching?: boolean - no_mmap?: boolean - mlock?: boolean - no_kv_offload?: boolean - cache_type_k?: string - cache_type_v?: string - defrag_thold?: number - rope_scaling?: string - rope_scale?: number - rope_freq_base?: number - rope_freq_scale?: number -} - -export interface sessionInfo { - sessionId: string; // opaque handle for unload/chat - port: number; // llama-server output port (corrected from portid) - modelPath: string; // path of the loaded model - settings: Record; // The actual settings used to load -} - -// 4. /unload -export interface unloadOptions { - providerId: string; - sessionId: string; -} -export interface unloadResult { - success: boolean; - error?: string; -} - -// 5. /chat -export interface chatOptions { - providerId: string; - sessionId: string; - /** Full OpenAI ChatCompletionRequest payload */ - payload: chatCompletionRequest; -} -// Output for /chat will be Promise for non-streaming -// or Promise> for streaming - -// 6. /delete -export interface deleteOptions { - providerId: string; - modelId: string; // The ID of the model to delete (implies finding its path) - modelPath?: string; // Optionally, direct path can be provided -} -export interface deleteResult { - success: boolean; - error?: string; -} - -// 7. /import -export interface importOptions { - providerId: string; - sourcePath: string; // Path to the local model file to import - desiredModelId?: string; // Optional: if user wants to name it specifically -} -export interface importResult { - success: boolean; - modelInfo?: modelInfo; - error?: string; -} - -// 8. /abortPull -export interface abortPullOptions { - providerId: string; - modelId: string; // The modelId whose download is to be aborted -} -export interface abortPullResult { - success: boolean; - error?: string; -} - - -// The interface for any local provider -export interface localProvider { - readonly providerId: string; - - list(opts: listOptions): Promise; - pull(opts: pullOptions): Promise; - load(opts: loadOptions): Promise; - unload(opts: unloadOptions): Promise; - chat(opts: chatOptions): Promise>; - delete(opts: deleteOptions): Promise; - import(opts: importOptions): Promise; - abortPull(opts: abortPullOptions): Promise; - - // Optional: for direct access to underlying client if needed for specific streaming cases - getChatClient?(sessionId: string): any; // e.g., an OpenAI client instance configured for the session -} diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index a990b7780..283d07849 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -79,7 +79,7 @@ pub struct UnloadResult { // --- Load Command --- #[tauri::command] -pub async fn load( +pub async fn load_llama_model( app_handle: AppHandle, // Get the AppHandle state: State<'_, AppState>, // Access the shared state args: Vec, // Arguments from the frontend @@ -143,7 +143,7 @@ pub async fn load( // --- Unload Command --- #[tauri::command] -pub async fn unload(session_id: String, state: State<'_, AppState>) -> ServerResult { +pub async fn unload_llama_model(session_id: String, state: State<'_, AppState>) -> ServerResult { let mut process_lock = state.llama_server_process.lock().await; // Take the child process out of the Option, leaving None in its place if let Some(mut child) = process_lock.take() { diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 68f497607..23636a883 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -86,8 +86,8 @@ pub fn run() { core::hardware::get_system_info, core::hardware::get_system_usage, // llama-cpp extension - core::utils::extensions::inference_llamacpp_extension::server::load, - core::utils::extensions::inference_llamacpp_extension::server::unload, + core::utils::extensions::inference_llamacpp_extension::server::load_llama_model, + core::utils::extensions::inference_llamacpp_extension::server::unload_llama_model, ]) .manage(AppState { app_token: Some(generate_app_token()), diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 5ffd4fa4b..cbdd3cc77 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -211,7 +211,7 @@ export const stopModel = async ( ): Promise => { const providerObj = EngineManager.instance().get(normalizeProvider(provider)) const modelObj = ModelManager.instance().get(model) - if (providerObj && modelObj) return providerObj?.unloadModel(modelObj) + if (providerObj && modelObj) return providerObj?.unload(modelObj) } /**