feat: add embedding support to llamacpp extension
This commit introduces embedding functionality to the llamacpp extension. It allows users to generate embeddings for text inputs using the 'sentence-transformer-mini' model. The changes include: - Adding a new `embed` method to the `llamacpp_extension` class. - Implementing model loading and API interaction for embeddings. - Handling potential errors during API requests. - Adding necessary types for embedding responses and data. - The load method now accepts a boolean parameter to determine if it should load embedding model.
This commit is contained in:
parent
2eeabf8ae6
commit
48d1164858
@ -158,7 +158,7 @@ export interface chatOptions {
|
|||||||
// 7. /import
|
// 7. /import
|
||||||
export interface ImportOptions {
|
export interface ImportOptions {
|
||||||
modelPath: string
|
modelPath: string
|
||||||
mmprojPath: string
|
mmprojPath?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface importResult {
|
export interface importResult {
|
||||||
@ -193,7 +193,7 @@ export abstract class AIEngine extends BaseExtension {
|
|||||||
/**
|
/**
|
||||||
* Lists available models
|
* Lists available models
|
||||||
*/
|
*/
|
||||||
abstract list(): Promise<listResult>
|
abstract list(): Promise<modelInfo[]>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a model into memory
|
* Loads a model into memory
|
||||||
|
|||||||
@ -68,6 +68,22 @@ interface ModelConfig {
|
|||||||
size_bytes: number
|
size_bytes: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface EmbeddingResponse {
|
||||||
|
model: string
|
||||||
|
object: string
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: number
|
||||||
|
total_tokens: number
|
||||||
|
}
|
||||||
|
data: EmbeddingData[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface EmbeddingData {
|
||||||
|
embedding: number[]
|
||||||
|
index: number
|
||||||
|
object: string
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||||||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||||||
@ -370,10 +386,13 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async sleep(ms: number): Promise<void> {
|
private async sleep(ms: number): Promise<void> {
|
||||||
return new Promise(resolve => setTimeout(resolve, ms))
|
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||||
}
|
}
|
||||||
|
|
||||||
private async waitForModelLoad(port: number, timeoutMs = 30_000): Promise<void> {
|
private async waitForModelLoad(
|
||||||
|
port: number,
|
||||||
|
timeoutMs = 30_000
|
||||||
|
): Promise<void> {
|
||||||
const start = Date.now()
|
const start = Date.now()
|
||||||
while (Date.now() - start < timeoutMs) {
|
while (Date.now() - start < timeoutMs) {
|
||||||
try {
|
try {
|
||||||
@ -387,7 +406,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
throw new Error(`Timed out loading model after ${timeoutMs}`)
|
throw new Error(`Timed out loading model after ${timeoutMs}`)
|
||||||
}
|
}
|
||||||
|
|
||||||
override async load(modelId: string): Promise<SessionInfo> {
|
override async load(
|
||||||
|
modelId: string,
|
||||||
|
isEmbedding: boolean = false
|
||||||
|
): Promise<SessionInfo> {
|
||||||
const sInfo = this.findSessionByModel(modelId)
|
const sInfo = this.findSessionByModel(modelId)
|
||||||
if (sInfo) {
|
if (sInfo) {
|
||||||
throw new Error('Model already loaded!!')
|
throw new Error('Model already loaded!!')
|
||||||
@ -444,8 +466,6 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||||||
if (cfg.threads_batch > 0)
|
if (cfg.threads_batch > 0)
|
||||||
args.push('--threads-batch', String(cfg.threads_batch))
|
args.push('--threads-batch', String(cfg.threads_batch))
|
||||||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
|
||||||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
|
||||||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||||||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||||||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||||||
@ -459,7 +479,12 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
if (cfg.no_mmap) args.push('--no-mmap')
|
if (cfg.no_mmap) args.push('--no-mmap')
|
||||||
if (cfg.mlock) args.push('--mlock')
|
if (cfg.mlock) args.push('--mlock')
|
||||||
if (cfg.no_kv_offload) args.push('--no-kv-offload')
|
if (cfg.no_kv_offload) args.push('--no-kv-offload')
|
||||||
|
if (isEmbedding) {
|
||||||
|
args.push('--embedding')
|
||||||
|
args.push('--pooling mean')
|
||||||
|
} else {
|
||||||
|
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||||||
|
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||||||
args.push('--cache-type-k', cfg.cache_type_k)
|
args.push('--cache-type-k', cfg.cache_type_k)
|
||||||
args.push('--cache-type-v', cfg.cache_type_v)
|
args.push('--cache-type-v', cfg.cache_type_v)
|
||||||
args.push('--defrag-thold', String(cfg.defrag_thold))
|
args.push('--defrag-thold', String(cfg.defrag_thold))
|
||||||
@ -469,6 +494,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
args.push('--rope-freq-base', String(cfg.rope_freq_base))
|
args.push('--rope-freq-base', String(cfg.rope_freq_base))
|
||||||
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
|
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
|
||||||
args.push('--reasoning-budget', String(cfg.reasoning_budget))
|
args.push('--reasoning-budget', String(cfg.reasoning_budget))
|
||||||
|
}
|
||||||
|
|
||||||
console.log('Calling Tauri command llama_load with args:', args)
|
console.log('Calling Tauri command llama_load with args:', args)
|
||||||
const backendPath = await getBackendExePath(backend, version)
|
const backendPath = await getBackendExePath(backend, version)
|
||||||
@ -479,7 +505,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
const sInfo = await invoke<SessionInfo>('load_llama_model', {
|
const sInfo = await invoke<SessionInfo>('load_llama_model', {
|
||||||
backendPath,
|
backendPath,
|
||||||
libraryPath,
|
libraryPath,
|
||||||
args
|
args,
|
||||||
})
|
})
|
||||||
|
|
||||||
await this.waitForModelLoad(sInfo.port)
|
await this.waitForModelLoad(sInfo.port)
|
||||||
@ -503,7 +529,7 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
try {
|
try {
|
||||||
// Pass the PID as the session_id
|
// Pass the PID as the session_id
|
||||||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||||||
pid: pid
|
pid: pid,
|
||||||
})
|
})
|
||||||
|
|
||||||
// If successful, remove from active sessions
|
// If successful, remove from active sessions
|
||||||
@ -648,6 +674,48 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
return lmodels
|
return lmodels
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||||||
|
let sInfo = this.findSessionByModel('sentence-transformer-mini')
|
||||||
|
if (!sInfo) {
|
||||||
|
const downloadedModelList = await this.list()
|
||||||
|
if (
|
||||||
|
!downloadedModelList.some(
|
||||||
|
(model) => model.id === 'sentence-transformer-mini'
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
await this.import('sentence-transformer-mini', {
|
||||||
|
modelPath:
|
||||||
|
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
sInfo = await this.load('sentence-transformer-mini')
|
||||||
|
}
|
||||||
|
const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings`
|
||||||
|
const headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${sInfo.api_key}`,
|
||||||
|
}
|
||||||
|
const body = JSON.stringify({
|
||||||
|
input: text,
|
||||||
|
model: sInfo.model_id,
|
||||||
|
encoding_format: 'float',
|
||||||
|
})
|
||||||
|
const response = await fetch(baseUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json().catch(() => null)
|
||||||
|
throw new Error(
|
||||||
|
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
const responseData = await response.json()
|
||||||
|
return responseData as EmbeddingResponse
|
||||||
|
}
|
||||||
|
|
||||||
// Optional method for direct client access
|
// Optional method for direct client access
|
||||||
override getChatClient(sessionId: string): any {
|
override getChatClient(sessionId: string): any {
|
||||||
throw new Error('method not implemented yet')
|
throw new Error('method not implemented yet')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user