* feat: Improve llama.cpp argument handling and add device parsing tests
This commit refactors how arguments are passed to llama.cpp,
specifically by only adding arguments when their values differ from
their defaults. This reduces the verbosity of the command and prevents
potential conflicts or errors when llama.cpp's default behavior aligns
with the desired setting.
Additionally, new tests have been added for parsing device output from
llama.cpp, ensuring the accurate extraction of GPU information (ID,
name, total memory, and free memory). This improves the robustness of
device detection.
The following changes were made:
* **Remove redundant `--ctx-size` argument:** The `--ctx-size`
argument is now only explicitly added if `cfg.ctx_size` is greater
than 0.
* **Conditional argument adding for default values:**
* `--split-mode` is only added if `cfg.split_mode` is not empty
and not 'layer'.
* `--main-gpu` is only added if `cfg.main_gpu` is not undefined
and not 0.
* `--cache-type-k` is only added if `cfg.cache_type_k` is not 'f16'.
* `--cache-type-v` is only added if `cfg.cache_type_v` is not 'f16'
(when `flash_attn` is enabled) or not 'f32' (otherwise). This
also corrects the `flash_attn` condition.
* `--defrag-thold` is only added if `cfg.defrag_thold` is not 0.1.
* `--rope-scaling` is only added if `cfg.rope_scaling` is not
'none'.
* `--rope-scale` is only added if `cfg.rope_scale` is not 1.
* `--rope-freq-base` is only added if `cfg.rope_freq_base` is not 0.
* `--rope-freq-scale` is only added if `cfg.rope_freq_scale` is
not 1.
* **Add `parse_device_output` tests:** Comprehensive unit tests were
added to `src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs`
to validate the parsing of llama.cpp device output under various
scenarios, including multiple devices, single devices, different
backends (CUDA, Vulkan, SYCL), complex GPU names, and error
conditions.
* fixup cache_type_v comparision
1628 lines
51 KiB
TypeScript
1628 lines
51 KiB
TypeScript
/**
|
||
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
* @version 1.0.0
|
||
* @module llamacpp-extension/src/index
|
||
*/
|
||
|
||
import {
|
||
AIEngine,
|
||
getJanDataFolderPath,
|
||
fs,
|
||
joinPath,
|
||
modelInfo,
|
||
SessionInfo,
|
||
UnloadResult,
|
||
chatCompletion,
|
||
chatCompletionChunk,
|
||
ImportOptions,
|
||
chatCompletionRequest,
|
||
events,
|
||
} from '@janhq/core'
|
||
|
||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||
|
||
import {
|
||
listSupportedBackends,
|
||
downloadBackend,
|
||
isBackendInstalled,
|
||
getBackendExePath,
|
||
} from './backend'
|
||
import { invoke } from '@tauri-apps/api/core'
|
||
import { getProxyConfig } from './util'
|
||
import { basename } from '@tauri-apps/api/path'
|
||
|
||
type LlamacppConfig = {
|
||
version_backend: string
|
||
auto_update_engine: boolean
|
||
auto_unload: boolean
|
||
chat_template: string
|
||
n_gpu_layers: number
|
||
ctx_size: number
|
||
threads: number
|
||
threads_batch: number
|
||
n_predict: number
|
||
batch_size: number
|
||
ubatch_size: number
|
||
device: string
|
||
split_mode: string
|
||
main_gpu: number
|
||
flash_attn: boolean
|
||
cont_batching: boolean
|
||
no_mmap: boolean
|
||
mlock: boolean
|
||
no_kv_offload: boolean
|
||
cache_type_k: string
|
||
cache_type_v: string
|
||
defrag_thold: number
|
||
rope_scaling: string
|
||
rope_scale: number
|
||
rope_freq_base: number
|
||
rope_freq_scale: number
|
||
ctx_shift: boolean
|
||
}
|
||
|
||
interface DownloadItem {
|
||
url: string
|
||
save_path: string
|
||
proxy?: Record<string, string | string[] | boolean>
|
||
}
|
||
|
||
interface ModelConfig {
|
||
model_path: string
|
||
mmproj_path?: string
|
||
name: string // user-friendly
|
||
// some model info that we cache upon import
|
||
size_bytes: number
|
||
}
|
||
|
||
interface EmbeddingResponse {
|
||
model: string
|
||
object: string
|
||
usage: {
|
||
prompt_tokens: number
|
||
total_tokens: number
|
||
}
|
||
data: EmbeddingData[]
|
||
}
|
||
|
||
interface EmbeddingData {
|
||
embedding: number[]
|
||
index: number
|
||
object: string
|
||
}
|
||
|
||
interface DeviceList {
|
||
id: string
|
||
name: string
|
||
mem: number
|
||
free: number
|
||
}
|
||
/**
|
||
* Override the default app.log function to use Jan's logging system.
|
||
* @param args
|
||
*/
|
||
const logger = {
|
||
info: function (...args: any[]) {
|
||
console.log(...args)
|
||
info(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
warn: function (...args: any[]) {
|
||
console.warn(...args)
|
||
warn(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
error: function (...args: any[]) {
|
||
console.error(...args)
|
||
error(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
}
|
||
|
||
/**
|
||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
*/
|
||
|
||
// Folder structure for llamacpp extension:
|
||
// <Jan's data folder>/llamacpp
|
||
// - models/<modelId>/
|
||
// - model.yml (required)
|
||
// - model.gguf (optional, present if downloaded from URL)
|
||
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
||
// Contents of model.yml can be found in ModelConfig interface
|
||
//
|
||
// - backends/<backend_version>/<backend_type>/
|
||
// - build/bin/llama-server (or llama-server.exe on Windows)
|
||
//
|
||
// - lib/
|
||
// - e.g. libcudart.so.12
|
||
|
||
export default class llamacpp_extension extends AIEngine {
|
||
provider: string = 'llamacpp'
|
||
autoUnload: boolean = true
|
||
readonly providerId: string = 'llamacpp'
|
||
|
||
private config: LlamacppConfig
|
||
private activeSessions: Map<number, SessionInfo> = new Map()
|
||
private providerPath!: string
|
||
private apiSecret: string = 'JustAskNow'
|
||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||
private isConfiguringBackends: boolean = false
|
||
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
||
|
||
override async onLoad(): Promise<void> {
|
||
super.onLoad() // Calls registerEngine() from AIEngine
|
||
|
||
let settings = structuredClone(SETTINGS) // Clone to modify settings definition before registration
|
||
|
||
// This makes the settings (including the backend options and initial value) available to the Jan UI.
|
||
this.registerSettings(settings)
|
||
|
||
let loadedConfig: any = {}
|
||
for (const item of settings) {
|
||
const defaultValue = item.controllerProps.value
|
||
// Use the potentially updated default value from the settings array as the fallback for getSetting
|
||
loadedConfig[item.key] = await this.getSetting<typeof defaultValue>(
|
||
item.key,
|
||
defaultValue
|
||
)
|
||
}
|
||
this.config = loadedConfig as LlamacppConfig
|
||
|
||
this.autoUnload = this.config.auto_unload
|
||
|
||
// This sets the base directory where model files for this provider are stored.
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
this.configureBackends()
|
||
}
|
||
|
||
private getStoredBackendType(): string | null {
|
||
try {
|
||
return localStorage.getItem('llama_cpp_backend_type')
|
||
} catch (error) {
|
||
logger.warn('Failed to read backend type from localStorage:', error)
|
||
return null
|
||
}
|
||
}
|
||
|
||
private setStoredBackendType(backendType: string): void {
|
||
try {
|
||
localStorage.setItem('llama_cpp_backend_type', backendType)
|
||
logger.info(`Stored backend type preference: ${backendType}`)
|
||
} catch (error) {
|
||
logger.warn('Failed to store backend type in localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private clearStoredBackendType(): void {
|
||
try {
|
||
localStorage.removeItem('llama_cpp_backend_type')
|
||
logger.info('Cleared stored backend type preference')
|
||
} catch (error) {
|
||
logger.warn('Failed to clear backend type from localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private findLatestVersionForBackend(
|
||
version_backends: { version: string; backend: string }[],
|
||
backendType: string
|
||
): string | null {
|
||
const matchingBackends = version_backends.filter(
|
||
(vb) => vb.backend === backendType
|
||
)
|
||
if (matchingBackends.length === 0) {
|
||
return null
|
||
}
|
||
|
||
// Sort by version (newest first) and get the latest
|
||
matchingBackends.sort((a, b) => b.version.localeCompare(a.version))
|
||
return `${matchingBackends[0].version}/${matchingBackends[0].backend}`
|
||
}
|
||
|
||
async configureBackends(): Promise<void> {
|
||
if (this.isConfiguringBackends) {
|
||
logger.info(
|
||
'configureBackends already in progress, skipping duplicate call'
|
||
)
|
||
return
|
||
}
|
||
|
||
this.isConfiguringBackends = true
|
||
|
||
try {
|
||
let version_backends: { version: string; backend: string }[] = []
|
||
|
||
try {
|
||
version_backends = await listSupportedBackends()
|
||
if (version_backends.length === 0) {
|
||
throw new Error(
|
||
'No supported backend binaries found for this system. Backend selection and auto-update will be unavailable.'
|
||
)
|
||
} else {
|
||
version_backends.sort((a, b) => b.version.localeCompare(a.version))
|
||
}
|
||
} catch (error) {
|
||
throw new Error(
|
||
`Failed to fetch supported backends: ${
|
||
error instanceof Error ? error.message : error
|
||
}`
|
||
)
|
||
}
|
||
|
||
// Get stored backend preference
|
||
const storedBackendType = this.getStoredBackendType()
|
||
let bestAvailableBackendString = ''
|
||
|
||
if (storedBackendType) {
|
||
// Find the latest version of the stored backend type
|
||
const preferredBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
storedBackendType
|
||
)
|
||
if (preferredBackendString) {
|
||
bestAvailableBackendString = preferredBackendString
|
||
logger.info(
|
||
`Using stored backend preference: ${bestAvailableBackendString}`
|
||
)
|
||
} else {
|
||
logger.warn(
|
||
`Stored backend type '${storedBackendType}' not available, falling back to best backend`
|
||
)
|
||
// Clear the invalid stored preference
|
||
this.clearStoredBackendType()
|
||
bestAvailableBackendString =
|
||
this.determineBestBackend(version_backends)
|
||
}
|
||
} else {
|
||
bestAvailableBackendString = this.determineBestBackend(version_backends)
|
||
}
|
||
|
||
let settings = structuredClone(SETTINGS)
|
||
const backendSettingIndex = settings.findIndex(
|
||
(item) => item.key === 'version_backend'
|
||
)
|
||
|
||
let originalDefaultBackendValue = ''
|
||
if (backendSettingIndex !== -1) {
|
||
const backendSetting = settings[backendSettingIndex]
|
||
originalDefaultBackendValue = backendSetting.controllerProps
|
||
.value as string
|
||
|
||
backendSetting.controllerProps.options = version_backends.map((b) => {
|
||
const key = `${b.version}/${b.backend}`
|
||
return { value: key, name: key }
|
||
})
|
||
|
||
const savedBackendSetting = await this.getSetting<string>(
|
||
'version_backend',
|
||
originalDefaultBackendValue
|
||
)
|
||
|
||
// Determine initial UI default based on priority:
|
||
// 1. Saved setting (if valid and not original default)
|
||
// 2. Best available for stored backend type
|
||
// 3. Original default
|
||
let initialUiDefault = originalDefaultBackendValue
|
||
|
||
if (
|
||
savedBackendSetting &&
|
||
savedBackendSetting !== originalDefaultBackendValue
|
||
) {
|
||
initialUiDefault = savedBackendSetting
|
||
// Store the backend type from the saved setting only if different
|
||
const [, backendType] = savedBackendSetting.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from saved setting: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
} else if (bestAvailableBackendString) {
|
||
initialUiDefault = bestAvailableBackendString
|
||
// Store the backend type from the best available only if different
|
||
const [, backendType] = bestAvailableBackendString.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from best available: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
backendSetting.controllerProps.value = initialUiDefault
|
||
logger.info(
|
||
`Initial UI default for version_backend set to: ${initialUiDefault}`
|
||
)
|
||
} else {
|
||
logger.error(
|
||
'Critical setting "version_backend" definition not found in SETTINGS.'
|
||
)
|
||
throw new Error('Critical setting "version_backend" not found.')
|
||
}
|
||
|
||
this.registerSettings(settings)
|
||
|
||
let effectiveBackendString = this.config.version_backend
|
||
let backendWasDownloaded = false
|
||
|
||
// Handle fresh installation case where version_backend might be 'none' or invalid
|
||
if (
|
||
!effectiveBackendString ||
|
||
effectiveBackendString === 'none' ||
|
||
!effectiveBackendString.includes('/')
|
||
) {
|
||
effectiveBackendString = bestAvailableBackendString
|
||
logger.info(
|
||
`Fresh installation or invalid backend detected, using: ${effectiveBackendString}`
|
||
)
|
||
|
||
// Update the config immediately
|
||
this.config.version_backend = effectiveBackendString
|
||
|
||
// Update the settings to reflect the change in UI
|
||
const updatedSettings = await this.getSettings()
|
||
await this.updateSettings(
|
||
updatedSettings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = effectiveBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
logger.info(`Updated UI settings to show: ${effectiveBackendString}`)
|
||
}
|
||
|
||
// Download and install the backend if not already present
|
||
if (effectiveBackendString) {
|
||
const [version, backend] = effectiveBackendString.split('/')
|
||
if (version && backend) {
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (!isInstalled) {
|
||
logger.info(`Installing initial backend: ${effectiveBackendString}`)
|
||
await this.ensureBackendReady(backend, version)
|
||
backendWasDownloaded = true
|
||
logger.info(
|
||
`Successfully installed initial backend: ${effectiveBackendString}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
if (this.config.auto_update_engine) {
|
||
const updateResult = await this.handleAutoUpdate(
|
||
bestAvailableBackendString
|
||
)
|
||
if (updateResult.wasUpdated) {
|
||
effectiveBackendString = updateResult.newBackend
|
||
backendWasDownloaded = true
|
||
}
|
||
}
|
||
|
||
if (!backendWasDownloaded && effectiveBackendString) {
|
||
await this.ensureFinalBackendInstallation(effectiveBackendString)
|
||
}
|
||
} finally {
|
||
this.isConfiguringBackends = false
|
||
}
|
||
}
|
||
|
||
private determineBestBackend(
|
||
version_backends: { version: string; backend: string }[]
|
||
): string {
|
||
if (version_backends.length === 0) return ''
|
||
|
||
// Priority list for backend types (more specific/performant ones first)
|
||
const backendPriorities: string[] = [
|
||
'cuda-cu12.0',
|
||
'cuda-cu11.7',
|
||
'vulkan',
|
||
'avx512',
|
||
'avx2',
|
||
'avx',
|
||
'noavx',
|
||
'arm64',
|
||
'x64',
|
||
]
|
||
|
||
// Helper to map backend string to a priority category
|
||
const getBackendCategory = (backendString: string): string | undefined => {
|
||
if (backendString.includes('cu12.0')) return 'cuda-cu12.0'
|
||
if (backendString.includes('cu11.7')) return 'cuda-cu11.7'
|
||
if (backendString.includes('vulkan')) return 'vulkan'
|
||
if (backendString.includes('avx512')) return 'avx512'
|
||
if (backendString.includes('avx2')) return 'avx2'
|
||
if (
|
||
backendString.includes('avx') &&
|
||
!backendString.includes('avx2') &&
|
||
!backendString.includes('avx512')
|
||
)
|
||
return 'avx'
|
||
if (backendString.includes('noavx')) return 'noavx'
|
||
if (backendString.endsWith('arm64')) return 'arm64'
|
||
if (backendString.endsWith('x64')) return 'x64'
|
||
return undefined
|
||
}
|
||
|
||
let foundBestBackend: { version: string; backend: string } | undefined
|
||
for (const priorityCategory of backendPriorities) {
|
||
const matchingBackends = version_backends.filter((vb) => {
|
||
const category = getBackendCategory(vb.backend)
|
||
return category === priorityCategory
|
||
})
|
||
|
||
if (matchingBackends.length > 0) {
|
||
foundBestBackend = matchingBackends[0]
|
||
logger.info(
|
||
`Determined best available backend: ${foundBestBackend.version}/${foundBestBackend.backend} (Category: "${priorityCategory}")`
|
||
)
|
||
break
|
||
}
|
||
}
|
||
|
||
if (foundBestBackend) {
|
||
return `${foundBestBackend.version}/${foundBestBackend.backend}`
|
||
} else {
|
||
// Fallback to newest version
|
||
return `${version_backends[0].version}/${version_backends[0].backend}`
|
||
}
|
||
}
|
||
|
||
private async handleAutoUpdate(
|
||
bestAvailableBackendString: string
|
||
): Promise<{ wasUpdated: boolean; newBackend: string }> {
|
||
logger.info(
|
||
`Auto-update engine is enabled. Current backend: ${this.config.version_backend}. Best available: ${bestAvailableBackendString}`
|
||
)
|
||
|
||
if (!bestAvailableBackendString) {
|
||
logger.warn(
|
||
'Auto-update enabled, but no best available backend determined'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// If version_backend is empty, invalid, or 'none', use the best available backend
|
||
if (
|
||
!this.config.version_backend ||
|
||
this.config.version_backend === '' ||
|
||
this.config.version_backend === 'none' ||
|
||
!this.config.version_backend.includes('/')
|
||
) {
|
||
logger.info(
|
||
'No valid backend currently selected, using best available backend'
|
||
)
|
||
try {
|
||
const [bestVersion, bestBackend] = bestAvailableBackendString.split('/')
|
||
|
||
// Download new backend
|
||
await this.ensureBackendReady(bestBackend, bestVersion)
|
||
|
||
// Add delay on Windows
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||
}
|
||
|
||
// Update configuration
|
||
this.config.version_backend = bestAvailableBackendString
|
||
|
||
// Store the backend type preference only if it changed
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== bestBackend) {
|
||
this.setStoredBackendType(bestBackend)
|
||
logger.info(`Stored new backend type preference: ${bestBackend}`)
|
||
}
|
||
|
||
// Update settings
|
||
const settings = await this.getSettings()
|
||
await this.updateSettings(
|
||
settings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = bestAvailableBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
|
||
logger.info(
|
||
`Successfully set initial backend: ${bestAvailableBackendString}`
|
||
)
|
||
return { wasUpdated: true, newBackend: bestAvailableBackendString }
|
||
} catch (error) {
|
||
logger.error('Failed to set initial backend:', error)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
}
|
||
|
||
// Parse current backend configuration
|
||
const [currentVersion, currentBackend] = (
|
||
this.config.version_backend || ''
|
||
).split('/')
|
||
|
||
if (!currentVersion || !currentBackend) {
|
||
logger.warn(
|
||
`Invalid current backend format: ${this.config.version_backend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Find the latest version for the currently selected backend type
|
||
const version_backends = await listSupportedBackends()
|
||
const targetBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
currentBackend
|
||
)
|
||
|
||
if (!targetBackendString) {
|
||
logger.warn(
|
||
`No available versions found for current backend type: ${currentBackend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
const [latestVersion] = targetBackendString.split('/')
|
||
|
||
// Check if update is needed (only version comparison for same backend type)
|
||
if (currentVersion === latestVersion) {
|
||
logger.info(
|
||
'Auto-update: Already using the latest version of the selected backend'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Perform version update for the same backend type
|
||
try {
|
||
logger.info(
|
||
`Auto-updating from ${this.config.version_backend} to ${targetBackendString} (preserving backend type)`
|
||
)
|
||
|
||
// Download new version of the same backend type
|
||
await this.ensureBackendReady(currentBackend, latestVersion)
|
||
|
||
// Add delay on Windows
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||
}
|
||
|
||
// Update configuration
|
||
this.config.version_backend = targetBackendString
|
||
|
||
// Update stored backend type preference only if it changed
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== currentBackend) {
|
||
this.setStoredBackendType(currentBackend)
|
||
logger.info(`Updated stored backend type preference: ${currentBackend}`)
|
||
}
|
||
|
||
// Update settings
|
||
const settings = await this.getSettings()
|
||
await this.updateSettings(
|
||
settings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = targetBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
|
||
logger.info(
|
||
`Successfully updated to backend: ${targetBackendString} (preserved backend type: ${currentBackend})`
|
||
)
|
||
|
||
// Emit for updating fe
|
||
if (events && typeof events.emit === 'function') {
|
||
logger.info(
|
||
`Emitting settingsChanged event for version_backend with value: ${targetBackendString}`
|
||
)
|
||
events.emit('settingsChanged', {
|
||
key: 'version_backend',
|
||
value: targetBackendString,
|
||
})
|
||
}
|
||
|
||
// Clean up old versions of the same backend type
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||
}
|
||
await this.removeOldBackend(latestVersion, currentBackend)
|
||
|
||
return { wasUpdated: true, newBackend: targetBackendString }
|
||
} catch (error) {
|
||
logger.error('Auto-update failed:', error)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
}
|
||
|
||
private async removeOldBackend(
|
||
latestVersion: string,
|
||
backendType: string
|
||
): Promise<void> {
|
||
try {
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const backendsDir = await joinPath([
|
||
janDataFolderPath,
|
||
'llamacpp',
|
||
'backends',
|
||
])
|
||
|
||
if (!(await fs.existsSync(backendsDir))) {
|
||
return
|
||
}
|
||
|
||
const versionDirs = await fs.readdirSync(backendsDir)
|
||
|
||
for (const versionDir of versionDirs) {
|
||
const versionPath = await joinPath([backendsDir, versionDir])
|
||
const versionName = await basename(versionDir)
|
||
|
||
// Skip the latest version
|
||
if (versionName === latestVersion) {
|
||
continue
|
||
}
|
||
|
||
// Check if this version has the specific backend type we're interested in
|
||
const backendTypePath = await joinPath([versionPath, backendType])
|
||
|
||
if (await fs.existsSync(backendTypePath)) {
|
||
const isInstalled = await isBackendInstalled(backendType, versionName)
|
||
if (isInstalled) {
|
||
try {
|
||
await fs.rm(backendTypePath)
|
||
logger.info(
|
||
`Removed old version of ${backendType}: ${backendTypePath}`
|
||
)
|
||
} catch (e) {
|
||
logger.warn(
|
||
`Failed to remove old backend version: ${backendTypePath}`,
|
||
e
|
||
)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} catch (error) {
|
||
logger.error('Error during old backend version cleanup:', error)
|
||
}
|
||
}
|
||
|
||
private async ensureFinalBackendInstallation(
|
||
backendString: string
|
||
): Promise<void> {
|
||
if (!backendString) {
|
||
logger.warn('No backend specified for final installation check')
|
||
return
|
||
}
|
||
|
||
const [selectedVersion, selectedBackend] = backendString
|
||
.split('/')
|
||
.map((part) => part?.trim())
|
||
|
||
if (!selectedVersion || !selectedBackend) {
|
||
logger.warn(`Invalid backend format: ${backendString}`)
|
||
return
|
||
}
|
||
|
||
try {
|
||
const isInstalled = await isBackendInstalled(
|
||
selectedBackend,
|
||
selectedVersion
|
||
)
|
||
if (!isInstalled) {
|
||
logger.info(`Final check: Installing backend ${backendString}`)
|
||
await this.ensureBackendReady(selectedBackend, selectedVersion)
|
||
logger.info(`Successfully installed backend: ${backendString}`)
|
||
} else {
|
||
logger.info(
|
||
`Final check: Backend ${backendString} is already installed`
|
||
)
|
||
}
|
||
} catch (error) {
|
||
logger.error(
|
||
`Failed to ensure backend ${backendString} installation:`,
|
||
error
|
||
)
|
||
throw error // Re-throw as this is critical
|
||
}
|
||
}
|
||
|
||
async getProviderPath(): Promise<string> {
|
||
if (!this.providerPath) {
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
}
|
||
return this.providerPath
|
||
}
|
||
|
||
override async onUnload(): Promise<void> {
|
||
// Terminate all active sessions
|
||
for (const [_, sInfo] of this.activeSessions) {
|
||
try {
|
||
await this.unload(sInfo.model_id)
|
||
} catch (error) {
|
||
logger.error(`Failed to unload model ${sInfo.model_id}:`, error)
|
||
}
|
||
}
|
||
|
||
// Clear the sessions map
|
||
this.activeSessions.clear()
|
||
}
|
||
|
||
onSettingUpdate<T>(key: string, value: T): void {
|
||
this.config[key] = value
|
||
|
||
if (key === 'version_backend') {
|
||
const valueStr = value as string
|
||
const [version, backend] = valueStr.split('/')
|
||
|
||
// Store the backend type preference in localStorage only if it changed
|
||
if (backend) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backend) {
|
||
this.setStoredBackendType(backend)
|
||
logger.info(`Updated backend type preference to: ${backend}`)
|
||
}
|
||
}
|
||
|
||
// Reset device setting when backend changes
|
||
this.config.device = ''
|
||
|
||
const closure = async () => {
|
||
await this.ensureBackendReady(backend, version)
|
||
}
|
||
closure()
|
||
} else if (key === 'auto_unload') {
|
||
this.autoUnload = value as boolean
|
||
}
|
||
}
|
||
|
||
private async generateApiKey(modelId: string, port: string): Promise<string> {
|
||
const hash = await invoke<string>('generate_api_key', {
|
||
modelId: modelId + port,
|
||
apiSecret: this.apiSecret,
|
||
})
|
||
return hash
|
||
}
|
||
|
||
// Implement the required LocalProvider interface methods
|
||
override async list(): Promise<modelInfo[]> {
|
||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) {
|
||
await fs.mkdir(modelsDir)
|
||
}
|
||
|
||
await this.migrateLegacyModels()
|
||
|
||
let modelIds: string[] = []
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
// check if model.yml exists
|
||
const modelConfigPath = await joinPath([currentDir, 'model.yml'])
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
modelIds.push(currentDir.slice(modelsDir.length + 1))
|
||
continue
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
|
||
let modelInfos: modelInfo[] = []
|
||
for (const modelId of modelIds) {
|
||
const path = await joinPath([modelsDir, modelId, 'model.yml'])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||
|
||
const modelInfo = {
|
||
id: modelId,
|
||
name: modelConfig.name ?? modelId,
|
||
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||
providerId: this.provider,
|
||
port: 0, // port is not known until the model is loaded
|
||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||
} as modelInfo
|
||
modelInfos.push(modelInfo)
|
||
}
|
||
|
||
return modelInfos
|
||
}
|
||
|
||
private async migrateLegacyModels() {
|
||
// Attempt to migrate only once
|
||
if (localStorage.getItem('cortex_models_migrated') === 'true') return
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelsDir = await joinPath([janDataFolderPath, 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) return
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
const files = await fs.readdirSync(currentDir)
|
||
for (const child of files) {
|
||
try {
|
||
const childPath = await joinPath([currentDir, child])
|
||
const stat = await fs.fileStat(childPath)
|
||
if (
|
||
files.some((e) => e.endsWith('model.yml')) &&
|
||
!child.endsWith('model.yml')
|
||
)
|
||
continue
|
||
if (!stat.isDirectory && child.endsWith('.yml')) {
|
||
// check if model.yml exists
|
||
const modelConfigPath = child
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
const legacyModelConfig = await invoke<{
|
||
files: string[]
|
||
model: string
|
||
}>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const legacyModelPath = legacyModelConfig.files?.[0]
|
||
if (!legacyModelPath) continue
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
let modelId = currentDir.slice(modelsDir.length + 1)
|
||
|
||
modelId =
|
||
modelId !== 'imported'
|
||
? modelId.replace(/^(cortex\.so|huggingface\.co)[\/\\]/, '')
|
||
: (await basename(child)).replace('.yml', '')
|
||
|
||
const modelName = legacyModelConfig.model ?? modelId
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath)) continue // Don't reimport
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
let size_bytes = (
|
||
await fs.fileStat(
|
||
await joinPath([janDataFolderPath, legacyModelPath])
|
||
)
|
||
).size
|
||
|
||
const modelConfig = {
|
||
model_path: legacyModelPath,
|
||
mmproj_path: undefined, // legacy models do not have mmproj
|
||
name: modelName,
|
||
size_bytes,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
continue
|
||
}
|
||
}
|
||
} catch (error) {
|
||
console.error(`Error migrating model ${child}:`, error)
|
||
}
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
localStorage.setItem('cortex_models_migrated', 'true')
|
||
}
|
||
|
||
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||
const isValidModelId = (id: string) => {
|
||
// only allow alphanumeric, underscore, hyphen, and dot characters in modelId
|
||
if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false
|
||
|
||
// check for empty parts or path traversal
|
||
const parts = id.split('/')
|
||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||
}
|
||
|
||
if (!isValidModelId(modelId))
|
||
throw new Error(
|
||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||
)
|
||
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath))
|
||
throw new Error(`Model ${modelId} already exists`)
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
// we only use these from opts
|
||
// opts.modelPath: URL to the model file
|
||
// opts.mmprojPath: URL to the mmproj file
|
||
|
||
let downloadItems: DownloadItem[] = []
|
||
|
||
const maybeDownload = async (path: string, saveName: string) => {
|
||
// if URL, add to downloadItems, and return local path
|
||
if (path.startsWith('https://')) {
|
||
const localPath = `${modelDir}/${saveName}`
|
||
downloadItems.push({
|
||
url: path,
|
||
save_path: localPath,
|
||
proxy: getProxyConfig(),
|
||
})
|
||
return localPath
|
||
}
|
||
|
||
// if local file (absolute path), check if it exists
|
||
// and return the path
|
||
if (!(await fs.existsSync(path)))
|
||
throw new Error(`File not found: ${path}`)
|
||
return path
|
||
}
|
||
|
||
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||
let mmprojPath = opts.mmprojPath
|
||
? await maybeDownload(opts.mmprojPath, 'mmproj.gguf')
|
||
: undefined
|
||
|
||
if (downloadItems.length > 0) {
|
||
let downloadCompleted = false
|
||
|
||
try {
|
||
// emit download update event on progress
|
||
const onProgress = (transferred: number, total: number) => {
|
||
events.emit('onFileDownloadUpdate', {
|
||
modelId,
|
||
percent: transferred / total,
|
||
size: { transferred, total },
|
||
downloadType: 'Model',
|
||
})
|
||
downloadCompleted = transferred === total
|
||
}
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
await downloadManager.downloadFiles(
|
||
downloadItems,
|
||
this.createDownloadTaskId(modelId),
|
||
onProgress
|
||
)
|
||
|
||
const eventName = downloadCompleted
|
||
? 'onFileDownloadSuccess'
|
||
: 'onFileDownloadStopped'
|
||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||
} catch (error) {
|
||
logger.error('Error downloading model:', modelId, opts, error)
|
||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||
throw error
|
||
}
|
||
}
|
||
|
||
// TODO: check if files are valid GGUF files
|
||
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
|
||
// or absolute paths (if they are provided as local files)
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
let size_bytes = (
|
||
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
|
||
).size
|
||
if (mmprojPath) {
|
||
size_bytes += (
|
||
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
|
||
).size
|
||
}
|
||
|
||
// TODO: add name as import() argument
|
||
// TODO: add updateModelConfig() method
|
||
const modelConfig = {
|
||
model_path: modelPath,
|
||
mmproj_path: mmprojPath,
|
||
name: modelId,
|
||
size_bytes,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
}
|
||
|
||
override async abortImport(modelId: string): Promise<void> {
|
||
// prepand provider name to avoid name collision
|
||
const taskId = this.createDownloadTaskId(modelId)
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
await downloadManager.cancelDownload(taskId)
|
||
}
|
||
|
||
/**
|
||
* Function to find a random port
|
||
*/
|
||
private async getRandomPort(): Promise<number> {
|
||
const MAX_ATTEMPTS = 20000
|
||
let attempts = 0
|
||
|
||
while (attempts < MAX_ATTEMPTS) {
|
||
const port = Math.floor(Math.random() * 1000) + 3000
|
||
|
||
const isAlreadyUsed = Array.from(this.activeSessions.values()).some(
|
||
(info) => info.port === port
|
||
)
|
||
|
||
if (!isAlreadyUsed) {
|
||
const isAvailable = await invoke<boolean>('is_port_available', { port })
|
||
if (isAvailable) return port
|
||
}
|
||
|
||
attempts++
|
||
}
|
||
|
||
throw new Error('Failed to find an available port for the model to load')
|
||
}
|
||
|
||
private async sleep(ms: number): Promise<void> {
|
||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||
}
|
||
|
||
private async waitForModelLoad(
|
||
sInfo: SessionInfo,
|
||
timeoutMs = 240_000
|
||
): Promise<void> {
|
||
await this.sleep(500) // Wait before first check
|
||
const start = Date.now()
|
||
while (Date.now() - start < timeoutMs) {
|
||
try {
|
||
const res = await fetch(`http://localhost:${sInfo.port}/health`)
|
||
|
||
if (res.status === 503) {
|
||
const body = await res.json()
|
||
const msg = body?.error?.message ?? 'Model loading'
|
||
logger.info(`waiting for model load... (${msg})`)
|
||
} else if (res.ok) {
|
||
const body = await res.json()
|
||
if (body.status === 'ok') {
|
||
return
|
||
} else {
|
||
logger.warn('Unexpected OK response from /health:', body)
|
||
}
|
||
} else {
|
||
logger.warn(`Unexpected status ${res.status} from /health`)
|
||
}
|
||
} catch (e) {
|
||
await this.unload(sInfo.model_id)
|
||
throw new Error(`Model appears to have crashed: ${e}`)
|
||
}
|
||
|
||
await this.sleep(800) // Retry interval
|
||
}
|
||
|
||
await this.unload(sInfo.model_id)
|
||
throw new Error(
|
||
`Timed out loading model after ${timeoutMs}... killing llamacpp`
|
||
)
|
||
}
|
||
|
||
override async load(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const sInfo = this.findSessionByModel(modelId)
|
||
if (sInfo) {
|
||
throw new Error('Model already loaded!!')
|
||
}
|
||
|
||
// If this model is already being loaded, return the existing promise
|
||
if (this.loadingModels.has(modelId)) {
|
||
return this.loadingModels.get(modelId)!
|
||
}
|
||
|
||
// Create the loading promise
|
||
const loadingPromise = this.performLoad(
|
||
modelId,
|
||
overrideSettings,
|
||
isEmbedding
|
||
)
|
||
this.loadingModels.set(modelId, loadingPromise)
|
||
|
||
try {
|
||
const result = await loadingPromise
|
||
return result
|
||
} finally {
|
||
this.loadingModels.delete(modelId)
|
||
}
|
||
}
|
||
|
||
private async performLoad(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const loadedModels = await this.getLoadedModels()
|
||
|
||
// Get OTHER models that are currently loading (exclude current model)
|
||
const otherLoadingPromises = Array.from(this.loadingModels.entries())
|
||
.filter(([id, _]) => id !== modelId)
|
||
.map(([_, promise]) => promise)
|
||
|
||
if (
|
||
this.autoUnload &&
|
||
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
|
||
) {
|
||
// Wait for OTHER loading models to finish, then unload everything
|
||
if (otherLoadingPromises.length > 0) {
|
||
await Promise.all(otherLoadingPromises)
|
||
}
|
||
|
||
// Now unload all loaded models
|
||
const allLoadedModels = await this.getLoadedModels()
|
||
if (allLoadedModels.length > 0) {
|
||
await Promise.all(allLoadedModels.map((model) => this.unload(model)))
|
||
}
|
||
}
|
||
const args: string[] = []
|
||
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||
)
|
||
}
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelConfigPath = await joinPath([
|
||
this.providerPath,
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const port = await this.getRandomPort()
|
||
|
||
// disable llama-server webui
|
||
args.push('--no-webui')
|
||
const api_key = await this.generateApiKey(modelId, String(port))
|
||
args.push('--api-key', api_key)
|
||
|
||
// model option is required
|
||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||
const modelPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.model_path,
|
||
])
|
||
args.push('--jinja')
|
||
args.push('--reasoning-format', 'none')
|
||
args.push('-m', modelPath)
|
||
args.push('-a', modelId)
|
||
args.push('--port', String(port))
|
||
if (modelConfig.mmproj_path) {
|
||
const mmprojPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.mmproj_path,
|
||
])
|
||
args.push('--mmproj', mmprojPath)
|
||
}
|
||
// Add remaining options from the interface
|
||
if (cfg.chat_template) args.push('--chat-template', cfg.chat_template)
|
||
const gpu_layers =
|
||
parseInt(String(cfg.n_gpu_layers)) >= 0 ? cfg.n_gpu_layers : 100
|
||
args.push('-ngl', String(gpu_layers))
|
||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||
if (cfg.threads_batch > 0)
|
||
args.push('--threads-batch', String(cfg.threads_batch))
|
||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||
if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer')
|
||
args.push('--split-mode', cfg.split_mode)
|
||
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
|
||
args.push('--main-gpu', String(cfg.main_gpu))
|
||
|
||
// Boolean flags
|
||
if (!cfg.ctx_shift) args.push('--no-context-shift')
|
||
if (cfg.flash_attn) args.push('--flash-attn')
|
||
if (cfg.cont_batching) args.push('--cont-batching')
|
||
args.push('--no-mmap')
|
||
if (cfg.mlock) args.push('--mlock')
|
||
if (cfg.no_kv_offload) args.push('--no-kv-offload')
|
||
if (isEmbedding) {
|
||
args.push('--embedding')
|
||
args.push('--pooling mean')
|
||
} else {
|
||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||
if (cfg.cache_type_k && cfg.cache_type_k != 'f16')
|
||
args.push('--cache-type-k', cfg.cache_type_k)
|
||
if (
|
||
cfg.flash_attn &&
|
||
(cfg.cache_type_v != 'f16' && cfg.cache_type_v != 'f32')
|
||
) {
|
||
args.push('--cache-type-v', cfg.cache_type_v)
|
||
}
|
||
if (cfg.defrag_thold && cfg.defrag_thold != 0.1)
|
||
args.push('--defrag-thold', String(cfg.defrag_thold))
|
||
|
||
if (cfg.rope_scaling && cfg.rope_scaling != 'none')
|
||
args.push('--rope-scaling', cfg.rope_scaling)
|
||
if (cfg.rope_scale && cfg.rope_scale != 1)
|
||
args.push('--rope-scale', String(cfg.rope_scale))
|
||
if (cfg.rope_freq_base && cfg.rope_freq_base != 0)
|
||
args.push('--rope-freq-base', String(cfg.rope_freq_base))
|
||
if (cfg.rope_freq_scale && cfg.rope_freq_scale != 1)
|
||
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
|
||
}
|
||
|
||
logger.info('Calling Tauri command llama_load with args:', args)
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
|
||
try {
|
||
// TODO: add LIBRARY_PATH
|
||
const sInfo = await invoke<SessionInfo>('load_llama_model', {
|
||
backendPath,
|
||
libraryPath,
|
||
args,
|
||
})
|
||
|
||
// Store the session info for later use
|
||
this.activeSessions.set(sInfo.pid, sInfo)
|
||
await this.waitForModelLoad(sInfo)
|
||
|
||
return sInfo
|
||
} catch (error) {
|
||
logger.error('Error loading llama-server:\n', error)
|
||
throw new Error(`Failed to load llama-server: ${error}`)
|
||
}
|
||
}
|
||
|
||
override async unload(modelId: string): Promise<UnloadResult> {
|
||
const sInfo: SessionInfo = this.findSessionByModel(modelId)
|
||
if (!sInfo) {
|
||
throw new Error(`No active session found for model: ${modelId}`)
|
||
}
|
||
const pid = sInfo.pid
|
||
try {
|
||
this.activeSessions.delete(pid)
|
||
|
||
// Pass the PID as the session_id
|
||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||
pid: pid,
|
||
})
|
||
|
||
// If successful, remove from active sessions
|
||
if (result.success) {
|
||
logger.info(`Successfully unloaded model with PID ${pid}`)
|
||
} else {
|
||
logger.warn(`Failed to unload model: ${result.error}`)
|
||
this.activeSessions.set(sInfo.pid, sInfo)
|
||
}
|
||
|
||
return result
|
||
} catch (error) {
|
||
logger.error('Error in unload command:', error)
|
||
this.activeSessions.set(sInfo.pid, sInfo)
|
||
return {
|
||
success: false,
|
||
error: `Failed to unload model: ${error}`,
|
||
}
|
||
}
|
||
}
|
||
|
||
private createDownloadTaskId(modelId: string) {
|
||
// prepend provider to make taksId unique across providers
|
||
const cleanModelId = modelId.includes('.')
|
||
? modelId.slice(0, modelId.indexOf('.'))
|
||
: modelId
|
||
return `${this.provider}/${cleanModelId}`
|
||
}
|
||
|
||
private async ensureBackendReady(
|
||
backend: string,
|
||
version: string
|
||
): Promise<void> {
|
||
const backendKey = `${version}/${backend}`
|
||
|
||
// Check if backend is already installed
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (isInstalled) {
|
||
return
|
||
}
|
||
|
||
// Check if download is already in progress
|
||
if (this.pendingDownloads.has(backendKey)) {
|
||
logger.info(
|
||
`Backend ${backendKey} download already in progress, waiting...`
|
||
)
|
||
await this.pendingDownloads.get(backendKey)
|
||
return
|
||
}
|
||
|
||
// Start new download
|
||
logger.info(`Backend ${backendKey} not installed, downloading...`)
|
||
const downloadPromise = downloadBackend(backend, version).finally(() => {
|
||
this.pendingDownloads.delete(backendKey)
|
||
})
|
||
|
||
this.pendingDownloads.set(backendKey, downloadPromise)
|
||
await downloadPromise
|
||
logger.info(`Backend ${backendKey} download completed`)
|
||
}
|
||
|
||
private async *handleStreamingResponse(
|
||
url: string,
|
||
headers: HeadersInit,
|
||
body: string,
|
||
abortController?: AbortController
|
||
): AsyncIterable<chatCompletionChunk> {
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
signal: abortController?.signal,
|
||
})
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
if (!response.body) {
|
||
throw new Error('Response body is null')
|
||
}
|
||
|
||
const reader = response.body.getReader()
|
||
const decoder = new TextDecoder('utf-8')
|
||
let buffer = ''
|
||
let jsonStr = ''
|
||
try {
|
||
while (true) {
|
||
const { done, value } = await reader.read()
|
||
|
||
if (done) {
|
||
break
|
||
}
|
||
|
||
buffer += decoder.decode(value, { stream: true })
|
||
|
||
// Process complete lines in the buffer
|
||
const lines = buffer.split('\n')
|
||
buffer = lines.pop() || '' // Keep the last incomplete line in the buffer
|
||
|
||
for (const line of lines) {
|
||
const trimmedLine = line.trim()
|
||
if (!trimmedLine || trimmedLine === 'data: [DONE]') {
|
||
continue
|
||
}
|
||
|
||
if (trimmedLine.startsWith('data: ')) {
|
||
jsonStr = trimmedLine.slice(6)
|
||
} else if (trimmedLine.startsWith('error: ')) {
|
||
jsonStr = trimmedLine.slice(7)
|
||
const error = JSON.parse(jsonStr)
|
||
throw new Error(error.message)
|
||
} else {
|
||
// it should not normally reach here
|
||
throw new Error('Malformed chunk')
|
||
}
|
||
try {
|
||
const data = JSON.parse(jsonStr)
|
||
const chunk = data as chatCompletionChunk
|
||
yield chunk
|
||
} catch (e) {
|
||
logger.error('Error parsing JSON from stream or server error:', e)
|
||
// re‑throw so the async iterator terminates with an error
|
||
throw e
|
||
}
|
||
}
|
||
}
|
||
} finally {
|
||
reader.releaseLock()
|
||
}
|
||
}
|
||
|
||
private findSessionByModel(modelId: string): SessionInfo | undefined {
|
||
return Array.from(this.activeSessions.values()).find(
|
||
(session) => session.model_id === modelId
|
||
)
|
||
}
|
||
|
||
override async chat(
|
||
opts: chatCompletionRequest,
|
||
abortController?: AbortController
|
||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||
const sessionInfo = this.findSessionByModel(opts.model)
|
||
if (!sessionInfo) {
|
||
throw new Error(`No active session found for model: ${opts.model}`)
|
||
}
|
||
// check if the process is alive
|
||
const result = await invoke<boolean>('is_process_running', {
|
||
pid: sessionInfo.pid,
|
||
})
|
||
if (result) {
|
||
try {
|
||
await fetch(`http://localhost:${sessionInfo.port}/health`)
|
||
} catch (e) {
|
||
this.unload(sessionInfo.model_id)
|
||
throw new Error('Model appears to have crashed! Please reload!')
|
||
}
|
||
} else {
|
||
this.activeSessions.delete(sessionInfo.pid)
|
||
throw new Error('Model have crashed! Please reload!')
|
||
}
|
||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||
const url = `${baseUrl}/chat/completions`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sessionInfo.api_key}`,
|
||
}
|
||
|
||
const body = JSON.stringify(opts)
|
||
if (opts.stream) {
|
||
return this.handleStreamingResponse(url, headers, body, abortController)
|
||
}
|
||
// Handle non-streaming response
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
signal: abortController?.signal,
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
return (await response.json()) as chatCompletion
|
||
}
|
||
|
||
override async delete(modelId: string): Promise<void> {
|
||
const modelDir = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
|
||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||
throw new Error(`Model ${modelId} does not exist`)
|
||
}
|
||
|
||
await fs.rm(modelDir)
|
||
}
|
||
|
||
override async getLoadedModels(): Promise<string[]> {
|
||
let lmodels: string[] = []
|
||
for (const [_, sInfo] of this.activeSessions) {
|
||
lmodels.push(sInfo.model_id)
|
||
}
|
||
return lmodels
|
||
}
|
||
|
||
async getDevices(): Promise<DeviceList[]> {
|
||
const cfg = this.config
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||
)
|
||
}
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
logger.info('Calling Tauri command getDevices with arg --list-devices')
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
try {
|
||
const dList = await invoke<DeviceList[]>('get_devices', {
|
||
backendPath,
|
||
libraryPath,
|
||
})
|
||
return dList
|
||
} catch (error) {
|
||
logger.error('Failed to query devices:\n', error)
|
||
throw new Error(`Failed to load llama-server: ${error}`)
|
||
}
|
||
}
|
||
|
||
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||
let sInfo = this.findSessionByModel('sentence-transformer-mini')
|
||
if (!sInfo) {
|
||
const downloadedModelList = await this.list()
|
||
if (
|
||
!downloadedModelList.some(
|
||
(model) => model.id === 'sentence-transformer-mini'
|
||
)
|
||
) {
|
||
await this.import('sentence-transformer-mini', {
|
||
modelPath:
|
||
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',
|
||
})
|
||
}
|
||
sInfo = await this.load('sentence-transformer-mini')
|
||
}
|
||
const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sInfo.api_key}`,
|
||
}
|
||
const body = JSON.stringify({
|
||
input: text,
|
||
model: sInfo.model_id,
|
||
encoding_format: 'float',
|
||
})
|
||
const response = await fetch(baseUrl, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
const responseData = await response.json()
|
||
return responseData as EmbeddingResponse
|
||
}
|
||
|
||
// Optional method for direct client access
|
||
override getChatClient(sessionId: string): any {
|
||
throw new Error('method not implemented yet')
|
||
}
|
||
}
|