feat: add embedding support to llamacpp extension

This commit introduces embedding functionality to the llamacpp extension. It allows users to generate embeddings for text inputs using the 'sentence-transformer-mini' model.  The changes include:

- Adding a new `embed` method to the `llamacpp_extension` class.
- Implementing model loading and API interaction for embeddings.
- Handling potential errors during API requests.
- Adding necessary types for embedding responses and data.
- The load method now accepts a boolean parameter to determine if it should load embedding model.
This commit is contained in:
Akarshan 2025-06-16 12:26:28 +05:30 committed by Louis
parent 2eeabf8ae6
commit 48d1164858
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
2 changed files with 97 additions and 29 deletions

View File

@ -158,7 +158,7 @@ export interface chatOptions {
// 7. /import // 7. /import
export interface ImportOptions { export interface ImportOptions {
modelPath: string modelPath: string
mmprojPath: string mmprojPath?: string
} }
export interface importResult { export interface importResult {
@ -193,7 +193,7 @@ export abstract class AIEngine extends BaseExtension {
/** /**
* Lists available models * Lists available models
*/ */
abstract list(): Promise<listResult> abstract list(): Promise<modelInfo[]>
/** /**
* Loads a model into memory * Loads a model into memory

View File

@ -68,6 +68,22 @@ interface ModelConfig {
size_bytes: number size_bytes: number
} }
interface EmbeddingResponse {
model: string
object: string
usage: {
prompt_tokens: number
total_tokens: number
}
data: EmbeddingData[]
}
interface EmbeddingData {
embedding: number[]
index: number
object: string
}
/** /**
* A class that implements the InferenceExtension interface from the @janhq/core package. * A class that implements the InferenceExtension interface from the @janhq/core package.
* The class provides methods for initializing and stopping a model, and for making inference requests. * The class provides methods for initializing and stopping a model, and for making inference requests.
@ -370,24 +386,30 @@ export default class llamacpp_extension extends AIEngine {
} }
private async sleep(ms: number): Promise<void> { private async sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms)) return new Promise((resolve) => setTimeout(resolve, ms))
} }
private async waitForModelLoad(port: number, timeoutMs = 30_000): Promise<void> { private async waitForModelLoad(
const start = Date.now() port: number,
while (Date.now() - start < timeoutMs) { timeoutMs = 30_000
try { ): Promise<void> {
const res = await fetch(`http://localhost:${port}/health`) const start = Date.now()
if(res.ok) { while (Date.now() - start < timeoutMs) {
return try {
} const res = await fetch(`http://localhost:${port}/health`)
} catch (e) {} if (res.ok) {
await this.sleep(500) // 500 sec interval during rechecks return
} }
throw new Error(`Timed out loading model after ${timeoutMs}`) } catch (e) {}
await this.sleep(500) // 500 sec interval during rechecks
}
throw new Error(`Timed out loading model after ${timeoutMs}`)
} }
override async load(modelId: string): Promise<SessionInfo> { override async load(
modelId: string,
isEmbedding: boolean = false
): Promise<SessionInfo> {
const sInfo = this.findSessionByModel(modelId) const sInfo = this.findSessionByModel(modelId)
if (sInfo) { if (sInfo) {
throw new Error('Model already loaded!!') throw new Error('Model already loaded!!')
@ -444,8 +466,6 @@ export default class llamacpp_extension extends AIEngine {
if (cfg.threads > 0) args.push('--threads', String(cfg.threads)) if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
if (cfg.threads_batch > 0) if (cfg.threads_batch > 0)
args.push('--threads-batch', String(cfg.threads_batch)) args.push('--threads-batch', String(cfg.threads_batch))
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size)) if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size)) if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
if (cfg.device.length > 0) args.push('--device', cfg.device) if (cfg.device.length > 0) args.push('--device', cfg.device)
@ -459,16 +479,22 @@ export default class llamacpp_extension extends AIEngine {
if (cfg.no_mmap) args.push('--no-mmap') if (cfg.no_mmap) args.push('--no-mmap')
if (cfg.mlock) args.push('--mlock') if (cfg.mlock) args.push('--mlock')
if (cfg.no_kv_offload) args.push('--no-kv-offload') if (cfg.no_kv_offload) args.push('--no-kv-offload')
if (isEmbedding) {
args.push('--embedding')
args.push('--pooling mean')
} else {
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
args.push('--cache-type-k', cfg.cache_type_k)
args.push('--cache-type-v', cfg.cache_type_v)
args.push('--defrag-thold', String(cfg.defrag_thold))
args.push('--cache-type-k', cfg.cache_type_k) args.push('--rope-scaling', cfg.rope_scaling)
args.push('--cache-type-v', cfg.cache_type_v) args.push('--rope-scale', String(cfg.rope_scale))
args.push('--defrag-thold', String(cfg.defrag_thold)) args.push('--rope-freq-base', String(cfg.rope_freq_base))
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
args.push('--rope-scaling', cfg.rope_scaling) args.push('--reasoning-budget', String(cfg.reasoning_budget))
args.push('--rope-scale', String(cfg.rope_scale)) }
args.push('--rope-freq-base', String(cfg.rope_freq_base))
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
args.push('--reasoning-budget', String(cfg.reasoning_budget))
console.log('Calling Tauri command llama_load with args:', args) console.log('Calling Tauri command llama_load with args:', args)
const backendPath = await getBackendExePath(backend, version) const backendPath = await getBackendExePath(backend, version)
@ -479,7 +505,7 @@ export default class llamacpp_extension extends AIEngine {
const sInfo = await invoke<SessionInfo>('load_llama_model', { const sInfo = await invoke<SessionInfo>('load_llama_model', {
backendPath, backendPath,
libraryPath, libraryPath,
args args,
}) })
await this.waitForModelLoad(sInfo.port) await this.waitForModelLoad(sInfo.port)
@ -503,7 +529,7 @@ export default class llamacpp_extension extends AIEngine {
try { try {
// Pass the PID as the session_id // Pass the PID as the session_id
const result = await invoke<UnloadResult>('unload_llama_model', { const result = await invoke<UnloadResult>('unload_llama_model', {
pid: pid pid: pid,
}) })
// If successful, remove from active sessions // If successful, remove from active sessions
@ -648,6 +674,48 @@ export default class llamacpp_extension extends AIEngine {
return lmodels return lmodels
} }
async embed(text: string[]): Promise<EmbeddingResponse> {
let sInfo = this.findSessionByModel('sentence-transformer-mini')
if (!sInfo) {
const downloadedModelList = await this.list()
if (
!downloadedModelList.some(
(model) => model.id === 'sentence-transformer-mini'
)
) {
await this.import('sentence-transformer-mini', {
modelPath:
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',
})
}
sInfo = await this.load('sentence-transformer-mini')
}
const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings`
const headers = {
'Content-Type': 'application/json',
'Authorization': `Bearer ${sInfo.api_key}`,
}
const body = JSON.stringify({
input: text,
model: sInfo.model_id,
encoding_format: 'float',
})
const response = await fetch(baseUrl, {
method: 'POST',
headers,
body,
})
if (!response.ok) {
const errorData = await response.json().catch(() => null)
throw new Error(
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
)
}
const responseData = await response.json()
return responseData as EmbeddingResponse
}
// Optional method for direct client access // Optional method for direct client access
override getChatClient(sessionId: string): any { override getChatClient(sessionId: string): any {
throw new Error('method not implemented yet') throw new Error('method not implemented yet')