2489 lines
78 KiB
TypeScript
2489 lines
78 KiB
TypeScript
/**
|
||
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
* @version 1.0.0
|
||
* @module llamacpp-extension/src/index
|
||
*/
|
||
|
||
import {
|
||
AIEngine,
|
||
getJanDataFolderPath,
|
||
fs,
|
||
joinPath,
|
||
modelInfo,
|
||
SessionInfo,
|
||
UnloadResult,
|
||
chatCompletion,
|
||
chatCompletionChunk,
|
||
ImportOptions,
|
||
chatCompletionRequest,
|
||
events,
|
||
AppEvent,
|
||
DownloadEvent,
|
||
chatCompletionRequestMessage,
|
||
} from '@janhq/core'
|
||
|
||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||
import { listen } from '@tauri-apps/api/event'
|
||
|
||
import {
|
||
listSupportedBackends,
|
||
downloadBackend,
|
||
isBackendInstalled,
|
||
getBackendExePath,
|
||
getBackendDir,
|
||
} from './backend'
|
||
import { invoke } from '@tauri-apps/api/core'
|
||
import { getProxyConfig } from './util'
|
||
import { basename } from '@tauri-apps/api/path'
|
||
import {
|
||
readGgufMetadata,
|
||
getModelSize,
|
||
isModelSupported,
|
||
planModelLoadInternal,
|
||
} from '@janhq/tauri-plugin-llamacpp-api'
|
||
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
|
||
|
||
// Error message constant - matches web-app/src/utils/error.ts
|
||
const OUT_OF_CONTEXT_SIZE = 'the request exceeds the available context size.'
|
||
|
||
type LlamacppConfig = {
|
||
version_backend: string
|
||
auto_update_engine: boolean
|
||
auto_unload: boolean
|
||
llamacpp_env: string
|
||
memory_util: string
|
||
chat_template: string
|
||
n_gpu_layers: number
|
||
offload_mmproj: boolean
|
||
cpu_moe: boolean
|
||
n_cpu_moe: number
|
||
override_tensor_buffer_t: string
|
||
ctx_size: number
|
||
threads: number
|
||
threads_batch: number
|
||
n_predict: number
|
||
batch_size: number
|
||
ubatch_size: number
|
||
device: string
|
||
split_mode: string
|
||
main_gpu: number
|
||
flash_attn: boolean
|
||
cont_batching: boolean
|
||
no_mmap: boolean
|
||
mlock: boolean
|
||
no_kv_offload: boolean
|
||
cache_type_k: string
|
||
cache_type_v: string
|
||
defrag_thold: number
|
||
rope_scaling: string
|
||
rope_scale: number
|
||
rope_freq_base: number
|
||
rope_freq_scale: number
|
||
ctx_shift: boolean
|
||
}
|
||
|
||
type ModelPlan = {
|
||
gpuLayers: number
|
||
maxContextLength: number
|
||
noOffloadKVCache: boolean
|
||
offloadMmproj?: boolean
|
||
batchSize: number
|
||
mode: 'GPU' | 'Hybrid' | 'CPU' | 'Unsupported'
|
||
}
|
||
|
||
interface DownloadItem {
|
||
url: string
|
||
save_path: string
|
||
proxy?: Record<string, string | string[] | boolean>
|
||
sha256?: string
|
||
size?: number
|
||
}
|
||
|
||
interface ModelConfig {
|
||
model_path: string
|
||
mmproj_path?: string
|
||
name: string // user-friendly
|
||
// some model info that we cache upon import
|
||
size_bytes: number
|
||
sha256?: string
|
||
mmproj_sha256?: string
|
||
mmproj_size_bytes?: number
|
||
}
|
||
|
||
interface EmbeddingResponse {
|
||
model: string
|
||
object: string
|
||
usage: {
|
||
prompt_tokens: number
|
||
total_tokens: number
|
||
}
|
||
data: EmbeddingData[]
|
||
}
|
||
|
||
interface EmbeddingData {
|
||
embedding: number[]
|
||
index: number
|
||
object: string
|
||
}
|
||
|
||
interface DeviceList {
|
||
id: string
|
||
name: string
|
||
mem: number
|
||
free: number
|
||
}
|
||
|
||
interface SystemMemory {
|
||
totalVRAM: number
|
||
totalRAM: number
|
||
totalMemory: number
|
||
}
|
||
|
||
/**
|
||
* Override the default app.log function to use Jan's logging system.
|
||
* @param args
|
||
*/
|
||
const logger = {
|
||
info: function (...args: any[]) {
|
||
console.log(...args)
|
||
info(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
warn: function (...args: any[]) {
|
||
console.warn(...args)
|
||
warn(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
error: function (...args: any[]) {
|
||
console.error(...args)
|
||
error(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
}
|
||
|
||
/**
|
||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
*/
|
||
|
||
// Folder structure for llamacpp extension:
|
||
// <Jan's data folder>/llamacpp
|
||
// - models/<modelId>/
|
||
// - model.yml (required)
|
||
// - model.gguf (optional, present if downloaded from URL)
|
||
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
||
// Contents of model.yml can be found in ModelConfig interface
|
||
//
|
||
// - backends/<backend_version>/<backend_type>/
|
||
// - build/bin/llama-server (or llama-server.exe on Windows)
|
||
//
|
||
// - lib/
|
||
// - e.g. libcudart.so.12
|
||
|
||
export default class llamacpp_extension extends AIEngine {
|
||
provider: string = 'llamacpp'
|
||
autoUnload: boolean = true
|
||
llamacpp_env: string = ''
|
||
memoryMode: string = ''
|
||
readonly providerId: string = 'llamacpp'
|
||
|
||
private config: LlamacppConfig
|
||
private providerPath!: string
|
||
private apiSecret: string = 'JustAskNow'
|
||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||
private isConfiguringBackends: boolean = false
|
||
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
||
private unlistenValidationStarted?: () => void
|
||
|
||
override async onLoad(): Promise<void> {
|
||
super.onLoad() // Calls registerEngine() from AIEngine
|
||
|
||
let settings = structuredClone(SETTINGS) // Clone to modify settings definition before registration
|
||
|
||
// This makes the settings (including the backend options and initial value) available to the Jan UI.
|
||
this.registerSettings(settings)
|
||
|
||
let loadedConfig: any = {}
|
||
for (const item of settings) {
|
||
const defaultValue = item.controllerProps.value
|
||
// Use the potentially updated default value from the settings array as the fallback for getSetting
|
||
loadedConfig[item.key] = await this.getSetting<typeof defaultValue>(
|
||
item.key,
|
||
defaultValue
|
||
)
|
||
}
|
||
this.config = loadedConfig as LlamacppConfig
|
||
|
||
this.autoUnload = this.config.auto_unload
|
||
this.llamacpp_env = this.config.llamacpp_env
|
||
this.memoryMode = this.config.memory_util || 'high'
|
||
|
||
// This sets the base directory where model files for this provider are stored.
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
|
||
// Set up validation event listeners to bridge Tauri events to frontend
|
||
this.unlistenValidationStarted = await listen<{
|
||
modelId: string
|
||
downloadType: string
|
||
}>('onModelValidationStarted', (event) => {
|
||
console.debug(
|
||
'LlamaCPP: bridging onModelValidationStarted event',
|
||
event.payload
|
||
)
|
||
events.emit(DownloadEvent.onModelValidationStarted, event.payload)
|
||
})
|
||
|
||
this.configureBackends()
|
||
}
|
||
|
||
private getStoredBackendType(): string | null {
|
||
try {
|
||
return localStorage.getItem('llama_cpp_backend_type')
|
||
} catch (error) {
|
||
logger.warn('Failed to read backend type from localStorage:', error)
|
||
return null
|
||
}
|
||
}
|
||
|
||
private setStoredBackendType(backendType: string): void {
|
||
try {
|
||
localStorage.setItem('llama_cpp_backend_type', backendType)
|
||
logger.info(`Stored backend type preference: ${backendType}`)
|
||
} catch (error) {
|
||
logger.warn('Failed to store backend type in localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private clearStoredBackendType(): void {
|
||
try {
|
||
localStorage.removeItem('llama_cpp_backend_type')
|
||
logger.info('Cleared stored backend type preference')
|
||
} catch (error) {
|
||
logger.warn('Failed to clear backend type from localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private findLatestVersionForBackend(
|
||
version_backends: { version: string; backend: string }[],
|
||
backendType: string
|
||
): string | null {
|
||
const matchingBackends = version_backends.filter(
|
||
(vb) => vb.backend === backendType
|
||
)
|
||
if (matchingBackends.length === 0) {
|
||
return null
|
||
}
|
||
|
||
// Sort by version (newest first) and get the latest
|
||
matchingBackends.sort((a, b) => b.version.localeCompare(a.version))
|
||
return `${matchingBackends[0].version}/${matchingBackends[0].backend}`
|
||
}
|
||
|
||
async configureBackends(): Promise<void> {
|
||
if (this.isConfiguringBackends) {
|
||
logger.info(
|
||
'configureBackends already in progress, skipping duplicate call'
|
||
)
|
||
return
|
||
}
|
||
|
||
this.isConfiguringBackends = true
|
||
|
||
try {
|
||
let version_backends: { version: string; backend: string }[] = []
|
||
|
||
try {
|
||
version_backends = await listSupportedBackends()
|
||
if (version_backends.length === 0) {
|
||
throw new Error(
|
||
'No supported backend binaries found for this system. Backend selection and auto-update will be unavailable.'
|
||
)
|
||
} else {
|
||
version_backends.sort((a, b) => b.version.localeCompare(a.version))
|
||
}
|
||
} catch (error) {
|
||
throw new Error(
|
||
`Failed to fetch supported backends: ${
|
||
error instanceof Error ? error.message : error
|
||
}`
|
||
)
|
||
}
|
||
|
||
// Get stored backend preference
|
||
const storedBackendType = this.getStoredBackendType()
|
||
let bestAvailableBackendString = ''
|
||
|
||
if (storedBackendType) {
|
||
// Find the latest version of the stored backend type
|
||
const preferredBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
storedBackendType
|
||
)
|
||
if (preferredBackendString) {
|
||
bestAvailableBackendString = preferredBackendString
|
||
logger.info(
|
||
`Using stored backend preference: ${bestAvailableBackendString}`
|
||
)
|
||
} else {
|
||
logger.warn(
|
||
`Stored backend type '${storedBackendType}' not available, falling back to best backend`
|
||
)
|
||
// Clear the invalid stored preference
|
||
this.clearStoredBackendType()
|
||
bestAvailableBackendString =
|
||
await this.determineBestBackend(version_backends)
|
||
}
|
||
} else {
|
||
bestAvailableBackendString =
|
||
await this.determineBestBackend(version_backends)
|
||
}
|
||
|
||
let settings = structuredClone(SETTINGS)
|
||
const backendSettingIndex = settings.findIndex(
|
||
(item) => item.key === 'version_backend'
|
||
)
|
||
|
||
let originalDefaultBackendValue = ''
|
||
if (backendSettingIndex !== -1) {
|
||
const backendSetting = settings[backendSettingIndex]
|
||
originalDefaultBackendValue = backendSetting.controllerProps
|
||
.value as string
|
||
|
||
backendSetting.controllerProps.options = version_backends.map((b) => {
|
||
const key = `${b.version}/${b.backend}`
|
||
return { value: key, name: key }
|
||
})
|
||
|
||
// Set the recommended backend based on bestAvailableBackendString
|
||
if (bestAvailableBackendString) {
|
||
backendSetting.controllerProps.recommended =
|
||
bestAvailableBackendString
|
||
}
|
||
|
||
const savedBackendSetting = await this.getSetting<string>(
|
||
'version_backend',
|
||
originalDefaultBackendValue
|
||
)
|
||
|
||
// Determine initial UI default based on priority:
|
||
// 1. Saved setting (if valid and not original default)
|
||
// 2. Best available for stored backend type
|
||
// 3. Original default
|
||
let initialUiDefault = originalDefaultBackendValue
|
||
|
||
if (
|
||
savedBackendSetting &&
|
||
savedBackendSetting !== originalDefaultBackendValue
|
||
) {
|
||
initialUiDefault = savedBackendSetting
|
||
// Store the backend type from the saved setting only if different
|
||
const [, backendType] = savedBackendSetting.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from saved setting: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
} else if (bestAvailableBackendString) {
|
||
initialUiDefault = bestAvailableBackendString
|
||
// Store the backend type from the best available only if different
|
||
const [, backendType] = bestAvailableBackendString.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from best available: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
backendSetting.controllerProps.value = initialUiDefault
|
||
logger.info(
|
||
`Initial UI default for version_backend set to: ${initialUiDefault}`
|
||
)
|
||
} else {
|
||
logger.error(
|
||
'Critical setting "version_backend" definition not found in SETTINGS.'
|
||
)
|
||
throw new Error('Critical setting "version_backend" not found.')
|
||
}
|
||
|
||
this.registerSettings(settings)
|
||
|
||
let effectiveBackendString = this.config.version_backend
|
||
let backendWasDownloaded = false
|
||
|
||
// Handle fresh installation case where version_backend might be 'none' or invalid
|
||
if (
|
||
(!effectiveBackendString ||
|
||
effectiveBackendString === 'none' ||
|
||
!effectiveBackendString.includes('/') ||
|
||
// If the selected backend is not in the list of supported backends
|
||
// Need to reset too
|
||
!version_backends.some(
|
||
(e) => `${e.version}/${e.backend}` === effectiveBackendString
|
||
)) &&
|
||
// Ensure we have a valid best available backend
|
||
bestAvailableBackendString
|
||
) {
|
||
effectiveBackendString = bestAvailableBackendString
|
||
logger.info(
|
||
`Fresh installation or invalid backend detected, using: ${effectiveBackendString}`
|
||
)
|
||
|
||
// Update the config immediately
|
||
this.config.version_backend = effectiveBackendString
|
||
|
||
// Update the settings to reflect the change in UI
|
||
const updatedSettings = await this.getSettings()
|
||
await this.updateSettings(
|
||
updatedSettings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = effectiveBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
logger.info(`Updated UI settings to show: ${effectiveBackendString}`)
|
||
|
||
// Emit for updating fe
|
||
if (events && typeof events.emit === 'function') {
|
||
logger.info(
|
||
`Emitting settingsChanged event for version_backend with value: ${effectiveBackendString}`
|
||
)
|
||
events.emit('settingsChanged', {
|
||
key: 'version_backend',
|
||
value: effectiveBackendString,
|
||
})
|
||
}
|
||
}
|
||
|
||
// Download and install the backend if not already present
|
||
if (effectiveBackendString) {
|
||
const [version, backend] = effectiveBackendString.split('/')
|
||
if (version && backend) {
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (!isInstalled) {
|
||
logger.info(`Installing initial backend: ${effectiveBackendString}`)
|
||
await this.ensureBackendReady(backend, version)
|
||
backendWasDownloaded = true
|
||
logger.info(
|
||
`Successfully installed initial backend: ${effectiveBackendString}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
if (this.config.auto_update_engine) {
|
||
const updateResult = await this.handleAutoUpdate(
|
||
bestAvailableBackendString
|
||
)
|
||
if (updateResult.wasUpdated) {
|
||
effectiveBackendString = updateResult.newBackend
|
||
backendWasDownloaded = true
|
||
}
|
||
}
|
||
|
||
if (!backendWasDownloaded && effectiveBackendString) {
|
||
await this.ensureFinalBackendInstallation(effectiveBackendString)
|
||
}
|
||
} finally {
|
||
this.isConfiguringBackends = false
|
||
}
|
||
}
|
||
|
||
private async determineBestBackend(
|
||
version_backends: { version: string; backend: string }[]
|
||
): Promise<string> {
|
||
if (version_backends.length === 0) return ''
|
||
|
||
// Check GPU memory availability
|
||
let hasEnoughGpuMemory = false
|
||
try {
|
||
const sysInfo = await getSystemInfo()
|
||
for (const gpuInfo of sysInfo.gpus) {
|
||
if (gpuInfo.total_memory >= 6 * 1024) {
|
||
hasEnoughGpuMemory = true
|
||
break
|
||
}
|
||
}
|
||
} catch (error) {
|
||
logger.warn('Failed to get system info for GPU memory check:', error)
|
||
// Default to false if we can't determine GPU memory
|
||
hasEnoughGpuMemory = false
|
||
}
|
||
|
||
// Priority list for backend types (more specific/performant ones first)
|
||
// Vulkan will be conditionally prioritized based on GPU memory
|
||
const backendPriorities: string[] = hasEnoughGpuMemory
|
||
? [
|
||
'cuda-cu12.0',
|
||
'cuda-cu11.7',
|
||
'vulkan', // Include vulkan if we have enough GPU memory
|
||
'avx512',
|
||
'avx2',
|
||
'avx',
|
||
'noavx',
|
||
'arm64',
|
||
'x64',
|
||
]
|
||
: [
|
||
'cuda-cu12.0',
|
||
'cuda-cu11.7',
|
||
'avx512',
|
||
'avx2',
|
||
'avx',
|
||
'noavx',
|
||
'arm64',
|
||
'x64',
|
||
'vulkan', // demote to last if we don't have enough memory
|
||
]
|
||
|
||
// Helper to map backend string to a priority category
|
||
const getBackendCategory = (backendString: string): string | undefined => {
|
||
if (backendString.includes('cu12.0')) return 'cuda-cu12.0'
|
||
if (backendString.includes('cu11.7')) return 'cuda-cu11.7'
|
||
if (backendString.includes('vulkan')) return 'vulkan'
|
||
if (backendString.includes('avx512')) return 'avx512'
|
||
if (backendString.includes('avx2')) return 'avx2'
|
||
if (
|
||
backendString.includes('avx') &&
|
||
!backendString.includes('avx2') &&
|
||
!backendString.includes('avx512')
|
||
)
|
||
return 'avx'
|
||
if (backendString.includes('noavx')) return 'noavx'
|
||
if (backendString.endsWith('arm64')) return 'arm64'
|
||
if (backendString.endsWith('x64')) return 'x64'
|
||
return undefined
|
||
}
|
||
|
||
let foundBestBackend: { version: string; backend: string } | undefined
|
||
for (const priorityCategory of backendPriorities) {
|
||
const matchingBackends = version_backends.filter((vb) => {
|
||
const category = getBackendCategory(vb.backend)
|
||
return category === priorityCategory
|
||
})
|
||
|
||
if (matchingBackends.length > 0) {
|
||
foundBestBackend = matchingBackends[0]
|
||
logger.info(
|
||
`Determined best available backend: ${foundBestBackend.version}/${foundBestBackend.backend} (Category: "${priorityCategory}")`
|
||
)
|
||
break
|
||
}
|
||
}
|
||
|
||
if (foundBestBackend) {
|
||
return `${foundBestBackend.version}/${foundBestBackend.backend}`
|
||
} else {
|
||
// Fallback to newest version
|
||
logger.info(
|
||
`Fallback to: ${version_backends[0].version}/${version_backends[0].backend}`
|
||
)
|
||
return `${version_backends[0].version}/${version_backends[0].backend}`
|
||
}
|
||
}
|
||
|
||
async updateBackend(
|
||
targetBackendString: string
|
||
): Promise<{ wasUpdated: boolean; newBackend: string }> {
|
||
try {
|
||
if (!targetBackendString)
|
||
throw new Error(
|
||
`Invalid backend string: ${targetBackendString} supplied to update function`
|
||
)
|
||
|
||
const [version, backend] = targetBackendString.split('/')
|
||
|
||
logger.info(
|
||
`Updating backend to ${targetBackendString} (backend type: ${backend})`
|
||
)
|
||
|
||
// Download new backend
|
||
await this.ensureBackendReady(backend, version)
|
||
|
||
// Add delay on Windows
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||
}
|
||
|
||
// Update configuration
|
||
this.config.version_backend = targetBackendString
|
||
|
||
// Store the backend type preference only if it changed
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backend) {
|
||
this.setStoredBackendType(backend)
|
||
logger.info(`Updated stored backend type preference: ${backend}`)
|
||
}
|
||
|
||
// Update settings
|
||
const settings = await this.getSettings()
|
||
await this.updateSettings(
|
||
settings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = targetBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
|
||
logger.info(`Successfully updated to backend: ${targetBackendString}`)
|
||
|
||
// Emit for updating frontend
|
||
if (events && typeof events.emit === 'function') {
|
||
logger.info(
|
||
`Emitting settingsChanged event for version_backend with value: ${targetBackendString}`
|
||
)
|
||
events.emit('settingsChanged', {
|
||
key: 'version_backend',
|
||
value: targetBackendString,
|
||
})
|
||
}
|
||
|
||
// Clean up old versions of the same backend type
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||
}
|
||
await this.removeOldBackend(version, backend)
|
||
|
||
return { wasUpdated: true, newBackend: targetBackendString }
|
||
} catch (error) {
|
||
logger.error('Backend update failed:', error)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
}
|
||
|
||
private async handleAutoUpdate(
|
||
bestAvailableBackendString: string
|
||
): Promise<{ wasUpdated: boolean; newBackend: string }> {
|
||
logger.info(
|
||
`Auto-update engine is enabled. Current backend: ${this.config.version_backend}. Best available: ${bestAvailableBackendString}`
|
||
)
|
||
|
||
if (!bestAvailableBackendString) {
|
||
logger.warn(
|
||
'Auto-update enabled, but no best available backend determined'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// If version_backend is empty, invalid, or 'none', use the best available backend
|
||
if (
|
||
!this.config.version_backend ||
|
||
this.config.version_backend === '' ||
|
||
this.config.version_backend === 'none' ||
|
||
!this.config.version_backend.includes('/')
|
||
) {
|
||
logger.info(
|
||
'No valid backend currently selected, using best available backend'
|
||
)
|
||
|
||
return await this.updateBackend(bestAvailableBackendString)
|
||
}
|
||
|
||
// Parse current backend configuration
|
||
const [currentVersion, currentBackend] = (
|
||
this.config.version_backend || ''
|
||
).split('/')
|
||
|
||
if (!currentVersion || !currentBackend) {
|
||
logger.warn(
|
||
`Invalid current backend format: ${this.config.version_backend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Find the latest version for the currently selected backend type
|
||
const version_backends = await listSupportedBackends()
|
||
const targetBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
currentBackend
|
||
)
|
||
|
||
if (!targetBackendString) {
|
||
logger.warn(
|
||
`No available versions found for current backend type: ${currentBackend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
const [latestVersion] = targetBackendString.split('/')
|
||
|
||
// Check if update is needed (only version comparison for same backend type)
|
||
if (currentVersion === latestVersion) {
|
||
logger.info(
|
||
'Auto-update: Already using the latest version of the selected backend'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Perform version update for the same backend type
|
||
logger.info(
|
||
`Auto-updating from ${this.config.version_backend} to ${targetBackendString} (preserving backend type)`
|
||
)
|
||
|
||
return await this.updateBackend(targetBackendString)
|
||
}
|
||
|
||
private parseBackendVersion(v: string): number {
|
||
// Remove any leading non‑digit characters (e.g. the "b")
|
||
const numeric = v.replace(/^[^\d]*/, '')
|
||
const n = Number(numeric)
|
||
return Number.isNaN(n) ? 0 : n
|
||
}
|
||
|
||
async checkBackendForUpdates(): Promise<{
|
||
updateNeeded: boolean
|
||
newVersion: string
|
||
}> {
|
||
// Parse current backend configuration
|
||
const [currentVersion, currentBackend] = (
|
||
this.config.version_backend || ''
|
||
).split('/')
|
||
|
||
if (!currentVersion || !currentBackend) {
|
||
logger.warn(
|
||
`Invalid current backend format: ${this.config.version_backend}`
|
||
)
|
||
return { updateNeeded: false, newVersion: '0' }
|
||
}
|
||
|
||
// Find the latest version for the currently selected backend type
|
||
const version_backends = await listSupportedBackends()
|
||
const targetBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
currentBackend
|
||
)
|
||
const [latestVersion] = targetBackendString.split('/')
|
||
if (
|
||
this.parseBackendVersion(latestVersion) >
|
||
this.parseBackendVersion(currentVersion)
|
||
) {
|
||
logger.info(`New update available: ${latestVersion}`)
|
||
return { updateNeeded: true, newVersion: latestVersion }
|
||
} else {
|
||
logger.info(
|
||
`Already at latest version: ${currentVersion} = ${latestVersion}`
|
||
)
|
||
return { updateNeeded: false, newVersion: '0' }
|
||
}
|
||
}
|
||
|
||
private async removeOldBackend(
|
||
latestVersion: string,
|
||
backendType: string
|
||
): Promise<void> {
|
||
try {
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const backendsDir = await joinPath([
|
||
janDataFolderPath,
|
||
'llamacpp',
|
||
'backends',
|
||
])
|
||
|
||
if (!(await fs.existsSync(backendsDir))) {
|
||
return
|
||
}
|
||
|
||
const versionDirs = await fs.readdirSync(backendsDir)
|
||
|
||
for (const versionDir of versionDirs) {
|
||
const versionPath = await joinPath([backendsDir, versionDir])
|
||
const versionName = await basename(versionDir)
|
||
|
||
// Skip the latest version
|
||
if (versionName === latestVersion) {
|
||
continue
|
||
}
|
||
|
||
// Check if this version has the specific backend type we're interested in
|
||
const backendTypePath = await joinPath([versionPath, backendType])
|
||
|
||
if (await fs.existsSync(backendTypePath)) {
|
||
const isInstalled = await isBackendInstalled(backendType, versionName)
|
||
if (isInstalled) {
|
||
try {
|
||
await fs.rm(backendTypePath)
|
||
logger.info(
|
||
`Removed old version of ${backendType}: ${backendTypePath}`
|
||
)
|
||
} catch (e) {
|
||
logger.warn(
|
||
`Failed to remove old backend version: ${backendTypePath}`,
|
||
e
|
||
)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} catch (error) {
|
||
logger.error('Error during old backend version cleanup:', error)
|
||
}
|
||
}
|
||
|
||
private async ensureFinalBackendInstallation(
|
||
backendString: string
|
||
): Promise<void> {
|
||
if (!backendString) {
|
||
logger.warn('No backend specified for final installation check')
|
||
return
|
||
}
|
||
|
||
const [selectedVersion, selectedBackend] = backendString
|
||
.split('/')
|
||
.map((part) => part?.trim())
|
||
|
||
if (!selectedVersion || !selectedBackend) {
|
||
logger.warn(`Invalid backend format: ${backendString}`)
|
||
return
|
||
}
|
||
|
||
try {
|
||
const isInstalled = await isBackendInstalled(
|
||
selectedBackend,
|
||
selectedVersion
|
||
)
|
||
if (!isInstalled) {
|
||
logger.info(`Final check: Installing backend ${backendString}`)
|
||
await this.ensureBackendReady(selectedBackend, selectedVersion)
|
||
logger.info(`Successfully installed backend: ${backendString}`)
|
||
} else {
|
||
logger.info(
|
||
`Final check: Backend ${backendString} is already installed`
|
||
)
|
||
}
|
||
} catch (error) {
|
||
logger.error(
|
||
`Failed to ensure backend ${backendString} installation:`,
|
||
error
|
||
)
|
||
throw error // Re-throw as this is critical
|
||
}
|
||
}
|
||
|
||
async getProviderPath(): Promise<string> {
|
||
if (!this.providerPath) {
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
}
|
||
return this.providerPath
|
||
}
|
||
|
||
override async onUnload(): Promise<void> {
|
||
// Terminate all active sessions
|
||
|
||
// Clean up validation event listeners
|
||
if (this.unlistenValidationStarted) {
|
||
this.unlistenValidationStarted()
|
||
}
|
||
}
|
||
|
||
onSettingUpdate<T>(key: string, value: T): void {
|
||
this.config[key] = value
|
||
|
||
if (key === 'version_backend') {
|
||
const valueStr = value as string
|
||
const [version, backend] = valueStr.split('/')
|
||
|
||
// Store the backend type preference in localStorage only if it changed
|
||
if (backend) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backend) {
|
||
this.setStoredBackendType(backend)
|
||
logger.info(`Updated backend type preference to: ${backend}`)
|
||
}
|
||
}
|
||
|
||
// Reset device setting when backend changes
|
||
this.config.device = ''
|
||
|
||
const closure = async () => {
|
||
await this.ensureBackendReady(backend, version)
|
||
}
|
||
closure()
|
||
} else if (key === 'auto_unload') {
|
||
this.autoUnload = value as boolean
|
||
} else if (key === 'llamacpp_env') {
|
||
this.llamacpp_env = value as string
|
||
} else if (key === 'memory_util') {
|
||
this.memoryMode = value as string
|
||
}
|
||
}
|
||
|
||
private async generateApiKey(modelId: string, port: string): Promise<string> {
|
||
const hash = await invoke<string>('plugin:llamacpp|generate_api_key', {
|
||
modelId: modelId + port,
|
||
apiSecret: this.apiSecret,
|
||
})
|
||
return hash
|
||
}
|
||
|
||
override async get(modelId: string): Promise<modelInfo | undefined> {
|
||
const modelPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
const path = await joinPath([modelPath, 'model.yml'])
|
||
|
||
if (!(await fs.existsSync(path))) return undefined
|
||
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path,
|
||
})
|
||
|
||
return {
|
||
id: modelId,
|
||
name: modelConfig.name ?? modelId,
|
||
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||
providerId: this.provider,
|
||
port: 0, // port is not known until the model is loaded
|
||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||
} as modelInfo
|
||
}
|
||
|
||
// Implement the required LocalProvider interface methods
|
||
override async list(): Promise<modelInfo[]> {
|
||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) {
|
||
await fs.mkdir(modelsDir)
|
||
}
|
||
|
||
await this.migrateLegacyModels()
|
||
|
||
let modelIds: string[] = []
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
// check if model.yml exists
|
||
const modelConfigPath = await joinPath([currentDir, 'model.yml'])
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
modelIds.push(currentDir.slice(modelsDir.length + 1))
|
||
continue
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
|
||
let modelInfos: modelInfo[] = []
|
||
for (const modelId of modelIds) {
|
||
const path = await joinPath([modelsDir, modelId, 'model.yml'])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||
|
||
const modelInfo = {
|
||
id: modelId,
|
||
name: modelConfig.name ?? modelId,
|
||
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||
providerId: this.provider,
|
||
port: 0, // port is not known until the model is loaded
|
||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||
} as modelInfo
|
||
modelInfos.push(modelInfo)
|
||
}
|
||
|
||
return modelInfos
|
||
}
|
||
|
||
private async migrateLegacyModels() {
|
||
// Attempt to migrate only once
|
||
if (localStorage.getItem('cortex_models_migrated') === 'true') return
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelsDir = await joinPath([janDataFolderPath, 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) return
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
const files = await fs.readdirSync(currentDir)
|
||
for (const child of files) {
|
||
try {
|
||
const childPath = await joinPath([currentDir, child])
|
||
const stat = await fs.fileStat(childPath)
|
||
if (
|
||
files.some((e) => e.endsWith('model.yml')) &&
|
||
!child.endsWith('model.yml')
|
||
)
|
||
continue
|
||
if (!stat.isDirectory && child.endsWith('.yml')) {
|
||
// check if model.yml exists
|
||
const modelConfigPath = child
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
const legacyModelConfig = await invoke<{
|
||
files: string[]
|
||
model: string
|
||
}>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const legacyModelPath = legacyModelConfig.files?.[0]
|
||
if (!legacyModelPath) continue
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
let modelId = currentDir.slice(modelsDir.length + 1)
|
||
|
||
modelId =
|
||
modelId !== 'imported'
|
||
? modelId.replace(/^(cortex\.so|huggingface\.co)[\/\\]/, '')
|
||
: (await basename(child)).replace('.yml', '')
|
||
|
||
const modelName = legacyModelConfig.model ?? modelId
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath)) continue // Don't reimport
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
let size_bytes = (
|
||
await fs.fileStat(
|
||
await joinPath([janDataFolderPath, legacyModelPath])
|
||
)
|
||
).size
|
||
|
||
const modelConfig = {
|
||
model_path: legacyModelPath,
|
||
mmproj_path: undefined, // legacy models do not have mmproj
|
||
name: modelName,
|
||
size_bytes,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
continue
|
||
}
|
||
}
|
||
} catch (error) {
|
||
console.error(`Error migrating model ${child}:`, error)
|
||
}
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
localStorage.setItem('cortex_models_migrated', 'true')
|
||
}
|
||
|
||
/*
|
||
* Manually installs a supported backend archive
|
||
*
|
||
*/
|
||
async installBackend(path: string): Promise<void> {
|
||
const platformName = IS_WINDOWS ? 'win' : 'linux'
|
||
const re = /^llama-(b\d+)-bin-(.+?)\.(?:tar\.gz|zip)$/
|
||
const archiveName = await basename(path)
|
||
logger.info(`Installing backend from path: ${path}`)
|
||
|
||
if (
|
||
!(await fs.existsSync(path)) ||
|
||
(!path.endsWith('tar.gz') && !path.endsWith('zip'))
|
||
) {
|
||
logger.error(`Invalid path or file ${path}`)
|
||
throw new Error(`Invalid path or file ${path}`)
|
||
}
|
||
const match = re.exec(archiveName)
|
||
if (!match) throw new Error('Failed to parse archive name')
|
||
const [, version, backend] = match
|
||
if (!version && !backend) {
|
||
throw new Error(`Invalid backend archive name: ${archiveName}`)
|
||
}
|
||
const backendDir = await getBackendDir(backend, version)
|
||
try {
|
||
await invoke('decompress', { path: path, outputDir: backendDir })
|
||
} catch (e) {
|
||
logger.error(`Failed to install: ${String(e)}`)
|
||
}
|
||
const binPath =
|
||
platformName === 'win'
|
||
? await joinPath([backendDir, 'build', 'bin', 'llama-server.exe'])
|
||
: await joinPath([backendDir, 'build', 'bin', 'llama-server'])
|
||
|
||
if (!fs.existsSync(binPath)) {
|
||
await fs.rm(backendDir)
|
||
throw new Error('Not a supported backend archive!')
|
||
}
|
||
try {
|
||
await this.configureBackends()
|
||
logger.info(`Backend ${backend}/${version} installed and UI refreshed`)
|
||
} catch (e) {
|
||
logger.error('Backend installed but failed to refresh UI', e)
|
||
throw new Error(
|
||
`Backend installed but failed to refresh UI: ${String(e)}`
|
||
)
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Update a model with new information.
|
||
* @param modelId
|
||
* @param model
|
||
*/
|
||
async update(modelId: string, model: Partial<modelInfo>): Promise<void> {
|
||
const modelFolderPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: await joinPath([modelFolderPath, 'model.yml']),
|
||
})
|
||
const newFolderPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
model.id,
|
||
])
|
||
// Check if newFolderPath exists
|
||
if (await fs.existsSync(newFolderPath)) {
|
||
throw new Error(`Model with ID ${model.id} already exists`)
|
||
}
|
||
const newModelConfigPath = await joinPath([newFolderPath, 'model.yml'])
|
||
await fs.mv(modelFolderPath, newFolderPath).then(() =>
|
||
// now replace what values have previous model name with format
|
||
invoke('write_yaml', {
|
||
data: {
|
||
...modelConfig,
|
||
model_path: modelConfig?.model_path?.replace(
|
||
`${this.providerId}/models/${modelId}`,
|
||
`${this.providerId}/models/${model.id}`
|
||
),
|
||
mmproj_path: modelConfig?.mmproj_path?.replace(
|
||
`${this.providerId}/models/${modelId}`,
|
||
`${this.providerId}/models/${model.id}`
|
||
),
|
||
},
|
||
savePath: newModelConfigPath,
|
||
})
|
||
)
|
||
}
|
||
|
||
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||
const isValidModelId = (id: string) => {
|
||
// only allow alphanumeric, underscore, hyphen, and dot characters in modelId
|
||
if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false
|
||
|
||
// check for empty parts or path traversal
|
||
const parts = id.split('/')
|
||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||
}
|
||
|
||
if (!isValidModelId(modelId))
|
||
throw new Error(
|
||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||
)
|
||
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath))
|
||
throw new Error(`Model ${modelId} already exists`)
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
// we only use these from opts
|
||
// opts.modelPath: URL to the model file
|
||
// opts.mmprojPath: URL to the mmproj file
|
||
|
||
let downloadItems: DownloadItem[] = []
|
||
|
||
const maybeDownload = async (path: string, saveName: string) => {
|
||
// if URL, add to downloadItems, and return local path
|
||
if (path.startsWith('https://')) {
|
||
const localPath = `${modelDir}/${saveName}`
|
||
downloadItems.push({
|
||
url: path,
|
||
save_path: localPath,
|
||
proxy: getProxyConfig(),
|
||
sha256:
|
||
saveName === 'model.gguf' ? opts.modelSha256 : opts.mmprojSha256,
|
||
size: saveName === 'model.gguf' ? opts.modelSize : opts.mmprojSize,
|
||
})
|
||
return localPath
|
||
}
|
||
|
||
// if local file (absolute path), check if it exists
|
||
// and return the path
|
||
if (!(await fs.existsSync(path)))
|
||
throw new Error(`File not found: ${path}`)
|
||
return path
|
||
}
|
||
|
||
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||
let mmprojPath = opts.mmprojPath
|
||
? await maybeDownload(opts.mmprojPath, 'mmproj.gguf')
|
||
: undefined
|
||
|
||
if (downloadItems.length > 0) {
|
||
try {
|
||
// emit download update event on progress
|
||
const onProgress = (transferred: number, total: number) => {
|
||
events.emit(DownloadEvent.onFileDownloadUpdate, {
|
||
modelId,
|
||
percent: transferred / total,
|
||
size: { transferred, total },
|
||
downloadType: 'Model',
|
||
})
|
||
}
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
await downloadManager.downloadFiles(
|
||
downloadItems,
|
||
this.createDownloadTaskId(modelId),
|
||
onProgress
|
||
)
|
||
|
||
// If we reach here, download completed successfully (including validation)
|
||
// The downloadFiles function only returns successfully if all files downloaded AND validated
|
||
events.emit(DownloadEvent.onFileDownloadAndVerificationSuccess, {
|
||
modelId,
|
||
downloadType: 'Model',
|
||
})
|
||
} catch (error) {
|
||
logger.error('Error downloading model:', modelId, opts, error)
|
||
const errorMessage =
|
||
error instanceof Error ? error.message : String(error)
|
||
|
||
// Check if this is a cancellation
|
||
const isCancellationError =
|
||
errorMessage.includes('Download cancelled') ||
|
||
errorMessage.includes('Validation cancelled') ||
|
||
errorMessage.includes('Hash computation cancelled') ||
|
||
errorMessage.includes('cancelled') ||
|
||
errorMessage.includes('aborted')
|
||
|
||
// Check if this is a validation failure
|
||
const isValidationError =
|
||
errorMessage.includes('Hash verification failed') ||
|
||
errorMessage.includes('Size verification failed') ||
|
||
errorMessage.includes('Failed to verify file')
|
||
|
||
if (isCancellationError) {
|
||
logger.info('Download cancelled for model:', modelId)
|
||
// Emit download stopped event instead of error
|
||
events.emit(DownloadEvent.onFileDownloadStopped, {
|
||
modelId,
|
||
downloadType: 'Model',
|
||
})
|
||
} else if (isValidationError) {
|
||
logger.error(
|
||
'Validation failed for model:',
|
||
modelId,
|
||
'Error:',
|
||
errorMessage
|
||
)
|
||
|
||
// Cancel any other download tasks for this model
|
||
try {
|
||
this.abortImport(modelId)
|
||
} catch (cancelError) {
|
||
logger.warn('Failed to cancel download task:', cancelError)
|
||
}
|
||
|
||
// Emit validation failure event
|
||
events.emit(DownloadEvent.onModelValidationFailed, {
|
||
modelId,
|
||
downloadType: 'Model',
|
||
error: errorMessage,
|
||
reason: 'validation_failed',
|
||
})
|
||
} else {
|
||
// Regular download error
|
||
events.emit(DownloadEvent.onFileDownloadError, {
|
||
modelId,
|
||
downloadType: 'Model',
|
||
error: errorMessage,
|
||
})
|
||
}
|
||
throw error
|
||
}
|
||
}
|
||
|
||
// Validate GGUF files
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const fullModelPath = await joinPath([janDataFolderPath, modelPath])
|
||
|
||
try {
|
||
// Validate main model file
|
||
const modelMetadata = await readGgufMetadata(fullModelPath)
|
||
logger.info(
|
||
`Model GGUF validation successful: version ${modelMetadata.version}, tensors: ${modelMetadata.tensor_count}`
|
||
)
|
||
|
||
// Validate mmproj file if present
|
||
if (mmprojPath) {
|
||
const fullMmprojPath = await joinPath([janDataFolderPath, mmprojPath])
|
||
const mmprojMetadata = await readGgufMetadata(fullMmprojPath)
|
||
logger.info(
|
||
`Mmproj GGUF validation successful: version ${mmprojMetadata.version}, tensors: ${mmprojMetadata.tensor_count}`
|
||
)
|
||
}
|
||
} catch (error) {
|
||
logger.error('GGUF validation failed:', error)
|
||
throw new Error(
|
||
`Invalid GGUF file(s): ${
|
||
error.message || 'File format validation failed'
|
||
}`
|
||
)
|
||
}
|
||
|
||
// Calculate file sizes
|
||
let size_bytes = (await fs.fileStat(fullModelPath)).size
|
||
if (mmprojPath) {
|
||
size_bytes += (
|
||
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
|
||
).size
|
||
}
|
||
|
||
// TODO: add name as import() argument
|
||
// TODO: add updateModelConfig() method
|
||
const modelConfig = {
|
||
model_path: modelPath,
|
||
mmproj_path: mmprojPath,
|
||
name: modelId,
|
||
size_bytes,
|
||
model_sha256: opts.modelSha256,
|
||
model_size_bytes: opts.modelSize,
|
||
mmproj_sha256: opts.mmprojSha256,
|
||
mmproj_size_bytes: opts.mmprojSize,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
events.emit(AppEvent.onModelImported, {
|
||
modelId,
|
||
modelPath,
|
||
mmprojPath,
|
||
size_bytes,
|
||
model_sha256: opts.modelSha256,
|
||
model_size_bytes: opts.modelSize,
|
||
mmproj_sha256: opts.mmprojSha256,
|
||
mmproj_size_bytes: opts.mmprojSize,
|
||
})
|
||
}
|
||
|
||
/**
|
||
* Deletes the entire model folder for a given modelId
|
||
* @param modelId The model ID to delete
|
||
*/
|
||
private async deleteModelFolder(modelId: string): Promise<void> {
|
||
try {
|
||
const modelDir = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
|
||
if (await fs.existsSync(modelDir)) {
|
||
logger.info(`Cleaning up model directory: ${modelDir}`)
|
||
await fs.rm(modelDir)
|
||
}
|
||
} catch (deleteError) {
|
||
logger.warn('Failed to delete model directory:', deleteError)
|
||
}
|
||
}
|
||
|
||
override async abortImport(modelId: string): Promise<void> {
|
||
// Cancel any active download task
|
||
// prepend provider name to avoid name collision
|
||
const taskId = this.createDownloadTaskId(modelId)
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
|
||
try {
|
||
await downloadManager.cancelDownload(taskId)
|
||
} catch (cancelError) {
|
||
logger.warn('Failed to cancel download task:', cancelError)
|
||
}
|
||
|
||
// Delete the entire model folder if it exists (for validation failures)
|
||
await this.deleteModelFolder(modelId)
|
||
}
|
||
|
||
/**
|
||
* Function to find a random port
|
||
*/
|
||
private async getRandomPort(): Promise<number> {
|
||
try {
|
||
const port = await invoke<number>('plugin:llamacpp|get_random_port')
|
||
return port
|
||
} catch {
|
||
logger.error('Unable to find a suitable port')
|
||
throw new Error('Unable to find a suitable port for model')
|
||
}
|
||
}
|
||
|
||
private parseEnvFromString(
|
||
target: Record<string, string>,
|
||
envString: string
|
||
): void {
|
||
envString
|
||
.split(';')
|
||
.filter((pair) => pair.trim())
|
||
.forEach((pair) => {
|
||
const [key, ...valueParts] = pair.split('=')
|
||
const cleanKey = key?.trim()
|
||
|
||
if (
|
||
cleanKey &&
|
||
valueParts.length > 0 &&
|
||
!cleanKey.startsWith('LLAMA')
|
||
) {
|
||
target[cleanKey] = valueParts.join('=').trim()
|
||
}
|
||
})
|
||
}
|
||
|
||
override async load(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const sInfo = await this.findSessionByModel(modelId)
|
||
if (sInfo) {
|
||
throw new Error('Model already loaded!!')
|
||
}
|
||
|
||
// If this model is already being loaded, return the existing promise
|
||
if (this.loadingModels.has(modelId)) {
|
||
return this.loadingModels.get(modelId)!
|
||
}
|
||
|
||
// Create the loading promise
|
||
const loadingPromise = this.performLoad(
|
||
modelId,
|
||
overrideSettings,
|
||
isEmbedding
|
||
)
|
||
this.loadingModels.set(modelId, loadingPromise)
|
||
|
||
try {
|
||
const result = await loadingPromise
|
||
return result
|
||
} finally {
|
||
this.loadingModels.delete(modelId)
|
||
}
|
||
}
|
||
|
||
private async performLoad(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const loadedModels = await this.getLoadedModels()
|
||
|
||
// Get OTHER models that are currently loading (exclude current model)
|
||
const otherLoadingPromises = Array.from(this.loadingModels.entries())
|
||
.filter(([id, _]) => id !== modelId)
|
||
.map(([_, promise]) => promise)
|
||
|
||
if (
|
||
this.autoUnload &&
|
||
!isEmbedding &&
|
||
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
|
||
) {
|
||
// Wait for OTHER loading models to finish, then unload everything
|
||
if (otherLoadingPromises.length > 0) {
|
||
await Promise.all(otherLoadingPromises)
|
||
}
|
||
|
||
// Now unload all loaded Text models excluding embedding models
|
||
const allLoadedModels = await this.getLoadedModels()
|
||
if (allLoadedModels.length > 0) {
|
||
const sessionInfos: (SessionInfo | null)[] = await Promise.all(
|
||
allLoadedModels.map(async (modelId) => {
|
||
try {
|
||
return await this.findSessionByModel(modelId)
|
||
} catch (e) {
|
||
logger.warn(`Unable to find session for model "${modelId}": ${e}`)
|
||
return null // treat as “not‑eligible for unload”
|
||
}
|
||
})
|
||
)
|
||
|
||
logger.info(JSON.stringify(sessionInfos))
|
||
|
||
const nonEmbeddingModels: string[] = sessionInfos
|
||
.filter(
|
||
(s): s is SessionInfo => s !== null && s.is_embedding === false
|
||
)
|
||
.map((s) => s.model_id)
|
||
|
||
if (nonEmbeddingModels.length > 0) {
|
||
await Promise.all(
|
||
nonEmbeddingModels.map((modelId) => this.unload(modelId))
|
||
)
|
||
}
|
||
}
|
||
}
|
||
const args: string[] = []
|
||
const envs: Record<string, string> = {}
|
||
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
'Initial setup for the backend failed due to a network issue. Please restart the app!'
|
||
)
|
||
}
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelConfigPath = await joinPath([
|
||
this.providerPath,
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const port = await this.getRandomPort()
|
||
|
||
// disable llama-server webui
|
||
args.push('--no-webui')
|
||
const api_key = await this.generateApiKey(modelId, String(port))
|
||
envs['LLAMA_API_KEY'] = api_key
|
||
|
||
// set user envs
|
||
if (this.llamacpp_env) this.parseEnvFromString(envs, this.llamacpp_env)
|
||
|
||
// model option is required
|
||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||
const modelPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.model_path,
|
||
])
|
||
args.push('--jinja')
|
||
args.push('-m', modelPath)
|
||
if (cfg.cpu_moe) args.push('--cpu-moe')
|
||
if (cfg.n_cpu_moe && cfg.n_cpu_moe > 0) {
|
||
args.push('--n-cpu-moe', String(cfg.n_cpu_moe))
|
||
}
|
||
// For overriding tensor buffer type, useful where
|
||
// massive MOE models can be made faster by keeping attention on the GPU
|
||
// and offloading the expert FFNs to the CPU.
|
||
// This is an expert level settings and should only be used by people
|
||
// who knows what they are doing.
|
||
// Takes a regex with matching tensor name as input
|
||
if (cfg.override_tensor_buffer_t)
|
||
args.push('--override-tensor', cfg.override_tensor_buffer_t)
|
||
// offload multimodal projector model to the GPU by default. if there is not enough memory
|
||
// turn this setting off will keep the projector model on the CPU but the image processing can
|
||
// take longer
|
||
if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload')
|
||
args.push('-a', modelId)
|
||
args.push('--port', String(port))
|
||
if (modelConfig.mmproj_path) {
|
||
const mmprojPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.mmproj_path,
|
||
])
|
||
args.push('--mmproj', mmprojPath)
|
||
}
|
||
// Add remaining options from the interface
|
||
if (cfg.chat_template) args.push('--chat-template', cfg.chat_template)
|
||
const gpu_layers =
|
||
parseInt(String(cfg.n_gpu_layers)) >= 0 ? cfg.n_gpu_layers : 100
|
||
args.push('-ngl', String(gpu_layers))
|
||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||
if (cfg.threads_batch > 0)
|
||
args.push('--threads-batch', String(cfg.threads_batch))
|
||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||
if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer')
|
||
args.push('--split-mode', cfg.split_mode)
|
||
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
|
||
args.push('--main-gpu', String(cfg.main_gpu))
|
||
|
||
// Boolean flags
|
||
if (cfg.ctx_shift) args.push('--context-shift')
|
||
if (Number(version.replace(/^b/, '')) >= 6325) {
|
||
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
|
||
} else {
|
||
if (cfg.flash_attn) args.push('--flash-attn')
|
||
}
|
||
if (cfg.cont_batching) args.push('--cont-batching')
|
||
args.push('--no-mmap')
|
||
if (cfg.mlock) args.push('--mlock')
|
||
if (cfg.no_kv_offload) args.push('--no-kv-offload')
|
||
if (isEmbedding) {
|
||
args.push('--embedding')
|
||
args.push('--pooling', 'mean')
|
||
} else {
|
||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||
if (cfg.cache_type_k && cfg.cache_type_k != 'f16')
|
||
args.push('--cache-type-k', cfg.cache_type_k)
|
||
if (
|
||
cfg.flash_attn &&
|
||
cfg.cache_type_v != 'f16' &&
|
||
cfg.cache_type_v != 'f32'
|
||
) {
|
||
args.push('--cache-type-v', cfg.cache_type_v)
|
||
}
|
||
if (cfg.defrag_thold && cfg.defrag_thold != 0.1)
|
||
args.push('--defrag-thold', String(cfg.defrag_thold))
|
||
|
||
if (cfg.rope_scaling && cfg.rope_scaling != 'none')
|
||
args.push('--rope-scaling', cfg.rope_scaling)
|
||
if (cfg.rope_scale && cfg.rope_scale != 1)
|
||
args.push('--rope-scale', String(cfg.rope_scale))
|
||
if (cfg.rope_freq_base && cfg.rope_freq_base != 0)
|
||
args.push('--rope-freq-base', String(cfg.rope_freq_base))
|
||
if (cfg.rope_freq_scale && cfg.rope_freq_scale != 1)
|
||
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
|
||
}
|
||
|
||
logger.info('Calling Tauri command llama_load with args:', args)
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
|
||
try {
|
||
// TODO: add LIBRARY_PATH
|
||
const sInfo = await invoke<SessionInfo>(
|
||
'plugin:llamacpp|load_llama_model',
|
||
{
|
||
backendPath,
|
||
libraryPath,
|
||
args,
|
||
envs,
|
||
isEmbedding,
|
||
}
|
||
)
|
||
return sInfo
|
||
} catch (error) {
|
||
logger.error('Error in load command:\n', error)
|
||
throw error
|
||
}
|
||
}
|
||
|
||
override async unload(modelId: string): Promise<UnloadResult> {
|
||
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
|
||
if (!sInfo) {
|
||
throw new Error(`No active session found for model: ${modelId}`)
|
||
}
|
||
const pid = sInfo.pid
|
||
try {
|
||
// Pass the PID as the session_id
|
||
const result = await invoke<UnloadResult>(
|
||
'plugin:llamacpp|unload_llama_model',
|
||
{
|
||
pid: pid,
|
||
}
|
||
)
|
||
|
||
// If successful, remove from active sessions
|
||
if (result.success) {
|
||
logger.info(`Successfully unloaded model with PID ${pid}`)
|
||
} else {
|
||
logger.warn(`Failed to unload model: ${result.error}`)
|
||
}
|
||
|
||
return result
|
||
} catch (error) {
|
||
logger.error('Error in unload command:', error)
|
||
return {
|
||
success: false,
|
||
error: `Failed to unload model: ${error}`,
|
||
}
|
||
}
|
||
}
|
||
|
||
private createDownloadTaskId(modelId: string) {
|
||
// prepend provider to make taksId unique across providers
|
||
const cleanModelId = modelId.includes('.')
|
||
? modelId.slice(0, modelId.indexOf('.'))
|
||
: modelId
|
||
return `${this.provider}/${cleanModelId}`
|
||
}
|
||
|
||
private async ensureBackendReady(
|
||
backend: string,
|
||
version: string
|
||
): Promise<void> {
|
||
const backendKey = `${version}/${backend}`
|
||
|
||
// Check if backend is already installed
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (isInstalled) {
|
||
return
|
||
}
|
||
|
||
// Check if download is already in progress
|
||
if (this.pendingDownloads.has(backendKey)) {
|
||
logger.info(
|
||
`Backend ${backendKey} download already in progress, waiting...`
|
||
)
|
||
await this.pendingDownloads.get(backendKey)
|
||
return
|
||
}
|
||
|
||
// Start new download
|
||
logger.info(`Backend ${backendKey} not installed, downloading...`)
|
||
const downloadPromise = downloadBackend(backend, version).finally(() => {
|
||
this.pendingDownloads.delete(backendKey)
|
||
})
|
||
|
||
this.pendingDownloads.set(backendKey, downloadPromise)
|
||
await downloadPromise
|
||
logger.info(`Backend ${backendKey} download completed`)
|
||
}
|
||
|
||
private async *handleStreamingResponse(
|
||
url: string,
|
||
headers: HeadersInit,
|
||
body: string,
|
||
abortController?: AbortController
|
||
): AsyncIterable<chatCompletionChunk> {
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
connectTimeout: 600000, // 10 minutes
|
||
signal: AbortSignal.any([
|
||
AbortSignal.timeout(600000),
|
||
abortController?.signal,
|
||
]),
|
||
})
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
if (!response.body) {
|
||
throw new Error('Response body is null')
|
||
}
|
||
|
||
const reader = response.body.getReader()
|
||
const decoder = new TextDecoder('utf-8')
|
||
let buffer = ''
|
||
let jsonStr = ''
|
||
try {
|
||
while (true) {
|
||
const { done, value } = await reader.read()
|
||
|
||
if (done) {
|
||
break
|
||
}
|
||
|
||
buffer += decoder.decode(value, { stream: true })
|
||
|
||
// Process complete lines in the buffer
|
||
const lines = buffer.split('\n')
|
||
buffer = lines.pop() || '' // Keep the last incomplete line in the buffer
|
||
|
||
for (const line of lines) {
|
||
const trimmedLine = line.trim()
|
||
if (!trimmedLine || trimmedLine === 'data: [DONE]') {
|
||
continue
|
||
}
|
||
|
||
if (trimmedLine.startsWith('data: ')) {
|
||
jsonStr = trimmedLine.slice(6)
|
||
} else if (trimmedLine.startsWith('error: ')) {
|
||
jsonStr = trimmedLine.slice(7)
|
||
const error = JSON.parse(jsonStr)
|
||
throw new Error(error.message)
|
||
} else {
|
||
// it should not normally reach here
|
||
throw new Error('Malformed chunk')
|
||
}
|
||
try {
|
||
const data = JSON.parse(jsonStr)
|
||
const chunk = data as chatCompletionChunk
|
||
|
||
// Check for out-of-context error conditions
|
||
if (chunk.choices?.[0]?.finish_reason === 'length') {
|
||
// finish_reason 'length' indicates context limit was hit
|
||
throw new Error(OUT_OF_CONTEXT_SIZE)
|
||
}
|
||
|
||
yield chunk
|
||
} catch (e) {
|
||
logger.error('Error parsing JSON from stream or server error:', e)
|
||
// re‑throw so the async iterator terminates with an error
|
||
throw e
|
||
}
|
||
}
|
||
}
|
||
} finally {
|
||
reader.releaseLock()
|
||
}
|
||
}
|
||
|
||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||
try {
|
||
let sInfo = await invoke<SessionInfo>(
|
||
'plugin:llamacpp|find_session_by_model',
|
||
{
|
||
modelId,
|
||
}
|
||
)
|
||
return sInfo
|
||
} catch (e) {
|
||
logger.error(e)
|
||
throw new Error(String(e))
|
||
}
|
||
}
|
||
|
||
override async chat(
|
||
opts: chatCompletionRequest,
|
||
abortController?: AbortController
|
||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||
const sessionInfo = await this.findSessionByModel(opts.model)
|
||
if (!sessionInfo) {
|
||
throw new Error(`No active session found for model: ${opts.model}`)
|
||
}
|
||
// check if the process is alive
|
||
const result = await invoke<boolean>('plugin:llamacpp|is_process_running', {
|
||
pid: sessionInfo.pid,
|
||
})
|
||
if (result) {
|
||
try {
|
||
await fetch(`http://localhost:${sessionInfo.port}/health`)
|
||
} catch (e) {
|
||
this.unload(sessionInfo.model_id)
|
||
throw new Error('Model appears to have crashed! Please reload!')
|
||
}
|
||
} else {
|
||
throw new Error('Model have crashed! Please reload!')
|
||
}
|
||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||
const url = `${baseUrl}/chat/completions`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sessionInfo.api_key}`,
|
||
}
|
||
// always enable prompt progress return if stream is true
|
||
// Requires llamacpp version > b6399
|
||
// Example json returned from server
|
||
// {"choices":[{"finish_reason":null,"index":0,"delta":{"role":"assistant","content":null}}],"created":1758113912,"id":"chatcmpl-UwZwgxQKyJMo7WzMzXlsi90YTUK2BJro","model":"qwen","system_fingerprint":"b1-e4912fc","object":"chat.completion.chunk","prompt_progress":{"total":36,"cache":0,"processed":36,"time_ms":5706760300}}
|
||
// (chunk.prompt_progress?.processed / chunk.prompt_progress?.total) * 100
|
||
// chunk.prompt_progress?.cache is for past tokens already in kv cache
|
||
opts.return_progress = true
|
||
|
||
const body = JSON.stringify(opts)
|
||
if (opts.stream) {
|
||
return this.handleStreamingResponse(url, headers, body, abortController)
|
||
}
|
||
// Handle non-streaming response
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
signal: abortController?.signal,
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
const completionResponse = (await response.json()) as chatCompletion
|
||
|
||
// Check for out-of-context error conditions
|
||
if (completionResponse.choices?.[0]?.finish_reason === 'length') {
|
||
// finish_reason 'length' indicates context limit was hit
|
||
throw new Error(OUT_OF_CONTEXT_SIZE)
|
||
}
|
||
|
||
return completionResponse
|
||
}
|
||
|
||
override async delete(modelId: string): Promise<void> {
|
||
const modelDir = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
|
||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||
throw new Error(`Model ${modelId} does not exist`)
|
||
}
|
||
|
||
await fs.rm(modelDir)
|
||
}
|
||
|
||
override async getLoadedModels(): Promise<string[]> {
|
||
try {
|
||
let models: string[] = await invoke<string[]>(
|
||
'plugin:llamacpp|get_loaded_models'
|
||
)
|
||
return models
|
||
} catch (e) {
|
||
logger.error(e)
|
||
throw new Error(e)
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Check if mmproj.gguf file exists for a given model ID
|
||
* @param modelId - The model ID to check for mmproj.gguf
|
||
* @returns Promise<boolean> - true if mmproj.gguf exists, false otherwise
|
||
*/
|
||
async checkMmprojExists(modelId: string): Promise<boolean> {
|
||
try {
|
||
const modelConfigPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
|
||
// If mmproj_path is not defined in YAML, return false
|
||
if (modelConfig.mmproj_path) {
|
||
return true
|
||
}
|
||
|
||
const mmprojPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'mmproj.gguf',
|
||
])
|
||
return await fs.existsSync(mmprojPath)
|
||
} catch (e) {
|
||
logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e)
|
||
return false
|
||
}
|
||
}
|
||
|
||
async getDevices(): Promise<DeviceList[]> {
|
||
const cfg = this.config
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
'Backend setup was not successful. Please restart the app in a stable internet connection.'
|
||
)
|
||
}
|
||
// set envs
|
||
const envs: Record<string, string> = {}
|
||
if (this.llamacpp_env) this.parseEnvFromString(envs, this.llamacpp_env)
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
logger.info('Calling Tauri command getDevices with arg --list-devices')
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
try {
|
||
const dList = await invoke<DeviceList[]>('plugin:llamacpp|get_devices', {
|
||
backendPath,
|
||
libraryPath,
|
||
envs,
|
||
})
|
||
// On Linux with AMD GPUs, llama.cpp via Vulkan may report UMA (shared) memory as device-local.
|
||
// For clearer UX, override with dedicated VRAM from the hardware plugin when available.
|
||
try {
|
||
const sysInfo = await getSystemInfo()
|
||
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
|
||
const usage = await getSystemUsage()
|
||
if (usage && Array.isArray(usage.gpus)) {
|
||
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
|
||
for (const u of usage.gpus as any[]) {
|
||
if (u && typeof u.uuid === 'string') {
|
||
uuidToUsage[u.uuid] = u
|
||
}
|
||
}
|
||
|
||
const indexToAmdUuid = new Map<number, string>()
|
||
for (const gpu of sysInfo.gpus as any[]) {
|
||
const vendorStr =
|
||
typeof gpu?.vendor === 'string'
|
||
? gpu.vendor
|
||
: typeof gpu?.vendor === 'object' && gpu.vendor !== null
|
||
? String(gpu.vendor)
|
||
: ''
|
||
if (
|
||
vendorStr.toUpperCase().includes('AMD') &&
|
||
gpu?.vulkan_info &&
|
||
typeof gpu.vulkan_info.index === 'number' &&
|
||
typeof gpu.uuid === 'string'
|
||
) {
|
||
indexToAmdUuid.set(gpu.vulkan_info.index, gpu.uuid)
|
||
}
|
||
}
|
||
|
||
if (indexToAmdUuid.size > 0) {
|
||
const adjusted = dList.map((dev) => {
|
||
if (dev.id?.startsWith('Vulkan')) {
|
||
const match = /^Vulkan(\d+)/.exec(dev.id)
|
||
if (match) {
|
||
const vIdx = Number(match[1])
|
||
const uuid = indexToAmdUuid.get(vIdx)
|
||
if (uuid) {
|
||
const u = uuidToUsage[uuid]
|
||
if (
|
||
u &&
|
||
typeof u.total_memory === 'number' &&
|
||
typeof u.used_memory === 'number'
|
||
) {
|
||
const total = Math.max(0, Math.floor(u.total_memory))
|
||
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
|
||
return { ...dev, mem: total, free }
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return dev
|
||
})
|
||
return adjusted
|
||
}
|
||
}
|
||
}
|
||
} catch (e) {
|
||
logger.warn('Device memory override (AMD/Linux) failed:', e)
|
||
}
|
||
|
||
return dList
|
||
} catch (error) {
|
||
logger.error('Failed to query devices:\n', error)
|
||
throw new Error('Failed to load llamacpp backend')
|
||
}
|
||
}
|
||
|
||
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||
// Ensure the sentence-transformer model is present
|
||
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
|
||
if (!sInfo) {
|
||
const downloadedModelList = await this.list()
|
||
if (
|
||
!downloadedModelList.some(
|
||
(model) => model.id === 'sentence-transformer-mini'
|
||
)
|
||
) {
|
||
await this.import('sentence-transformer-mini', {
|
||
modelPath:
|
||
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',
|
||
})
|
||
}
|
||
// Load specifically in embedding mode
|
||
sInfo = await this.load('sentence-transformer-mini', undefined, true)
|
||
}
|
||
|
||
const attemptRequest = async (session: SessionInfo) => {
|
||
const baseUrl = `http://localhost:${session.port}/v1/embeddings`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${session.api_key}`,
|
||
}
|
||
const body = JSON.stringify({
|
||
input: text,
|
||
model: session.model_id,
|
||
encoding_format: 'float',
|
||
})
|
||
const response = await fetch(baseUrl, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
})
|
||
return response
|
||
}
|
||
|
||
// First try with the existing session (may have been started without --embedding previously)
|
||
let response = await attemptRequest(sInfo)
|
||
|
||
// If embeddings endpoint is not available (501), reload with embedding mode and retry once
|
||
if (response.status === 501) {
|
||
try {
|
||
await this.unload('sentence-transformer-mini')
|
||
} catch {}
|
||
sInfo = await this.load('sentence-transformer-mini', undefined, true)
|
||
response = await attemptRequest(sInfo)
|
||
}
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(errorData)}`
|
||
)
|
||
}
|
||
const responseData = await response.json()
|
||
return responseData as EmbeddingResponse
|
||
}
|
||
|
||
/**
|
||
* Check if a tool is supported by the model
|
||
* Currently read from GGUF chat_template
|
||
* @param modelId
|
||
* @returns
|
||
*/
|
||
async isToolSupported(modelId: string): Promise<boolean> {
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelConfigPath = await joinPath([
|
||
this.providerPath,
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
// model option is required
|
||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||
const modelPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.model_path,
|
||
])
|
||
return (await readGgufMetadata(modelPath)).metadata?.[
|
||
'tokenizer.chat_template'
|
||
]?.includes('tools')
|
||
}
|
||
/**
|
||
* Get total system memory including both VRAM and RAM
|
||
*/
|
||
private async getTotalSystemMemory(): Promise<SystemMemory> {
|
||
const devices = await this.getDevices()
|
||
let totalVRAM = 0
|
||
|
||
if (devices.length > 0) {
|
||
// Sum total VRAM across all GPUs
|
||
totalVRAM = devices
|
||
.map((d) => d.mem * 1024 * 1024)
|
||
.reduce((a, b) => a + b, 0)
|
||
}
|
||
|
||
// Get system RAM
|
||
const sys = await getSystemUsage()
|
||
const totalRAM = sys.used_memory * 1024 * 1024
|
||
|
||
const totalMemory = totalVRAM + totalRAM
|
||
|
||
logger.info(
|
||
`Total VRAM: ${totalVRAM} bytes, Total RAM: ${totalRAM} bytes, Total Memory: ${totalMemory} bytes`
|
||
)
|
||
|
||
return {
|
||
totalVRAM,
|
||
totalRAM,
|
||
totalMemory,
|
||
}
|
||
}
|
||
|
||
private async getLayerSize(
|
||
path: string,
|
||
meta: Record<string, string>
|
||
): Promise<{ layerSize: number; totalLayers: number }> {
|
||
const modelSize = await getModelSize(path)
|
||
const arch = meta['general.architecture']
|
||
const totalLayers = Number(meta[`${arch}.block_count`]) + 2 // 1 for lm_head layer and 1 for embedding layer
|
||
if (!totalLayers) throw new Error('Invalid metadata: block_count not found')
|
||
return { layerSize: modelSize / totalLayers, totalLayers }
|
||
}
|
||
|
||
private isAbsolutePath(p: string): boolean {
|
||
// Normalize back‑slashes to forward‑slashes first.
|
||
const norm = p.replace(/\\/g, '/')
|
||
return (
|
||
norm.startsWith('/') || // POSIX absolute
|
||
/^[a-zA-Z]:/.test(norm) || // Drive‑letter Windows (C: or D:)
|
||
/^\/\/[^/]+/.test(norm) // UNC path //server/share
|
||
)
|
||
}
|
||
/*
|
||
* if (!this.isAbsolutePath(path))
|
||
path = await joinPath([await getJanDataFolderPath(), path])
|
||
if (mmprojPath && !this.isAbsolutePath(mmprojPath))
|
||
mmprojPath = await joinPath([await getJanDataFolderPath(), path])
|
||
*/
|
||
async planModelLoad(
|
||
path: string,
|
||
mmprojPath?: string,
|
||
requestedCtx?: number
|
||
): Promise<ModelPlan> {
|
||
if (!this.isAbsolutePath(path)) {
|
||
path = await joinPath([await getJanDataFolderPath(), path])
|
||
}
|
||
if (mmprojPath && !this.isAbsolutePath(mmprojPath))
|
||
mmprojPath = await joinPath([await getJanDataFolderPath(), path])
|
||
try {
|
||
const result = await planModelLoadInternal(
|
||
path,
|
||
this.memoryMode,
|
||
mmprojPath,
|
||
requestedCtx
|
||
)
|
||
return result
|
||
} catch (e) {
|
||
throw new Error(String(e))
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Check the support status of a model by its path (local/remote)
|
||
*
|
||
* Returns:
|
||
* - "RED" → weights don't fit in total memory
|
||
* - "YELLOW" → weights fit in VRAM but need system RAM, or KV cache doesn't fit
|
||
* - "GREEN" → both weights + KV cache fit in VRAM
|
||
*/
|
||
async isModelSupported(
|
||
path: string,
|
||
ctxSize?: number
|
||
): Promise<'RED' | 'YELLOW' | 'GREEN'> {
|
||
try {
|
||
const result = await isModelSupported(path, Number(ctxSize))
|
||
return result
|
||
} catch (e) {
|
||
throw new Error(String(e))
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Validate GGUF file and check for unsupported architectures like CLIP
|
||
*/
|
||
async validateGgufFile(filePath: string): Promise<{
|
||
isValid: boolean
|
||
error?: string
|
||
metadata?: any
|
||
}> {
|
||
try {
|
||
logger.info(`Validating GGUF file: ${filePath}`)
|
||
const metadata = await readGgufMetadata(filePath)
|
||
|
||
// Log full metadata for debugging
|
||
logger.info('Full GGUF metadata:', JSON.stringify(metadata, null, 2))
|
||
|
||
// Check if architecture is 'clip' which is not supported for text generation
|
||
const architecture = metadata.metadata?.['general.architecture']
|
||
logger.info(`Model architecture: ${architecture}`)
|
||
|
||
if (architecture === 'clip') {
|
||
const errorMessage =
|
||
'This model has CLIP architecture and cannot be imported as a text generation model. CLIP models are designed for vision tasks and require different handling.'
|
||
logger.error('CLIP architecture detected:', architecture)
|
||
return {
|
||
isValid: false,
|
||
error: errorMessage,
|
||
metadata,
|
||
}
|
||
}
|
||
|
||
logger.info('Model validation passed. Architecture:', architecture)
|
||
return {
|
||
isValid: true,
|
||
metadata,
|
||
}
|
||
} catch (error) {
|
||
logger.error('Failed to validate GGUF file:', error)
|
||
return {
|
||
isValid: false,
|
||
error: `Failed to read model metadata: ${
|
||
error instanceof Error ? error.message : 'Unknown error'
|
||
}`,
|
||
}
|
||
}
|
||
}
|
||
|
||
async getTokensCount(opts: chatCompletionRequest): Promise<number> {
|
||
const sessionInfo = await this.findSessionByModel(opts.model)
|
||
if (!sessionInfo) {
|
||
throw new Error(`No active session found for model: ${opts.model}`)
|
||
}
|
||
|
||
// Check if the process is alive
|
||
const result = await invoke<boolean>('plugin:llamacpp|is_process_running', {
|
||
pid: sessionInfo.pid,
|
||
})
|
||
if (result) {
|
||
try {
|
||
await fetch(`http://localhost:${sessionInfo.port}/health`)
|
||
} catch (e) {
|
||
this.unload(sessionInfo.model_id)
|
||
throw new Error('Model appears to have crashed! Please reload!')
|
||
}
|
||
} else {
|
||
throw new Error('Model has crashed! Please reload!')
|
||
}
|
||
|
||
const baseUrl = `http://localhost:${sessionInfo.port}`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sessionInfo.api_key}`,
|
||
}
|
||
|
||
// Count image tokens first
|
||
let imageTokens = 0
|
||
const hasImages = opts.messages.some(
|
||
(msg) =>
|
||
Array.isArray(msg.content) &&
|
||
msg.content.some((content) => content.type === 'image_url')
|
||
)
|
||
|
||
if (hasImages) {
|
||
logger.info('Conversation has images')
|
||
try {
|
||
// Read mmproj metadata to get vision parameters
|
||
logger.info(`MMPROJ PATH: ${sessionInfo.mmproj_path}`)
|
||
|
||
const metadata = await readGgufMetadata(sessionInfo.mmproj_path)
|
||
logger.info(`mmproj metadata: ${JSON.stringify(metadata.metadata)}`)
|
||
imageTokens = await this.calculateImageTokens(
|
||
opts.messages,
|
||
metadata.metadata
|
||
)
|
||
} catch (error) {
|
||
logger.warn('Failed to calculate image tokens:', error)
|
||
// Fallback to a rough estimate if metadata reading fails
|
||
imageTokens = this.estimateImageTokensFallback(opts.messages)
|
||
}
|
||
}
|
||
|
||
// Calculate text tokens
|
||
// Use chat_template_kwargs from opts if provided, otherwise default to disable enable_thinking
|
||
const tokenizeRequest = {
|
||
messages: opts.messages,
|
||
chat_template_kwargs: opts.chat_template_kwargs || {
|
||
enable_thinking: false,
|
||
},
|
||
}
|
||
|
||
let parseResponse = await fetch(`${baseUrl}/apply-template`, {
|
||
method: 'POST',
|
||
headers: headers,
|
||
body: JSON.stringify(tokenizeRequest),
|
||
})
|
||
|
||
if (!parseResponse.ok) {
|
||
const errorData = await parseResponse.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${
|
||
parseResponse.status
|
||
}: ${JSON.stringify(errorData)}`
|
||
)
|
||
}
|
||
|
||
const parsedPrompt = await parseResponse.json()
|
||
|
||
const response = await fetch(`${baseUrl}/tokenize`, {
|
||
method: 'POST',
|
||
headers: headers,
|
||
body: JSON.stringify({
|
||
content: parsedPrompt.prompt,
|
||
}),
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
const dataTokens = await response.json()
|
||
const textTokens = dataTokens.tokens?.length || 0
|
||
|
||
return textTokens + imageTokens
|
||
}
|
||
|
||
private async calculateImageTokens(
|
||
messages: chatCompletionRequestMessage[],
|
||
metadata: Record<string, string>
|
||
): Promise<number> {
|
||
// Extract vision parameters from metadata
|
||
const projectionDim =
|
||
Math.floor(Number(metadata['clip.vision.projection_dim']) / 10) || 256
|
||
|
||
// Count images in messages
|
||
let imageCount = 0
|
||
for (const message of messages) {
|
||
if (Array.isArray(message.content)) {
|
||
imageCount += message.content.filter(
|
||
(content) => content.type === 'image_url'
|
||
).length
|
||
}
|
||
}
|
||
|
||
logger.info(
|
||
`Calculated ${projectionDim} tokens per image, ${imageCount} images total`
|
||
)
|
||
return projectionDim * imageCount - imageCount // remove the lingering <__image__> placeholder token
|
||
}
|
||
|
||
private estimateImageTokensFallback(
|
||
messages: chatCompletionRequestMessage[]
|
||
): number {
|
||
// Fallback estimation if metadata reading fails
|
||
const estimatedTokensPerImage = 256 // Gemma's siglip
|
||
|
||
let imageCount = 0
|
||
for (const message of messages) {
|
||
if (Array.isArray(message.content)) {
|
||
imageCount += message.content.filter(
|
||
(content) => content.type === 'image_url'
|
||
).length
|
||
}
|
||
}
|
||
|
||
logger.warn(
|
||
`Fallback estimation: ${estimatedTokensPerImage} tokens per image, ${imageCount} images total`
|
||
)
|
||
return imageCount * estimatedTokensPerImage - imageCount // remove the lingering <__image__> placeholder token
|
||
}
|
||
}
|