diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 885199869..c9b9fa361 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -158,7 +158,7 @@ export interface chatOptions { // 7. /import export interface ImportOptions { modelPath: string - mmprojPath: string + mmprojPath?: string } export interface importResult { @@ -193,7 +193,7 @@ export abstract class AIEngine extends BaseExtension { /** * Lists available models */ - abstract list(): Promise + abstract list(): Promise /** * Loads a model into memory diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 6157b2640..bbd682145 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -68,6 +68,22 @@ interface ModelConfig { size_bytes: number } +interface EmbeddingResponse { + model: string + object: string + usage: { + prompt_tokens: number + total_tokens: number + } + data: EmbeddingData[] +} + +interface EmbeddingData { + embedding: number[] + index: number + object: string +} + /** * A class that implements the InferenceExtension interface from the @janhq/core package. * The class provides methods for initializing and stopping a model, and for making inference requests. @@ -370,24 +386,30 @@ export default class llamacpp_extension extends AIEngine { } private async sleep(ms: number): Promise { - return new Promise(resolve => setTimeout(resolve, ms)) + return new Promise((resolve) => setTimeout(resolve, ms)) } - private async waitForModelLoad(port: number, timeoutMs = 30_000): Promise { - const start = Date.now() - while (Date.now() - start < timeoutMs) { - try { - const res = await fetch(`http://localhost:${port}/health`) - if(res.ok) { - return - } - } catch (e) {} - await this.sleep(500) // 500 sec interval during rechecks - } - throw new Error(`Timed out loading model after ${timeoutMs}`) + private async waitForModelLoad( + port: number, + timeoutMs = 30_000 + ): Promise { + const start = Date.now() + while (Date.now() - start < timeoutMs) { + try { + const res = await fetch(`http://localhost:${port}/health`) + if (res.ok) { + return + } + } catch (e) {} + await this.sleep(500) // 500 sec interval during rechecks + } + throw new Error(`Timed out loading model after ${timeoutMs}`) } - override async load(modelId: string): Promise { + override async load( + modelId: string, + isEmbedding: boolean = false + ): Promise { const sInfo = this.findSessionByModel(modelId) if (sInfo) { throw new Error('Model already loaded!!') @@ -444,8 +466,6 @@ export default class llamacpp_extension extends AIEngine { if (cfg.threads > 0) args.push('--threads', String(cfg.threads)) if (cfg.threads_batch > 0) args.push('--threads-batch', String(cfg.threads_batch)) - if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size)) - if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict)) if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size)) if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size)) if (cfg.device.length > 0) args.push('--device', cfg.device) @@ -459,16 +479,22 @@ export default class llamacpp_extension extends AIEngine { if (cfg.no_mmap) args.push('--no-mmap') if (cfg.mlock) args.push('--mlock') if (cfg.no_kv_offload) args.push('--no-kv-offload') + if (isEmbedding) { + args.push('--embedding') + args.push('--pooling mean') + } else { + if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size)) + if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict)) + args.push('--cache-type-k', cfg.cache_type_k) + args.push('--cache-type-v', cfg.cache_type_v) + args.push('--defrag-thold', String(cfg.defrag_thold)) - args.push('--cache-type-k', cfg.cache_type_k) - args.push('--cache-type-v', cfg.cache_type_v) - args.push('--defrag-thold', String(cfg.defrag_thold)) - - args.push('--rope-scaling', cfg.rope_scaling) - args.push('--rope-scale', String(cfg.rope_scale)) - args.push('--rope-freq-base', String(cfg.rope_freq_base)) - args.push('--rope-freq-scale', String(cfg.rope_freq_scale)) - args.push('--reasoning-budget', String(cfg.reasoning_budget)) + args.push('--rope-scaling', cfg.rope_scaling) + args.push('--rope-scale', String(cfg.rope_scale)) + args.push('--rope-freq-base', String(cfg.rope_freq_base)) + args.push('--rope-freq-scale', String(cfg.rope_freq_scale)) + args.push('--reasoning-budget', String(cfg.reasoning_budget)) + } console.log('Calling Tauri command llama_load with args:', args) const backendPath = await getBackendExePath(backend, version) @@ -479,7 +505,7 @@ export default class llamacpp_extension extends AIEngine { const sInfo = await invoke('load_llama_model', { backendPath, libraryPath, - args + args, }) await this.waitForModelLoad(sInfo.port) @@ -503,7 +529,7 @@ export default class llamacpp_extension extends AIEngine { try { // Pass the PID as the session_id const result = await invoke('unload_llama_model', { - pid: pid + pid: pid, }) // If successful, remove from active sessions @@ -648,6 +674,48 @@ export default class llamacpp_extension extends AIEngine { return lmodels } + async embed(text: string[]): Promise { + let sInfo = this.findSessionByModel('sentence-transformer-mini') + if (!sInfo) { + const downloadedModelList = await this.list() + if ( + !downloadedModelList.some( + (model) => model.id === 'sentence-transformer-mini' + ) + ) { + await this.import('sentence-transformer-mini', { + modelPath: + 'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true', + }) + } + sInfo = await this.load('sentence-transformer-mini') + } + const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings` + const headers = { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${sInfo.api_key}`, + } + const body = JSON.stringify({ + input: text, + model: sInfo.model_id, + encoding_format: 'float', + }) + const response = await fetch(baseUrl, { + method: 'POST', + headers, + body, + }) + + if (!response.ok) { + const errorData = await response.json().catch(() => null) + throw new Error( + `API request failed with status ${response.status}: ${JSON.stringify(errorData)}` + ) + } + const responseData = await response.json() + return responseData as EmbeddingResponse + } + // Optional method for direct client access override getChatClient(sessionId: string): any { throw new Error('method not implemented yet')