* refactor: move session management & port allocation to backend
- Remove the in‑process `activeSessions` map and its cleanup logic from the TypeScript side.
- Introduce new Tauri commands in Rust:
- `get_random_port` – picks an unused port using a seeded RNG and checks availability.
- `find_session_by_model` – returns the `SessionInfo` for a given model ID.
- `get_loaded_models` – returns a list of currently loaded model IDs.
- Update the extension’s TypeScript code to use these commands via `invoke`:
- `findSessionByModel`, `load`, `unload`, `chat`, `getLoadedModels`, and `embed` now operate asynchronously and query the backend.
- Remove the old `is_port_available` command and the custom port‑checking loop.
- Simplify `onUnload` – session termination is now handled by the backend.
- Drop unused helpers (`sleep`, `waitForModelLoad`) and related port‑availability code.
- Add missing Rust imports (`rand::{StdRng,Rng,SeedableRng}`, `HashSet`) and improve error handling.
- Register the new commands in `src-tauri/src/lib.rs` (replace `is_port_available` with the three new commands).
This refactor centralises session state and port allocation in the Rust backend, eliminates duplicated logic, and resolves race conditions around model loading and session cleanup.
* Use String(e) for error
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
---------
Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com>
1594 lines
50 KiB
TypeScript
1594 lines
50 KiB
TypeScript
/**
|
||
* @file This file exports a class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
* @version 1.0.0
|
||
* @module llamacpp-extension/src/index
|
||
*/
|
||
|
||
import {
|
||
AIEngine,
|
||
getJanDataFolderPath,
|
||
fs,
|
||
joinPath,
|
||
modelInfo,
|
||
SessionInfo,
|
||
UnloadResult,
|
||
chatCompletion,
|
||
chatCompletionChunk,
|
||
ImportOptions,
|
||
chatCompletionRequest,
|
||
events,
|
||
} from '@janhq/core'
|
||
|
||
import { error, info, warn } from '@tauri-apps/plugin-log'
|
||
|
||
import {
|
||
listSupportedBackends,
|
||
downloadBackend,
|
||
isBackendInstalled,
|
||
getBackendExePath,
|
||
} from './backend'
|
||
import { invoke } from '@tauri-apps/api/core'
|
||
import { getProxyConfig } from './util'
|
||
import { basename } from '@tauri-apps/api/path'
|
||
|
||
type LlamacppConfig = {
|
||
version_backend: string
|
||
auto_update_engine: boolean
|
||
auto_unload: boolean
|
||
chat_template: string
|
||
n_gpu_layers: number
|
||
override_tensor_buffer_t: string
|
||
ctx_size: number
|
||
threads: number
|
||
threads_batch: number
|
||
n_predict: number
|
||
batch_size: number
|
||
ubatch_size: number
|
||
device: string
|
||
split_mode: string
|
||
main_gpu: number
|
||
flash_attn: boolean
|
||
cont_batching: boolean
|
||
no_mmap: boolean
|
||
mlock: boolean
|
||
no_kv_offload: boolean
|
||
cache_type_k: string
|
||
cache_type_v: string
|
||
defrag_thold: number
|
||
rope_scaling: string
|
||
rope_scale: number
|
||
rope_freq_base: number
|
||
rope_freq_scale: number
|
||
ctx_shift: boolean
|
||
}
|
||
|
||
interface DownloadItem {
|
||
url: string
|
||
save_path: string
|
||
proxy?: Record<string, string | string[] | boolean>
|
||
}
|
||
|
||
interface ModelConfig {
|
||
model_path: string
|
||
mmproj_path?: string
|
||
name: string // user-friendly
|
||
// some model info that we cache upon import
|
||
size_bytes: number
|
||
}
|
||
|
||
interface EmbeddingResponse {
|
||
model: string
|
||
object: string
|
||
usage: {
|
||
prompt_tokens: number
|
||
total_tokens: number
|
||
}
|
||
data: EmbeddingData[]
|
||
}
|
||
|
||
interface EmbeddingData {
|
||
embedding: number[]
|
||
index: number
|
||
object: string
|
||
}
|
||
|
||
interface DeviceList {
|
||
id: string
|
||
name: string
|
||
mem: number
|
||
free: number
|
||
}
|
||
/**
|
||
* Override the default app.log function to use Jan's logging system.
|
||
* @param args
|
||
*/
|
||
const logger = {
|
||
info: function (...args: any[]) {
|
||
console.log(...args)
|
||
info(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
warn: function (...args: any[]) {
|
||
console.warn(...args)
|
||
warn(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
error: function (...args: any[]) {
|
||
console.error(...args)
|
||
error(args.map((arg) => ` ${arg}`).join(` `))
|
||
},
|
||
}
|
||
|
||
/**
|
||
* A class that implements the InferenceExtension interface from the @janhq/core package.
|
||
* The class provides methods for initializing and stopping a model, and for making inference requests.
|
||
* It also subscribes to events emitted by the @janhq/core package and handles new message requests.
|
||
*/
|
||
|
||
// Folder structure for llamacpp extension:
|
||
// <Jan's data folder>/llamacpp
|
||
// - models/<modelId>/
|
||
// - model.yml (required)
|
||
// - model.gguf (optional, present if downloaded from URL)
|
||
// - mmproj.gguf (optional, present if mmproj exists and it was downloaded from URL)
|
||
// Contents of model.yml can be found in ModelConfig interface
|
||
//
|
||
// - backends/<backend_version>/<backend_type>/
|
||
// - build/bin/llama-server (or llama-server.exe on Windows)
|
||
//
|
||
// - lib/
|
||
// - e.g. libcudart.so.12
|
||
|
||
export default class llamacpp_extension extends AIEngine {
|
||
provider: string = 'llamacpp'
|
||
autoUnload: boolean = true
|
||
readonly providerId: string = 'llamacpp'
|
||
|
||
private config: LlamacppConfig
|
||
private providerPath!: string
|
||
private apiSecret: string = 'JustAskNow'
|
||
private pendingDownloads: Map<string, Promise<void>> = new Map()
|
||
private isConfiguringBackends: boolean = false
|
||
private loadingModels = new Map<string, Promise<SessionInfo>>() // Track loading promises
|
||
|
||
override async onLoad(): Promise<void> {
|
||
super.onLoad() // Calls registerEngine() from AIEngine
|
||
|
||
let settings = structuredClone(SETTINGS) // Clone to modify settings definition before registration
|
||
|
||
// This makes the settings (including the backend options and initial value) available to the Jan UI.
|
||
this.registerSettings(settings)
|
||
|
||
let loadedConfig: any = {}
|
||
for (const item of settings) {
|
||
const defaultValue = item.controllerProps.value
|
||
// Use the potentially updated default value from the settings array as the fallback for getSetting
|
||
loadedConfig[item.key] = await this.getSetting<typeof defaultValue>(
|
||
item.key,
|
||
defaultValue
|
||
)
|
||
}
|
||
this.config = loadedConfig as LlamacppConfig
|
||
|
||
this.autoUnload = this.config.auto_unload
|
||
|
||
// This sets the base directory where model files for this provider are stored.
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
this.configureBackends()
|
||
}
|
||
|
||
private getStoredBackendType(): string | null {
|
||
try {
|
||
return localStorage.getItem('llama_cpp_backend_type')
|
||
} catch (error) {
|
||
logger.warn('Failed to read backend type from localStorage:', error)
|
||
return null
|
||
}
|
||
}
|
||
|
||
private setStoredBackendType(backendType: string): void {
|
||
try {
|
||
localStorage.setItem('llama_cpp_backend_type', backendType)
|
||
logger.info(`Stored backend type preference: ${backendType}`)
|
||
} catch (error) {
|
||
logger.warn('Failed to store backend type in localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private clearStoredBackendType(): void {
|
||
try {
|
||
localStorage.removeItem('llama_cpp_backend_type')
|
||
logger.info('Cleared stored backend type preference')
|
||
} catch (error) {
|
||
logger.warn('Failed to clear backend type from localStorage:', error)
|
||
}
|
||
}
|
||
|
||
private findLatestVersionForBackend(
|
||
version_backends: { version: string; backend: string }[],
|
||
backendType: string
|
||
): string | null {
|
||
const matchingBackends = version_backends.filter(
|
||
(vb) => vb.backend === backendType
|
||
)
|
||
if (matchingBackends.length === 0) {
|
||
return null
|
||
}
|
||
|
||
// Sort by version (newest first) and get the latest
|
||
matchingBackends.sort((a, b) => b.version.localeCompare(a.version))
|
||
return `${matchingBackends[0].version}/${matchingBackends[0].backend}`
|
||
}
|
||
|
||
async configureBackends(): Promise<void> {
|
||
if (this.isConfiguringBackends) {
|
||
logger.info(
|
||
'configureBackends already in progress, skipping duplicate call'
|
||
)
|
||
return
|
||
}
|
||
|
||
this.isConfiguringBackends = true
|
||
|
||
try {
|
||
let version_backends: { version: string; backend: string }[] = []
|
||
|
||
try {
|
||
version_backends = await listSupportedBackends()
|
||
if (version_backends.length === 0) {
|
||
throw new Error(
|
||
'No supported backend binaries found for this system. Backend selection and auto-update will be unavailable.'
|
||
)
|
||
} else {
|
||
version_backends.sort((a, b) => b.version.localeCompare(a.version))
|
||
}
|
||
} catch (error) {
|
||
throw new Error(
|
||
`Failed to fetch supported backends: ${
|
||
error instanceof Error ? error.message : error
|
||
}`
|
||
)
|
||
}
|
||
|
||
// Get stored backend preference
|
||
const storedBackendType = this.getStoredBackendType()
|
||
let bestAvailableBackendString = ''
|
||
|
||
if (storedBackendType) {
|
||
// Find the latest version of the stored backend type
|
||
const preferredBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
storedBackendType
|
||
)
|
||
if (preferredBackendString) {
|
||
bestAvailableBackendString = preferredBackendString
|
||
logger.info(
|
||
`Using stored backend preference: ${bestAvailableBackendString}`
|
||
)
|
||
} else {
|
||
logger.warn(
|
||
`Stored backend type '${storedBackendType}' not available, falling back to best backend`
|
||
)
|
||
// Clear the invalid stored preference
|
||
this.clearStoredBackendType()
|
||
bestAvailableBackendString =
|
||
this.determineBestBackend(version_backends)
|
||
}
|
||
} else {
|
||
bestAvailableBackendString = this.determineBestBackend(version_backends)
|
||
}
|
||
|
||
let settings = structuredClone(SETTINGS)
|
||
const backendSettingIndex = settings.findIndex(
|
||
(item) => item.key === 'version_backend'
|
||
)
|
||
|
||
let originalDefaultBackendValue = ''
|
||
if (backendSettingIndex !== -1) {
|
||
const backendSetting = settings[backendSettingIndex]
|
||
originalDefaultBackendValue = backendSetting.controllerProps
|
||
.value as string
|
||
|
||
backendSetting.controllerProps.options = version_backends.map((b) => {
|
||
const key = `${b.version}/${b.backend}`
|
||
return { value: key, name: key }
|
||
})
|
||
|
||
// Set the recommended backend based on bestAvailableBackendString
|
||
if (bestAvailableBackendString) {
|
||
backendSetting.controllerProps.recommended =
|
||
bestAvailableBackendString
|
||
}
|
||
|
||
const savedBackendSetting = await this.getSetting<string>(
|
||
'version_backend',
|
||
originalDefaultBackendValue
|
||
)
|
||
|
||
// Determine initial UI default based on priority:
|
||
// 1. Saved setting (if valid and not original default)
|
||
// 2. Best available for stored backend type
|
||
// 3. Original default
|
||
let initialUiDefault = originalDefaultBackendValue
|
||
|
||
if (
|
||
savedBackendSetting &&
|
||
savedBackendSetting !== originalDefaultBackendValue
|
||
) {
|
||
initialUiDefault = savedBackendSetting
|
||
// Store the backend type from the saved setting only if different
|
||
const [, backendType] = savedBackendSetting.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from saved setting: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
} else if (bestAvailableBackendString) {
|
||
initialUiDefault = bestAvailableBackendString
|
||
// Store the backend type from the best available only if different
|
||
const [, backendType] = bestAvailableBackendString.split('/')
|
||
if (backendType) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backendType) {
|
||
this.setStoredBackendType(backendType)
|
||
logger.info(
|
||
`Stored backend type preference from best available: ${backendType}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
backendSetting.controllerProps.value = initialUiDefault
|
||
logger.info(
|
||
`Initial UI default for version_backend set to: ${initialUiDefault}`
|
||
)
|
||
} else {
|
||
logger.error(
|
||
'Critical setting "version_backend" definition not found in SETTINGS.'
|
||
)
|
||
throw new Error('Critical setting "version_backend" not found.')
|
||
}
|
||
|
||
this.registerSettings(settings)
|
||
|
||
let effectiveBackendString = this.config.version_backend
|
||
let backendWasDownloaded = false
|
||
|
||
// Handle fresh installation case where version_backend might be 'none' or invalid
|
||
if (
|
||
(!effectiveBackendString ||
|
||
effectiveBackendString === 'none' ||
|
||
!effectiveBackendString.includes('/') ||
|
||
// If the selected backend is not in the list of supported backends
|
||
// Need to reset too
|
||
!version_backends.some(
|
||
(e) => `${e.version}/${e.backend}` === effectiveBackendString
|
||
)) &&
|
||
// Ensure we have a valid best available backend
|
||
bestAvailableBackendString
|
||
) {
|
||
effectiveBackendString = bestAvailableBackendString
|
||
logger.info(
|
||
`Fresh installation or invalid backend detected, using: ${effectiveBackendString}`
|
||
)
|
||
|
||
// Update the config immediately
|
||
this.config.version_backend = effectiveBackendString
|
||
|
||
// Update the settings to reflect the change in UI
|
||
const updatedSettings = await this.getSettings()
|
||
await this.updateSettings(
|
||
updatedSettings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = effectiveBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
logger.info(`Updated UI settings to show: ${effectiveBackendString}`)
|
||
|
||
// Emit for updating fe
|
||
if (events && typeof events.emit === 'function') {
|
||
logger.info(
|
||
`Emitting settingsChanged event for version_backend with value: ${effectiveBackendString}`
|
||
)
|
||
events.emit('settingsChanged', {
|
||
key: 'version_backend',
|
||
value: effectiveBackendString,
|
||
})
|
||
}
|
||
}
|
||
|
||
// Download and install the backend if not already present
|
||
if (effectiveBackendString) {
|
||
const [version, backend] = effectiveBackendString.split('/')
|
||
if (version && backend) {
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (!isInstalled) {
|
||
logger.info(`Installing initial backend: ${effectiveBackendString}`)
|
||
await this.ensureBackendReady(backend, version)
|
||
backendWasDownloaded = true
|
||
logger.info(
|
||
`Successfully installed initial backend: ${effectiveBackendString}`
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
if (this.config.auto_update_engine) {
|
||
const updateResult = await this.handleAutoUpdate(
|
||
bestAvailableBackendString
|
||
)
|
||
if (updateResult.wasUpdated) {
|
||
effectiveBackendString = updateResult.newBackend
|
||
backendWasDownloaded = true
|
||
}
|
||
}
|
||
|
||
if (!backendWasDownloaded && effectiveBackendString) {
|
||
await this.ensureFinalBackendInstallation(effectiveBackendString)
|
||
}
|
||
} finally {
|
||
this.isConfiguringBackends = false
|
||
}
|
||
}
|
||
|
||
private determineBestBackend(
|
||
version_backends: { version: string; backend: string }[]
|
||
): string {
|
||
if (version_backends.length === 0) return ''
|
||
|
||
// Priority list for backend types (more specific/performant ones first)
|
||
const backendPriorities: string[] = [
|
||
'cuda-cu12.0',
|
||
'cuda-cu11.7',
|
||
'vulkan',
|
||
'avx512',
|
||
'avx2',
|
||
'avx',
|
||
'noavx',
|
||
'arm64',
|
||
'x64',
|
||
]
|
||
|
||
// Helper to map backend string to a priority category
|
||
const getBackendCategory = (backendString: string): string | undefined => {
|
||
if (backendString.includes('cu12.0')) return 'cuda-cu12.0'
|
||
if (backendString.includes('cu11.7')) return 'cuda-cu11.7'
|
||
if (backendString.includes('vulkan')) return 'vulkan'
|
||
if (backendString.includes('avx512')) return 'avx512'
|
||
if (backendString.includes('avx2')) return 'avx2'
|
||
if (
|
||
backendString.includes('avx') &&
|
||
!backendString.includes('avx2') &&
|
||
!backendString.includes('avx512')
|
||
)
|
||
return 'avx'
|
||
if (backendString.includes('noavx')) return 'noavx'
|
||
if (backendString.endsWith('arm64')) return 'arm64'
|
||
if (backendString.endsWith('x64')) return 'x64'
|
||
return undefined
|
||
}
|
||
|
||
let foundBestBackend: { version: string; backend: string } | undefined
|
||
for (const priorityCategory of backendPriorities) {
|
||
const matchingBackends = version_backends.filter((vb) => {
|
||
const category = getBackendCategory(vb.backend)
|
||
return category === priorityCategory
|
||
})
|
||
|
||
if (matchingBackends.length > 0) {
|
||
foundBestBackend = matchingBackends[0]
|
||
logger.info(
|
||
`Determined best available backend: ${foundBestBackend.version}/${foundBestBackend.backend} (Category: "${priorityCategory}")`
|
||
)
|
||
break
|
||
}
|
||
}
|
||
|
||
if (foundBestBackend) {
|
||
return `${foundBestBackend.version}/${foundBestBackend.backend}`
|
||
} else {
|
||
// Fallback to newest version
|
||
return `${version_backends[0].version}/${version_backends[0].backend}`
|
||
}
|
||
}
|
||
|
||
private async handleAutoUpdate(
|
||
bestAvailableBackendString: string
|
||
): Promise<{ wasUpdated: boolean; newBackend: string }> {
|
||
logger.info(
|
||
`Auto-update engine is enabled. Current backend: ${this.config.version_backend}. Best available: ${bestAvailableBackendString}`
|
||
)
|
||
|
||
if (!bestAvailableBackendString) {
|
||
logger.warn(
|
||
'Auto-update enabled, but no best available backend determined'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// If version_backend is empty, invalid, or 'none', use the best available backend
|
||
if (
|
||
!this.config.version_backend ||
|
||
this.config.version_backend === '' ||
|
||
this.config.version_backend === 'none' ||
|
||
!this.config.version_backend.includes('/')
|
||
) {
|
||
logger.info(
|
||
'No valid backend currently selected, using best available backend'
|
||
)
|
||
try {
|
||
const [bestVersion, bestBackend] = bestAvailableBackendString.split('/')
|
||
|
||
// Download new backend
|
||
await this.ensureBackendReady(bestBackend, bestVersion)
|
||
|
||
// Add delay on Windows
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||
}
|
||
|
||
// Update configuration
|
||
this.config.version_backend = bestAvailableBackendString
|
||
|
||
// Store the backend type preference only if it changed
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== bestBackend) {
|
||
this.setStoredBackendType(bestBackend)
|
||
logger.info(`Stored new backend type preference: ${bestBackend}`)
|
||
}
|
||
|
||
// Update settings
|
||
const settings = await this.getSettings()
|
||
await this.updateSettings(
|
||
settings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = bestAvailableBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
|
||
logger.info(
|
||
`Successfully set initial backend: ${bestAvailableBackendString}`
|
||
)
|
||
return { wasUpdated: true, newBackend: bestAvailableBackendString }
|
||
} catch (error) {
|
||
logger.error('Failed to set initial backend:', error)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
}
|
||
|
||
// Parse current backend configuration
|
||
const [currentVersion, currentBackend] = (
|
||
this.config.version_backend || ''
|
||
).split('/')
|
||
|
||
if (!currentVersion || !currentBackend) {
|
||
logger.warn(
|
||
`Invalid current backend format: ${this.config.version_backend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Find the latest version for the currently selected backend type
|
||
const version_backends = await listSupportedBackends()
|
||
const targetBackendString = this.findLatestVersionForBackend(
|
||
version_backends,
|
||
currentBackend
|
||
)
|
||
|
||
if (!targetBackendString) {
|
||
logger.warn(
|
||
`No available versions found for current backend type: ${currentBackend}`
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
const [latestVersion] = targetBackendString.split('/')
|
||
|
||
// Check if update is needed (only version comparison for same backend type)
|
||
if (currentVersion === latestVersion) {
|
||
logger.info(
|
||
'Auto-update: Already using the latest version of the selected backend'
|
||
)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
|
||
// Perform version update for the same backend type
|
||
try {
|
||
logger.info(
|
||
`Auto-updating from ${this.config.version_backend} to ${targetBackendString} (preserving backend type)`
|
||
)
|
||
|
||
// Download new version of the same backend type
|
||
await this.ensureBackendReady(currentBackend, latestVersion)
|
||
|
||
// Add delay on Windows
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 1000))
|
||
}
|
||
|
||
// Update configuration
|
||
this.config.version_backend = targetBackendString
|
||
|
||
// Update stored backend type preference only if it changed
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== currentBackend) {
|
||
this.setStoredBackendType(currentBackend)
|
||
logger.info(`Updated stored backend type preference: ${currentBackend}`)
|
||
}
|
||
|
||
// Update settings
|
||
const settings = await this.getSettings()
|
||
await this.updateSettings(
|
||
settings.map((item) => {
|
||
if (item.key === 'version_backend') {
|
||
item.controllerProps.value = targetBackendString
|
||
}
|
||
return item
|
||
})
|
||
)
|
||
|
||
logger.info(
|
||
`Successfully updated to backend: ${targetBackendString} (preserved backend type: ${currentBackend})`
|
||
)
|
||
|
||
// Emit for updating fe
|
||
if (events && typeof events.emit === 'function') {
|
||
logger.info(
|
||
`Emitting settingsChanged event for version_backend with value: ${targetBackendString}`
|
||
)
|
||
events.emit('settingsChanged', {
|
||
key: 'version_backend',
|
||
value: targetBackendString,
|
||
})
|
||
}
|
||
|
||
// Clean up old versions of the same backend type
|
||
if (IS_WINDOWS) {
|
||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||
}
|
||
await this.removeOldBackend(latestVersion, currentBackend)
|
||
|
||
return { wasUpdated: true, newBackend: targetBackendString }
|
||
} catch (error) {
|
||
logger.error('Auto-update failed:', error)
|
||
return { wasUpdated: false, newBackend: this.config.version_backend }
|
||
}
|
||
}
|
||
|
||
private async removeOldBackend(
|
||
latestVersion: string,
|
||
backendType: string
|
||
): Promise<void> {
|
||
try {
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const backendsDir = await joinPath([
|
||
janDataFolderPath,
|
||
'llamacpp',
|
||
'backends',
|
||
])
|
||
|
||
if (!(await fs.existsSync(backendsDir))) {
|
||
return
|
||
}
|
||
|
||
const versionDirs = await fs.readdirSync(backendsDir)
|
||
|
||
for (const versionDir of versionDirs) {
|
||
const versionPath = await joinPath([backendsDir, versionDir])
|
||
const versionName = await basename(versionDir)
|
||
|
||
// Skip the latest version
|
||
if (versionName === latestVersion) {
|
||
continue
|
||
}
|
||
|
||
// Check if this version has the specific backend type we're interested in
|
||
const backendTypePath = await joinPath([versionPath, backendType])
|
||
|
||
if (await fs.existsSync(backendTypePath)) {
|
||
const isInstalled = await isBackendInstalled(backendType, versionName)
|
||
if (isInstalled) {
|
||
try {
|
||
await fs.rm(backendTypePath)
|
||
logger.info(
|
||
`Removed old version of ${backendType}: ${backendTypePath}`
|
||
)
|
||
} catch (e) {
|
||
logger.warn(
|
||
`Failed to remove old backend version: ${backendTypePath}`,
|
||
e
|
||
)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} catch (error) {
|
||
logger.error('Error during old backend version cleanup:', error)
|
||
}
|
||
}
|
||
|
||
private async ensureFinalBackendInstallation(
|
||
backendString: string
|
||
): Promise<void> {
|
||
if (!backendString) {
|
||
logger.warn('No backend specified for final installation check')
|
||
return
|
||
}
|
||
|
||
const [selectedVersion, selectedBackend] = backendString
|
||
.split('/')
|
||
.map((part) => part?.trim())
|
||
|
||
if (!selectedVersion || !selectedBackend) {
|
||
logger.warn(`Invalid backend format: ${backendString}`)
|
||
return
|
||
}
|
||
|
||
try {
|
||
const isInstalled = await isBackendInstalled(
|
||
selectedBackend,
|
||
selectedVersion
|
||
)
|
||
if (!isInstalled) {
|
||
logger.info(`Final check: Installing backend ${backendString}`)
|
||
await this.ensureBackendReady(selectedBackend, selectedVersion)
|
||
logger.info(`Successfully installed backend: ${backendString}`)
|
||
} else {
|
||
logger.info(
|
||
`Final check: Backend ${backendString} is already installed`
|
||
)
|
||
}
|
||
} catch (error) {
|
||
logger.error(
|
||
`Failed to ensure backend ${backendString} installation:`,
|
||
error
|
||
)
|
||
throw error // Re-throw as this is critical
|
||
}
|
||
}
|
||
|
||
async getProviderPath(): Promise<string> {
|
||
if (!this.providerPath) {
|
||
this.providerPath = await joinPath([
|
||
await getJanDataFolderPath(),
|
||
this.providerId,
|
||
])
|
||
}
|
||
return this.providerPath
|
||
}
|
||
|
||
override async onUnload(): Promise<void> {
|
||
// Terminate all active sessions
|
||
}
|
||
|
||
onSettingUpdate<T>(key: string, value: T): void {
|
||
this.config[key] = value
|
||
|
||
if (key === 'version_backend') {
|
||
const valueStr = value as string
|
||
const [version, backend] = valueStr.split('/')
|
||
|
||
// Store the backend type preference in localStorage only if it changed
|
||
if (backend) {
|
||
const currentStoredBackend = this.getStoredBackendType()
|
||
if (currentStoredBackend !== backend) {
|
||
this.setStoredBackendType(backend)
|
||
logger.info(`Updated backend type preference to: ${backend}`)
|
||
}
|
||
}
|
||
|
||
// Reset device setting when backend changes
|
||
this.config.device = ''
|
||
|
||
const closure = async () => {
|
||
await this.ensureBackendReady(backend, version)
|
||
}
|
||
closure()
|
||
} else if (key === 'auto_unload') {
|
||
this.autoUnload = value as boolean
|
||
}
|
||
}
|
||
|
||
private async generateApiKey(modelId: string, port: string): Promise<string> {
|
||
const hash = await invoke<string>('generate_api_key', {
|
||
modelId: modelId + port,
|
||
apiSecret: this.apiSecret,
|
||
})
|
||
return hash
|
||
}
|
||
|
||
// Implement the required LocalProvider interface methods
|
||
override async list(): Promise<modelInfo[]> {
|
||
const modelsDir = await joinPath([await this.getProviderPath(), 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) {
|
||
await fs.mkdir(modelsDir)
|
||
}
|
||
|
||
await this.migrateLegacyModels()
|
||
|
||
let modelIds: string[] = []
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
// check if model.yml exists
|
||
const modelConfigPath = await joinPath([currentDir, 'model.yml'])
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
modelIds.push(currentDir.slice(modelsDir.length + 1))
|
||
continue
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
|
||
let modelInfos: modelInfo[] = []
|
||
for (const modelId of modelIds) {
|
||
const path = await joinPath([modelsDir, modelId, 'model.yml'])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', { path })
|
||
|
||
const modelInfo = {
|
||
id: modelId,
|
||
name: modelConfig.name ?? modelId,
|
||
quant_type: undefined, // TODO: parse quantization type from model.yml or model.gguf
|
||
providerId: this.provider,
|
||
port: 0, // port is not known until the model is loaded
|
||
sizeBytes: modelConfig.size_bytes ?? 0,
|
||
} as modelInfo
|
||
modelInfos.push(modelInfo)
|
||
}
|
||
|
||
return modelInfos
|
||
}
|
||
|
||
private async migrateLegacyModels() {
|
||
// Attempt to migrate only once
|
||
if (localStorage.getItem('cortex_models_migrated') === 'true') return
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelsDir = await joinPath([janDataFolderPath, 'models'])
|
||
if (!(await fs.existsSync(modelsDir))) return
|
||
|
||
// DFS
|
||
let stack = [modelsDir]
|
||
while (stack.length > 0) {
|
||
const currentDir = stack.pop()
|
||
|
||
const files = await fs.readdirSync(currentDir)
|
||
for (const child of files) {
|
||
try {
|
||
const childPath = await joinPath([currentDir, child])
|
||
const stat = await fs.fileStat(childPath)
|
||
if (
|
||
files.some((e) => e.endsWith('model.yml')) &&
|
||
!child.endsWith('model.yml')
|
||
)
|
||
continue
|
||
if (!stat.isDirectory && child.endsWith('.yml')) {
|
||
// check if model.yml exists
|
||
const modelConfigPath = child
|
||
if (await fs.existsSync(modelConfigPath)) {
|
||
const legacyModelConfig = await invoke<{
|
||
files: string[]
|
||
model: string
|
||
}>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const legacyModelPath = legacyModelConfig.files?.[0]
|
||
if (!legacyModelPath) continue
|
||
// +1 to remove the leading slash
|
||
// NOTE: this does not handle Windows path \\
|
||
let modelId = currentDir.slice(modelsDir.length + 1)
|
||
|
||
modelId =
|
||
modelId !== 'imported'
|
||
? modelId.replace(/^(cortex\.so|huggingface\.co)[\/\\]/, '')
|
||
: (await basename(child)).replace('.yml', '')
|
||
|
||
const modelName = legacyModelConfig.model ?? modelId
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath)) continue // Don't reimport
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
let size_bytes = (
|
||
await fs.fileStat(
|
||
await joinPath([janDataFolderPath, legacyModelPath])
|
||
)
|
||
).size
|
||
|
||
const modelConfig = {
|
||
model_path: legacyModelPath,
|
||
mmproj_path: undefined, // legacy models do not have mmproj
|
||
name: modelName,
|
||
size_bytes,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
continue
|
||
}
|
||
}
|
||
} catch (error) {
|
||
console.error(`Error migrating model ${child}:`, error)
|
||
}
|
||
}
|
||
|
||
// otherwise, look into subdirectories
|
||
const children = await fs.readdirSync(currentDir)
|
||
for (const child of children) {
|
||
// skip files
|
||
const dirInfo = await fs.fileStat(child)
|
||
if (!dirInfo.isDirectory) {
|
||
continue
|
||
}
|
||
|
||
stack.push(child)
|
||
}
|
||
}
|
||
localStorage.setItem('cortex_models_migrated', 'true')
|
||
}
|
||
|
||
override async import(modelId: string, opts: ImportOptions): Promise<void> {
|
||
const isValidModelId = (id: string) => {
|
||
// only allow alphanumeric, underscore, hyphen, and dot characters in modelId
|
||
if (!/^[a-zA-Z0-9/_\-\.]+$/.test(id)) return false
|
||
|
||
// check for empty parts or path traversal
|
||
const parts = id.split('/')
|
||
return parts.every((s) => s !== '' && s !== '.' && s !== '..')
|
||
}
|
||
|
||
if (!isValidModelId(modelId))
|
||
throw new Error(
|
||
`Invalid modelId: ${modelId}. Only alphanumeric and / _ - . characters are allowed.`
|
||
)
|
||
|
||
const configPath = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
if (await fs.existsSync(configPath))
|
||
throw new Error(`Model ${modelId} already exists`)
|
||
|
||
// this is relative to Jan's data folder
|
||
const modelDir = `${this.providerId}/models/${modelId}`
|
||
|
||
// we only use these from opts
|
||
// opts.modelPath: URL to the model file
|
||
// opts.mmprojPath: URL to the mmproj file
|
||
|
||
let downloadItems: DownloadItem[] = []
|
||
|
||
const maybeDownload = async (path: string, saveName: string) => {
|
||
// if URL, add to downloadItems, and return local path
|
||
if (path.startsWith('https://')) {
|
||
const localPath = `${modelDir}/${saveName}`
|
||
downloadItems.push({
|
||
url: path,
|
||
save_path: localPath,
|
||
proxy: getProxyConfig(),
|
||
})
|
||
return localPath
|
||
}
|
||
|
||
// if local file (absolute path), check if it exists
|
||
// and return the path
|
||
if (!(await fs.existsSync(path)))
|
||
throw new Error(`File not found: ${path}`)
|
||
return path
|
||
}
|
||
|
||
let modelPath = await maybeDownload(opts.modelPath, 'model.gguf')
|
||
let mmprojPath = opts.mmprojPath
|
||
? await maybeDownload(opts.mmprojPath, 'mmproj.gguf')
|
||
: undefined
|
||
|
||
if (downloadItems.length > 0) {
|
||
let downloadCompleted = false
|
||
|
||
try {
|
||
// emit download update event on progress
|
||
const onProgress = (transferred: number, total: number) => {
|
||
events.emit('onFileDownloadUpdate', {
|
||
modelId,
|
||
percent: transferred / total,
|
||
size: { transferred, total },
|
||
downloadType: 'Model',
|
||
})
|
||
downloadCompleted = transferred === total
|
||
}
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
await downloadManager.downloadFiles(
|
||
downloadItems,
|
||
this.createDownloadTaskId(modelId),
|
||
onProgress
|
||
)
|
||
|
||
const eventName = downloadCompleted
|
||
? 'onFileDownloadSuccess'
|
||
: 'onFileDownloadStopped'
|
||
events.emit(eventName, { modelId, downloadType: 'Model' })
|
||
} catch (error) {
|
||
logger.error('Error downloading model:', modelId, opts, error)
|
||
events.emit('onFileDownloadError', { modelId, downloadType: 'Model' })
|
||
throw error
|
||
}
|
||
}
|
||
|
||
// TODO: check if files are valid GGUF files
|
||
// NOTE: modelPath and mmprojPath can be either relative to Jan's data folder (if they are downloaded)
|
||
// or absolute paths (if they are provided as local files)
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
let size_bytes = (
|
||
await fs.fileStat(await joinPath([janDataFolderPath, modelPath]))
|
||
).size
|
||
if (mmprojPath) {
|
||
size_bytes += (
|
||
await fs.fileStat(await joinPath([janDataFolderPath, mmprojPath]))
|
||
).size
|
||
}
|
||
|
||
// TODO: add name as import() argument
|
||
// TODO: add updateModelConfig() method
|
||
const modelConfig = {
|
||
model_path: modelPath,
|
||
mmproj_path: mmprojPath,
|
||
name: modelId,
|
||
size_bytes,
|
||
} as ModelConfig
|
||
await fs.mkdir(await joinPath([janDataFolderPath, modelDir]))
|
||
await invoke<void>('write_yaml', {
|
||
data: modelConfig,
|
||
savePath: configPath,
|
||
})
|
||
}
|
||
|
||
override async abortImport(modelId: string): Promise<void> {
|
||
// prepand provider name to avoid name collision
|
||
const taskId = this.createDownloadTaskId(modelId)
|
||
const downloadManager = window.core.extensionManager.getByName(
|
||
'@janhq/download-extension'
|
||
)
|
||
await downloadManager.cancelDownload(taskId)
|
||
}
|
||
|
||
/**
|
||
* Function to find a random port
|
||
*/
|
||
private async getRandomPort(): Promise<number> {
|
||
try {
|
||
const port = await invoke<number>('get_random_port')
|
||
return port
|
||
} catch {
|
||
logger.error('Unable to find a suitable port')
|
||
throw new Error('Unable to find a suitable port for model')
|
||
}
|
||
}
|
||
|
||
override async load(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const sInfo = await this.findSessionByModel(modelId)
|
||
if (sInfo) {
|
||
throw new Error('Model already loaded!!')
|
||
}
|
||
|
||
// If this model is already being loaded, return the existing promise
|
||
if (this.loadingModels.has(modelId)) {
|
||
return this.loadingModels.get(modelId)!
|
||
}
|
||
|
||
// Create the loading promise
|
||
const loadingPromise = this.performLoad(
|
||
modelId,
|
||
overrideSettings,
|
||
isEmbedding
|
||
)
|
||
this.loadingModels.set(modelId, loadingPromise)
|
||
|
||
try {
|
||
const result = await loadingPromise
|
||
return result
|
||
} finally {
|
||
this.loadingModels.delete(modelId)
|
||
}
|
||
}
|
||
|
||
private async performLoad(
|
||
modelId: string,
|
||
overrideSettings?: Partial<LlamacppConfig>,
|
||
isEmbedding: boolean = false
|
||
): Promise<SessionInfo> {
|
||
const loadedModels = await this.getLoadedModels()
|
||
|
||
// Get OTHER models that are currently loading (exclude current model)
|
||
const otherLoadingPromises = Array.from(this.loadingModels.entries())
|
||
.filter(([id, _]) => id !== modelId)
|
||
.map(([_, promise]) => promise)
|
||
|
||
if (
|
||
this.autoUnload &&
|
||
(loadedModels.length > 0 || otherLoadingPromises.length > 0)
|
||
) {
|
||
// Wait for OTHER loading models to finish, then unload everything
|
||
if (otherLoadingPromises.length > 0) {
|
||
await Promise.all(otherLoadingPromises)
|
||
}
|
||
|
||
// Now unload all loaded models
|
||
const allLoadedModels = await this.getLoadedModels()
|
||
if (allLoadedModels.length > 0) {
|
||
await Promise.all(allLoadedModels.map((model) => this.unload(model)))
|
||
}
|
||
}
|
||
const args: string[] = []
|
||
const cfg = { ...this.config, ...(overrideSettings ?? {}) }
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||
)
|
||
}
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
|
||
const janDataFolderPath = await getJanDataFolderPath()
|
||
const modelConfigPath = await joinPath([
|
||
this.providerPath,
|
||
'models',
|
||
modelId,
|
||
'model.yml',
|
||
])
|
||
const modelConfig = await invoke<ModelConfig>('read_yaml', {
|
||
path: modelConfigPath,
|
||
})
|
||
const port = await this.getRandomPort()
|
||
|
||
// disable llama-server webui
|
||
args.push('--no-webui')
|
||
const api_key = await this.generateApiKey(modelId, String(port))
|
||
args.push('--api-key', api_key)
|
||
|
||
// model option is required
|
||
// NOTE: model_path and mmproj_path can be either relative to Jan's data folder or absolute path
|
||
const modelPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.model_path,
|
||
])
|
||
args.push('--jinja')
|
||
args.push('--reasoning-format', 'none')
|
||
args.push('-m', modelPath)
|
||
// For overriding tensor buffer type, useful where
|
||
// massive MOE models can be made faster by keeping attention on the GPU
|
||
// and offloading the expert FFNs to the CPU.
|
||
// This is an expert level settings and should only be used by people
|
||
// who knows what they are doing.
|
||
// Takes a regex with matching tensor name as input
|
||
if (cfg.override_tensor_buffer_t)
|
||
args.push('--override-tensor', cfg.override_tensor_buffer_t)
|
||
args.push('-a', modelId)
|
||
args.push('--port', String(port))
|
||
if (modelConfig.mmproj_path) {
|
||
const mmprojPath = await joinPath([
|
||
janDataFolderPath,
|
||
modelConfig.mmproj_path,
|
||
])
|
||
args.push('--mmproj', mmprojPath)
|
||
}
|
||
// Add remaining options from the interface
|
||
if (cfg.chat_template) args.push('--chat-template', cfg.chat_template)
|
||
const gpu_layers =
|
||
parseInt(String(cfg.n_gpu_layers)) >= 0 ? cfg.n_gpu_layers : 100
|
||
args.push('-ngl', String(gpu_layers))
|
||
if (cfg.threads > 0) args.push('--threads', String(cfg.threads))
|
||
if (cfg.threads_batch > 0)
|
||
args.push('--threads-batch', String(cfg.threads_batch))
|
||
if (cfg.batch_size > 0) args.push('--batch-size', String(cfg.batch_size))
|
||
if (cfg.ubatch_size > 0) args.push('--ubatch-size', String(cfg.ubatch_size))
|
||
if (cfg.device.length > 0) args.push('--device', cfg.device)
|
||
if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer')
|
||
args.push('--split-mode', cfg.split_mode)
|
||
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
|
||
args.push('--main-gpu', String(cfg.main_gpu))
|
||
|
||
// Boolean flags
|
||
if (!cfg.ctx_shift) args.push('--no-context-shift')
|
||
if (cfg.flash_attn) args.push('--flash-attn')
|
||
if (cfg.cont_batching) args.push('--cont-batching')
|
||
args.push('--no-mmap')
|
||
if (cfg.mlock) args.push('--mlock')
|
||
if (cfg.no_kv_offload) args.push('--no-kv-offload')
|
||
if (isEmbedding) {
|
||
args.push('--embedding')
|
||
args.push('--pooling mean')
|
||
} else {
|
||
if (cfg.ctx_size > 0) args.push('--ctx-size', String(cfg.ctx_size))
|
||
if (cfg.n_predict > 0) args.push('--n-predict', String(cfg.n_predict))
|
||
if (cfg.cache_type_k && cfg.cache_type_k != 'f16')
|
||
args.push('--cache-type-k', cfg.cache_type_k)
|
||
if (
|
||
cfg.flash_attn &&
|
||
cfg.cache_type_v != 'f16' &&
|
||
cfg.cache_type_v != 'f32'
|
||
) {
|
||
args.push('--cache-type-v', cfg.cache_type_v)
|
||
}
|
||
if (cfg.defrag_thold && cfg.defrag_thold != 0.1)
|
||
args.push('--defrag-thold', String(cfg.defrag_thold))
|
||
|
||
if (cfg.rope_scaling && cfg.rope_scaling != 'none')
|
||
args.push('--rope-scaling', cfg.rope_scaling)
|
||
if (cfg.rope_scale && cfg.rope_scale != 1)
|
||
args.push('--rope-scale', String(cfg.rope_scale))
|
||
if (cfg.rope_freq_base && cfg.rope_freq_base != 0)
|
||
args.push('--rope-freq-base', String(cfg.rope_freq_base))
|
||
if (cfg.rope_freq_scale && cfg.rope_freq_scale != 1)
|
||
args.push('--rope-freq-scale', String(cfg.rope_freq_scale))
|
||
}
|
||
|
||
logger.info('Calling Tauri command llama_load with args:', args)
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
|
||
try {
|
||
// TODO: add LIBRARY_PATH
|
||
const sInfo = await invoke<SessionInfo>('load_llama_model', {
|
||
backendPath,
|
||
libraryPath,
|
||
args,
|
||
})
|
||
return sInfo
|
||
} catch (error) {
|
||
logger.error('Error in load command:\n', error)
|
||
throw new Error(`Failed to load model:\n${error}`)
|
||
}
|
||
}
|
||
|
||
override async unload(modelId: string): Promise<UnloadResult> {
|
||
const sInfo: SessionInfo = await this.findSessionByModel(modelId)
|
||
if (!sInfo) {
|
||
throw new Error(`No active session found for model: ${modelId}`)
|
||
}
|
||
const pid = sInfo.pid
|
||
try {
|
||
|
||
// Pass the PID as the session_id
|
||
const result = await invoke<UnloadResult>('unload_llama_model', {
|
||
pid: pid,
|
||
})
|
||
|
||
// If successful, remove from active sessions
|
||
if (result.success) {
|
||
logger.info(`Successfully unloaded model with PID ${pid}`)
|
||
} else {
|
||
logger.warn(`Failed to unload model: ${result.error}`)
|
||
}
|
||
|
||
return result
|
||
} catch (error) {
|
||
logger.error('Error in unload command:', error)
|
||
return {
|
||
success: false,
|
||
error: `Failed to unload model: ${error}`,
|
||
}
|
||
}
|
||
}
|
||
|
||
private createDownloadTaskId(modelId: string) {
|
||
// prepend provider to make taksId unique across providers
|
||
const cleanModelId = modelId.includes('.')
|
||
? modelId.slice(0, modelId.indexOf('.'))
|
||
: modelId
|
||
return `${this.provider}/${cleanModelId}`
|
||
}
|
||
|
||
private async ensureBackendReady(
|
||
backend: string,
|
||
version: string
|
||
): Promise<void> {
|
||
const backendKey = `${version}/${backend}`
|
||
|
||
// Check if backend is already installed
|
||
const isInstalled = await isBackendInstalled(backend, version)
|
||
if (isInstalled) {
|
||
return
|
||
}
|
||
|
||
// Check if download is already in progress
|
||
if (this.pendingDownloads.has(backendKey)) {
|
||
logger.info(
|
||
`Backend ${backendKey} download already in progress, waiting...`
|
||
)
|
||
await this.pendingDownloads.get(backendKey)
|
||
return
|
||
}
|
||
|
||
// Start new download
|
||
logger.info(`Backend ${backendKey} not installed, downloading...`)
|
||
const downloadPromise = downloadBackend(backend, version).finally(() => {
|
||
this.pendingDownloads.delete(backendKey)
|
||
})
|
||
|
||
this.pendingDownloads.set(backendKey, downloadPromise)
|
||
await downloadPromise
|
||
logger.info(`Backend ${backendKey} download completed`)
|
||
}
|
||
|
||
private async *handleStreamingResponse(
|
||
url: string,
|
||
headers: HeadersInit,
|
||
body: string,
|
||
abortController?: AbortController
|
||
): AsyncIterable<chatCompletionChunk> {
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
signal: abortController?.signal,
|
||
})
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
if (!response.body) {
|
||
throw new Error('Response body is null')
|
||
}
|
||
|
||
const reader = response.body.getReader()
|
||
const decoder = new TextDecoder('utf-8')
|
||
let buffer = ''
|
||
let jsonStr = ''
|
||
try {
|
||
while (true) {
|
||
const { done, value } = await reader.read()
|
||
|
||
if (done) {
|
||
break
|
||
}
|
||
|
||
buffer += decoder.decode(value, { stream: true })
|
||
|
||
// Process complete lines in the buffer
|
||
const lines = buffer.split('\n')
|
||
buffer = lines.pop() || '' // Keep the last incomplete line in the buffer
|
||
|
||
for (const line of lines) {
|
||
const trimmedLine = line.trim()
|
||
if (!trimmedLine || trimmedLine === 'data: [DONE]') {
|
||
continue
|
||
}
|
||
|
||
if (trimmedLine.startsWith('data: ')) {
|
||
jsonStr = trimmedLine.slice(6)
|
||
} else if (trimmedLine.startsWith('error: ')) {
|
||
jsonStr = trimmedLine.slice(7)
|
||
const error = JSON.parse(jsonStr)
|
||
throw new Error(error.message)
|
||
} else {
|
||
// it should not normally reach here
|
||
throw new Error('Malformed chunk')
|
||
}
|
||
try {
|
||
const data = JSON.parse(jsonStr)
|
||
const chunk = data as chatCompletionChunk
|
||
yield chunk
|
||
} catch (e) {
|
||
logger.error('Error parsing JSON from stream or server error:', e)
|
||
// re‑throw so the async iterator terminates with an error
|
||
throw e
|
||
}
|
||
}
|
||
}
|
||
} finally {
|
||
reader.releaseLock()
|
||
}
|
||
}
|
||
|
||
private async findSessionByModel(modelId: string): Promise<SessionInfo> {
|
||
try {
|
||
let sInfo = await invoke<SessionInfo>('find_session_by_model', {modelId})
|
||
return sInfo
|
||
} catch (e) {
|
||
logger.error(e)
|
||
throw new Error(String(e))
|
||
}
|
||
}
|
||
|
||
override async chat(
|
||
opts: chatCompletionRequest,
|
||
abortController?: AbortController
|
||
): Promise<chatCompletion | AsyncIterable<chatCompletionChunk>> {
|
||
const sessionInfo = await this.findSessionByModel(opts.model)
|
||
if (!sessionInfo) {
|
||
throw new Error(`No active session found for model: ${opts.model}`)
|
||
}
|
||
// check if the process is alive
|
||
const result = await invoke<boolean>('is_process_running', {
|
||
pid: sessionInfo.pid,
|
||
})
|
||
if (result) {
|
||
try {
|
||
await fetch(`http://localhost:${sessionInfo.port}/health`)
|
||
} catch (e) {
|
||
this.unload(sessionInfo.model_id)
|
||
throw new Error('Model appears to have crashed! Please reload!')
|
||
}
|
||
} else {
|
||
throw new Error('Model have crashed! Please reload!')
|
||
}
|
||
const baseUrl = `http://localhost:${sessionInfo.port}/v1`
|
||
const url = `${baseUrl}/chat/completions`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sessionInfo.api_key}`,
|
||
}
|
||
|
||
const body = JSON.stringify(opts)
|
||
if (opts.stream) {
|
||
return this.handleStreamingResponse(url, headers, body, abortController)
|
||
}
|
||
// Handle non-streaming response
|
||
const response = await fetch(url, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
signal: abortController?.signal,
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
|
||
return (await response.json()) as chatCompletion
|
||
}
|
||
|
||
override async delete(modelId: string): Promise<void> {
|
||
const modelDir = await joinPath([
|
||
await this.getProviderPath(),
|
||
'models',
|
||
modelId,
|
||
])
|
||
|
||
if (!(await fs.existsSync(await joinPath([modelDir, 'model.yml'])))) {
|
||
throw new Error(`Model ${modelId} does not exist`)
|
||
}
|
||
|
||
await fs.rm(modelDir)
|
||
}
|
||
|
||
override async getLoadedModels(): Promise<string[]> {
|
||
try {
|
||
let models: string[] = await invoke<string[]>('get_loaded_models')
|
||
return models
|
||
} catch (e) {
|
||
logger.error(e)
|
||
throw new Error(e)
|
||
}
|
||
}
|
||
|
||
async getDevices(): Promise<DeviceList[]> {
|
||
const cfg = this.config
|
||
const [version, backend] = cfg.version_backend.split('/')
|
||
if (!version || !backend) {
|
||
throw new Error(
|
||
`Invalid version/backend format: ${cfg.version_backend}. Expected format: <version>/<backend>`
|
||
)
|
||
}
|
||
|
||
// Ensure backend is downloaded and ready before proceeding
|
||
await this.ensureBackendReady(backend, version)
|
||
logger.info('Calling Tauri command getDevices with arg --list-devices')
|
||
const backendPath = await getBackendExePath(backend, version)
|
||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||
try {
|
||
const dList = await invoke<DeviceList[]>('get_devices', {
|
||
backendPath,
|
||
libraryPath,
|
||
})
|
||
return dList
|
||
} catch (error) {
|
||
logger.error('Failed to query devices:\n', error)
|
||
throw new Error(`Failed to load llama-server: ${error}`)
|
||
}
|
||
}
|
||
|
||
async embed(text: string[]): Promise<EmbeddingResponse> {
|
||
let sInfo = await this.findSessionByModel('sentence-transformer-mini')
|
||
if (!sInfo) {
|
||
const downloadedModelList = await this.list()
|
||
if (
|
||
!downloadedModelList.some(
|
||
(model) => model.id === 'sentence-transformer-mini'
|
||
)
|
||
) {
|
||
await this.import('sentence-transformer-mini', {
|
||
modelPath:
|
||
'https://huggingface.co/second-state/All-MiniLM-L6-v2-Embedding-GGUF/resolve/main/all-MiniLM-L6-v2-ggml-model-f16.gguf?download=true',
|
||
})
|
||
}
|
||
sInfo = await this.load('sentence-transformer-mini')
|
||
}
|
||
const baseUrl = `http://localhost:${sInfo.port}/v1/embeddings`
|
||
const headers = {
|
||
'Content-Type': 'application/json',
|
||
'Authorization': `Bearer ${sInfo.api_key}`,
|
||
}
|
||
const body = JSON.stringify({
|
||
input: text,
|
||
model: sInfo.model_id,
|
||
encoding_format: 'float',
|
||
})
|
||
const response = await fetch(baseUrl, {
|
||
method: 'POST',
|
||
headers,
|
||
body,
|
||
})
|
||
|
||
if (!response.ok) {
|
||
const errorData = await response.json().catch(() => null)
|
||
throw new Error(
|
||
`API request failed with status ${response.status}: ${JSON.stringify(
|
||
errorData
|
||
)}`
|
||
)
|
||
}
|
||
const responseData = await response.json()
|
||
return responseData as EmbeddingResponse
|
||
}
|
||
|
||
// Optional method for direct client access
|
||
override getChatClient(sessionId: string): any {
|
||
throw new Error('method not implemented yet')
|
||
}
|
||
}
|