From 510c70bdf77d9a8951c964ff900e45c2ff014d5e Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Thu, 21 Aug 2025 16:13:50 +0530 Subject: [PATCH] feat: Add model compatibility check and memory estimation (#6243) * feat: Add model compatibility check and memory estimation This commit introduces a new feature to check if a given model is supported based on available device memory. The change includes: - A new `estimateKVCache` method that calculates the required memory for the model's KV cache. It uses GGUF metadata such as `block_count`, `head_count`, `key_length`, and `value_length` to perform the calculation. - An `isModelSupported` method that combines the model file size and the estimated KV cache size to determine the total memory required. It then checks if any available device has sufficient free memory to load the model. - An updated error message for the `version_backend` check to be more user-friendly, suggesting a stable internet connection as a potential solution for backend setup failures. This functionality helps prevent the application from attempting to load models that would exceed the device's memory capacity, leading to more stable and predictable behavior. fixes: #5505 * Update extensions/llamacpp-extension/src/index.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Update extensions/llamacpp-extension/src/index.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Extend this to available system RAM if GGML device is not available * fix: Improve model metadata and memory checks This commit refactors the logic for checking if a model is supported by a system's available memory. **Key changes:** - **Remote model support**: The `read_gguf_metadata` function can now fetch metadata from a remote URL by reading the file in chunks. - **Improved KV cache size calculation**: The KV cache size is now estimated more accurately by using `attention.key_length` and `attention.value_length` from the GGUF metadata, with a fallback to `embedding_length`. - **Granular memory check statuses**: The `isModelSupported` function now returns a more specific status (`'RED'`, `'YELLOW'`, `'GREEN'`) to indicate whether the model weights or the KV cache are too large for the available memory. - **Consolidated logic**: The logic for checking local and remote models has been consolidated into a single `isModelSupported` function, improving code clarity and maintainability. These changes provide more robust and informative model compatibility checks, especially for models hosted on remote servers. * Update extensions/llamacpp-extension/src/index.ts Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> * Make ctx_size optional and use sum free memory across ggml devices * feat: hub and dropdown model selection handle model compatibility * feat: update bage model info color * chore: enable detail page to get compatibility model * chore: update copy * chore: update shrink indicator UI --------- Co-authored-by: ellipsis-dev[bot] <65095814+ellipsis-dev[bot]@users.noreply.github.com> Co-authored-by: Faisal Amir --- extensions/llamacpp-extension/src/index.ts | 138 ++++++++++- .../plugins/tauri-plugin-llamacpp/Cargo.toml | 3 +- .../src/gguf/commands.rs | 52 +++- .../tauri-plugin-llamacpp/src/gguf/helpers.rs | 6 +- web-app/src/containers/ChatInput.tsx | 17 +- .../src/containers/DropdownModelProvider.tsx | 85 ++++++- web-app/src/containers/ModelInfoHoverCard.tsx | 226 ++++++++++++++++++ web-app/src/containers/ModelSupportStatus.tsx | 142 +++++++++++ .../containers/__tests__/ChatInput.test.tsx | 9 - web-app/src/routes/hub/$modelId.tsx | 59 ++++- web-app/src/routes/hub/index.tsx | 193 +++++++++++---- web-app/src/services/__tests__/models.test.ts | 92 +++++++ web-app/src/services/models.ts | 32 +++ 13 files changed, 978 insertions(+), 76 deletions(-) create mode 100644 web-app/src/containers/ModelInfoHoverCard.tsx create mode 100644 web-app/src/containers/ModelSupportStatus.tsx diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index f4cdd83c8..fe4f2f34c 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -35,7 +35,11 @@ import { import { invoke } from '@tauri-apps/api/core' import { getProxyConfig } from './util' import { basename } from '@tauri-apps/api/path' -import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api' +import { + GgufMetadata, + readGgufMetadata, +} from '@janhq/tauri-plugin-llamacpp-api' +import { getSystemUsage } from '@janhq/tauri-plugin-hardware-api' type LlamacppConfig = { version_backend: string @@ -1742,7 +1746,7 @@ export default class llamacpp_extension extends AIEngine { const [version, backend] = cfg.version_backend.split('/') if (!version || !backend) { throw new Error( - `Invalid version/backend format: ${cfg.version_backend}. Expected format: /` + 'Backend setup was not successful. Please restart the app in a stable internet connection.' ) } // set envs @@ -1843,4 +1847,134 @@ export default class llamacpp_extension extends AIEngine { 'tokenizer.chat_template' ]?.includes('tools') } + + /** + * estimate KVCache size of from a given metadata + * + */ + private async estimateKVCache( + meta: Record, + ctx_size?: number + ): Promise { + const arch = meta['general.architecture'] + if (!arch) throw new Error('Invalid metadata: architecture not found') + + const nLayer = Number(meta[`${arch}.block_count`]) + if (!nLayer) throw new Error('Invalid metadata: block_count not found') + + const nHead = Number(meta[`${arch}.attention.head_count`]) + if (!nHead) throw new Error('Invalid metadata: head_count not found') + + // Try to get key/value lengths first (more accurate) + const keyLen = Number(meta[`${arch}.attention.key_length`]) + const valLen = Number(meta[`${arch}.attention.value_length`]) + + let headDim: number + + if (keyLen && valLen) { + // Use explicit key/value lengths if available + logger.info( + `Using explicit key_length: ${keyLen}, value_length: ${valLen}` + ) + headDim = (keyLen + valLen) + } else { + // Fall back to embedding_length estimation + const embeddingLen = Number(meta[`${arch}.embedding_length`]) + if (!embeddingLen) + throw new Error('Invalid metadata: embedding_length not found') + + // Standard transformer: head_dim = embedding_dim / num_heads + // For KV cache: we need both K and V, so 2 * head_dim per head + headDim = (embeddingLen / nHead) * 2 + logger.info( + `Using embedding_length estimation: ${embeddingLen}, calculated head_dim: ${headDim}` + ) + } + let ctxLen: number + if (!ctx_size) { + ctxLen = Number(meta[`${arch}.context_length`]) + } else { + ctxLen = ctx_size + } + + logger.info(`ctxLen: ${ctxLen}`) + logger.info(`nLayer: ${nLayer}`) + logger.info(`nHead: ${nHead}`) + logger.info(`headDim: ${headDim}`) + + // Consider f16 by default + // Can be extended by checking cache-type-v and cache-type-k + // but we are checking overall compatibility with the default settings + // fp16 = 8 bits * 2 = 16 + const bytesPerElement = 2 + + // Total KV cache size per token = nHead * headDim * bytesPerElement + const kvPerToken = nHead * headDim * bytesPerElement + + return ctxLen * nLayer * kvPerToken + } + + private async getModelSize(path: string): Promise { + if (path.startsWith('https://')) { + const res = await fetch(path, { method: 'HEAD' }) + const len = res.headers.get('content-length') + return len ? parseInt(len, 10) : 0 + } else { + return (await fs.fileStat(path)).size + } + } + + /* + * check the support status of a model by its path (local/remote) + * + * * Returns: + * - "RED" → weights don't fit + * - "YELLOW" → weights fit, KV cache doesn't + * - "GREEN" → both weights + KV cache fit + */ + async isModelSupported( + path: string, + ctx_size?: number + ): Promise<'RED' | 'YELLOW' | 'GREEN'> { + try { + const modelSize = await this.getModelSize(path) + logger.info(`modelSize: ${modelSize}`) + let gguf: GgufMetadata + gguf = await readGgufMetadata(path) + let kvCacheSize: number + if (ctx_size) { + kvCacheSize = await this.estimateKVCache(gguf.metadata, ctx_size) + } else { + kvCacheSize = await this.estimateKVCache(gguf.metadata) + } + // total memory consumption = model weights + kvcache + a small buffer for outputs + // output buffer is small so not considering here + const totalRequired = modelSize + kvCacheSize + logger.info( + `isModelSupported: Total memory requirement: ${totalRequired} for ${path}` + ) + let availableMemBytes: number + const devices = await this.getDevices() + if (devices.length > 0) { + // Sum free memory across all GPUs + availableMemBytes = devices + .map((d) => d.free * 1024 * 1024) + .reduce((a, b) => a + b, 0) + } else { + // CPU fallback + const sys = await getSystemUsage() + availableMemBytes = (sys.total_memory - sys.used_memory) * 1024 * 1024 + } + // check model size wrt system memory + if (modelSize > availableMemBytes) { + return 'RED' + } else if (modelSize + kvCacheSize > availableMemBytes) { + return 'YELLOW' + } else { + return 'GREEN' + } + } catch (e) { + throw new Error(String(e)) + } + } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml b/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml index d4abaf88d..0defc62b7 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml +++ b/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml @@ -23,12 +23,13 @@ sysinfo = "0.34.2" tauri = { version = "2.5.0", default-features = false, features = [] } thiserror = "2.0.12" tokio = { version = "1", features = ["full"] } +reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } # Windows-specific dependencies [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] } -# Unix-specific dependencies +# Unix-specific dependencies [target.'cfg(unix)'.dependencies] nix = { version = "=0.30.1", features = ["signal", "process"] } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs index 5d005a241..ae38f56f3 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs @@ -1,8 +1,58 @@ use super::helpers; use super::types::GgufMetadata; +use reqwest; +use std::fs::File; +use std::io::BufReader; /// Read GGUF metadata from a model file #[tauri::command] pub async fn read_gguf_metadata(path: String) -> Result { - helpers::read_gguf_metadata(&path).map_err(|e| format!("Failed to read GGUF metadata: {}", e)) + if path.starts_with("http://") || path.starts_with("https://") { + // Remote: read in 2MB chunks until successful + let client = reqwest::Client::new(); + let chunk_size = 2 * 1024 * 1024; // Fixed 2MB chunks + let max_total_size = 120 * 1024 * 1024; // Don't exceed 120MB total + let mut total_downloaded = 0; + let mut accumulated_data = Vec::new(); + + while total_downloaded < max_total_size { + let start = total_downloaded; + let end = std::cmp::min(start + chunk_size - 1, max_total_size - 1); + + let resp = client + .get(&path) + .header("Range", format!("bytes={}-{}", start, end)) + .send() + .await + .map_err(|e| format!("Failed to fetch chunk {}-{}: {}", start, end, e))?; + + let chunk_data = resp + .bytes() + .await + .map_err(|e| format!("Failed to read chunk response: {}", e))?; + + accumulated_data.extend_from_slice(&chunk_data); + total_downloaded += chunk_data.len(); + + // Try parsing after each chunk + let cursor = std::io::Cursor::new(&accumulated_data); + if let Ok(metadata) = helpers::read_gguf_metadata(cursor) { + return Ok(metadata); + } + + // If we got less data than expected, we've reached EOF + if chunk_data.len() < chunk_size { + break; + } + } + Err("Could not parse GGUF metadata from downloaded data".to_string()) + } else { + // Local: use streaming file reader + let file = + File::open(&path).map_err(|e| format!("Failed to open local file {}: {}", path, e))?; + let reader = BufReader::new(file); + + helpers::read_gguf_metadata(reader) + .map_err(|e| format!("Failed to parse GGUF metadata: {}", e)) + } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs index 245b986a1..b728522c3 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs @@ -1,13 +1,11 @@ use byteorder::{LittleEndian, ReadBytesExt}; use std::convert::TryFrom; -use std::fs::File; use std::io::{self, BufReader, Read, Seek}; -use std::path::Path; use super::types::{GgufMetadata, GgufValueType}; -pub fn read_gguf_metadata>(path: P) -> io::Result { - let mut file = BufReader::new(File::open(path)?); +pub fn read_gguf_metadata(reader: R) -> io::Result { + let mut file = BufReader::new(reader); let mut magic = [0u8; 4]; file.read_exact(&mut magic)?; diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 53e47db51..0fa7a4b32 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -17,7 +17,6 @@ import { IconPhoto, IconWorld, IconAtom, - IconEye, IconTool, IconCodeCircle2, IconPlayerStopFilled, @@ -537,7 +536,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { {/* File attachment - show only for models with mmproj */} {hasMmproj && (
@@ -554,20 +553,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { {/*
*/} - {selectedModel?.capabilities?.includes('vision') && ( - - - -
- -
-
- -

{t('vision')}

-
-
-
- )} {selectedModel?.capabilities?.includes('embeddings') && ( diff --git a/web-app/src/containers/DropdownModelProvider.tsx b/web-app/src/containers/DropdownModelProvider.tsx index ac2bbc535..005580890 100644 --- a/web-app/src/containers/DropdownModelProvider.tsx +++ b/web-app/src/containers/DropdownModelProvider.tsx @@ -14,12 +14,16 @@ import { route } from '@/constants/routes' import { useThreads } from '@/hooks/useThreads' import { ModelSetting } from '@/containers/ModelSetting' import ProvidersAvatar from '@/containers/ProvidersAvatar' +import { ModelSupportStatus } from '@/containers/ModelSupportStatus' import { Fzf } from 'fzf' import { localStorageKey } from '@/constants/localStorage' import { useTranslation } from '@/i18n/react-i18next-compat' import { useFavoriteModel } from '@/hooks/useFavoriteModel' import { predefinedProviders } from '@/consts/providers' -import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models' +import { + checkMmprojExistsAndUpdateOffloadMMprojSetting, + checkMmprojExists, +} from '@/services/models' type DropdownModelProviderProps = { model?: ThreadModel @@ -91,6 +95,50 @@ const DropdownModelProvider = ({ [providers] ) + // Helper function to get context size from model settings + const getContextSize = useCallback((): number => { + if (!selectedModel?.settings?.ctx_len?.controller_props?.value) { + return 8192 // Default context size + } + return selectedModel.settings.ctx_len.controller_props.value as number + }, [selectedModel?.settings?.ctx_len?.controller_props?.value]) + + // Function to check if a llamacpp model has vision capabilities and update model capabilities + const checkAndUpdateModelVisionCapability = useCallback( + async (modelId: string) => { + try { + const hasVision = await checkMmprojExists(modelId) + if (hasVision) { + // Update the model capabilities to include 'vision' + const provider = getProviderByName('llamacpp') + if (provider) { + const modelIndex = provider.models.findIndex( + (m) => m.id === modelId + ) + if (modelIndex !== -1) { + const model = provider.models[modelIndex] + const capabilities = model.capabilities || [] + + // Add 'vision' capability if not already present + if (!capabilities.includes('vision')) { + const updatedModels = [...provider.models] + updatedModels[modelIndex] = { + ...model, + capabilities: [...capabilities, 'vision'], + } + + updateProvider('llamacpp', { models: updatedModels }) + } + } + } + } + } catch (error) { + console.debug('Error checking mmproj for model:', modelId, error) + } + }, + [getProviderByName, updateProvider] + ) + // Initialize model provider only once useEffect(() => { const initializeModel = async () => { @@ -107,6 +155,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(model.id as string) } } else if (useLastUsedModel) { // Try to use last used model only when explicitly requested (for new chat) @@ -119,6 +169,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(lastUsed.model) } } else { selectModelProvider('', '') @@ -136,6 +188,7 @@ const DropdownModelProvider = ({ checkModelExists, updateProvider, getProviderByName, + checkAndUpdateModelVisionCapability, ]) // Update display model when selection changes @@ -147,6 +200,25 @@ const DropdownModelProvider = ({ } }, [selectedProvider, selectedModel, t]) + // Check vision capabilities for all llamacpp models + useEffect(() => { + const checkAllLlamacppModelsForVision = async () => { + const llamacppProvider = providers.find( + (p) => p.provider === 'llamacpp' && p.active + ) + if (llamacppProvider) { + const checkPromises = llamacppProvider.models.map((model) => + checkAndUpdateModelVisionCapability(model.id) + ) + await Promise.allSettled(checkPromises) + } + } + + if (open) { + checkAllLlamacppModelsForVision() + } + }, [open, providers, checkAndUpdateModelVisionCapability]) + // Reset search value when dropdown closes const onOpenChange = useCallback((open: boolean) => { setOpen(open) @@ -287,6 +359,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(searchableModel.model.id) } // Store the selected model as last used @@ -305,6 +379,7 @@ const DropdownModelProvider = ({ useLastUsedModel, updateProvider, getProviderByName, + checkAndUpdateModelVisionCapability, ] ) @@ -318,7 +393,7 @@ const DropdownModelProvider = ({ return ( -
+
+ onCheckModelSupport: (variant: ModelQuant) => void + children?: React.ReactNode +} + +export const ModelInfoHoverCard = ({ + model, + variant, + defaultModelQuantizations, + modelSupportStatus, + onCheckModelSupport, + children, +}: ModelInfoHoverCardProps) => { + const isVariantMode = !!variant + const displayVariant = + variant || + model.quants.find((m) => + defaultModelQuantizations.some((e) => + m.model_id.toLowerCase().includes(e) + ) + ) || + model.quants?.[0] + + const handleMouseEnter = () => { + if (displayVariant) { + onCheckModelSupport(displayVariant) + } + } + + const getCompatibilityStatus = () => { + const status = displayVariant + ? modelSupportStatus[displayVariant.model_id] + : null + + if (status === 'LOADING') { + return ( +
+
+ Checking... +
+ ) + } else if (status === 'GREEN') { + return ( +
+
+ + Recommended for your device + +
+ ) + } else if (status === 'YELLOW') { + return ( +
+
+ + May be slow on your device + +
+ ) + } else if (status === 'RED') { + return ( +
+
+ + May be incompatible with your device + +
+ ) + } else { + return ( +
+
+ Unknown +
+ ) + } + } + + return ( + + + {children || ( +
+ +
+ )} +
+ +
+ {/* Header */} +
+

+ {isVariantMode ? variant.model_id : model.model_name} +

+

+ {isVariantMode + ? 'Model Variant Information' + : 'Model Information'} +

+
+ + {/* Main Info Grid */} +
+
+ {isVariantMode ? ( + <> +
+ + File Size + + + {variant.file_size} + +
+
+ + Quantization + + + {variant.model_id.split('-').pop()?.toUpperCase() || + 'N/A'} + +
+ + ) : ( + <> +
+ + Downloads + + + {model.downloads?.toLocaleString() || '0'} + +
+
+ Variants + + {model.quants?.length || 0} + +
+ + )} +
+ +
+ {!isVariantMode && ( +
+ + Default Size + + + {displayVariant?.file_size || 'N/A'} + +
+ )} +
+ + Compatibility + +
+ {getCompatibilityStatus()} +
+
+
+
+ + {/* Features Section */} + {(model.num_mmproj > 0 || model.tools) && ( +
+
+ Features +
+
+ {model.num_mmproj > 0 && ( +
+ + Vision + +
+ )} + {model.tools && ( +
+ + Tools + +
+ )} +
+
+ )} + + {/* Content Section */} +
+
+ {isVariantMode ? 'Download URL' : 'Description'} +
+
+ {isVariantMode ? ( +
{variant.path}
+ ) : ( + extractDescription(model?.description) || + 'No description available' + )} +
+
+
+
+
+ ) +} diff --git a/web-app/src/containers/ModelSupportStatus.tsx b/web-app/src/containers/ModelSupportStatus.tsx new file mode 100644 index 000000000..3667f4461 --- /dev/null +++ b/web-app/src/containers/ModelSupportStatus.tsx @@ -0,0 +1,142 @@ +import { useCallback, useEffect, useState } from 'react' +import { cn } from '@/lib/utils' +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from '@/components/ui/tooltip' +import { isModelSupported } from '@/services/models' +import { getJanDataFolderPath, joinPath } from '@janhq/core' + +interface ModelSupportStatusProps { + modelId: string | undefined + provider: string | undefined + contextSize: number + className?: string +} + +export const ModelSupportStatus = ({ + modelId, + provider, + contextSize, + className, +}: ModelSupportStatusProps) => { + const [modelSupportStatus, setModelSupportStatus] = useState< + 'RED' | 'YELLOW' | 'GREEN' | 'LOADING' | null + >(null) + + // Helper function to check model support with proper path resolution + const checkModelSupportWithPath = useCallback( + async ( + id: string, + ctxSize: number + ): Promise<'RED' | 'YELLOW' | 'GREEN'> => { + try { + // Get Jan's data folder path and construct the full model file path + // Following the llamacpp extension structure: /llamacpp/models//model.gguf + const janDataFolder = await getJanDataFolderPath() + const modelFilePath = await joinPath([ + janDataFolder, + 'llamacpp', + 'models', + id, + 'model.gguf', + ]) + + return await isModelSupported(modelFilePath, ctxSize) + } catch (error) { + console.error( + 'Error checking model support with constructed path:', + error + ) + // If path construction or model support check fails, assume not supported + return 'RED' + } + }, + [] + ) + + // Helper function to get icon color based on model support status + const getStatusColor = (): string => { + switch (modelSupportStatus) { + case 'GREEN': + return 'bg-green-500' + case 'YELLOW': + return 'bg-yellow-500' + case 'RED': + return 'bg-red-500' + case 'LOADING': + return 'bg-main-view-fg/50' + default: + return 'bg-main-view-fg/50' + } + } + + // Helper function to get tooltip text based on model support status + const getStatusTooltip = (): string => { + switch (modelSupportStatus) { + case 'GREEN': + return `Works Well on your device (ctx: ${contextSize})` + case 'YELLOW': + return `Might work on your device (ctx: ${contextSize})` + case 'RED': + return `Doesn't work on your device (ctx: ${contextSize})` + case 'LOADING': + return 'Checking device compatibility...' + default: + return 'Unknown' + } + } + + // Check model support when model changes + useEffect(() => { + const checkModelSupport = async () => { + if (modelId && provider === 'llamacpp') { + // Set loading state immediately + setModelSupportStatus('LOADING') + try { + const supportStatus = await checkModelSupportWithPath( + modelId, + contextSize + ) + setModelSupportStatus(supportStatus) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus('RED') + } + } else { + // Only show status for llamacpp models since isModelSupported is specific to llamacpp + setModelSupportStatus(null) + } + } + + checkModelSupport() + }, [modelId, provider, contextSize, checkModelSupportWithPath]) + + // Don't render anything if no status or not llamacpp + if (!modelSupportStatus || provider !== 'llamacpp') { + return null + } + + return ( + + + +
+ + +

{getStatusTooltip()}

+
+ + + ) +} diff --git a/web-app/src/containers/__tests__/ChatInput.test.tsx b/web-app/src/containers/__tests__/ChatInput.test.tsx index 292484006..0580821e5 100644 --- a/web-app/src/containers/__tests__/ChatInput.test.tsx +++ b/web-app/src/containers/__tests__/ChatInput.test.tsx @@ -291,15 +291,6 @@ describe('ChatInput', () => { expect(stopButton).toBeInTheDocument() }) - it('shows capability icons when model supports them', () => { - act(() => { - renderWithRouter() - }) - - // Should show vision icon (rendered as SVG with tabler-icon-eye class) - const visionIcon = document.querySelector('.tabler-icon-eye') - expect(visionIcon).toBeInTheDocument() - }) it('shows model selection dropdown', () => { act(() => { diff --git a/web-app/src/routes/hub/$modelId.tsx b/web-app/src/routes/hub/$modelId.tsx index e5f2f44b3..2d0eecc70 100644 --- a/web-app/src/routes/hub/$modelId.tsx +++ b/web-app/src/routes/hub/$modelId.tsx @@ -20,19 +20,24 @@ import { useModelProvider } from '@/hooks/useModelProvider' import { useDownloadStore } from '@/hooks/useDownloadStore' import { CatalogModel, + ModelQuant, convertHfRepoToCatalogModel, fetchHuggingFaceRepo, pullModelWithMetadata, + isModelSupported, } from '@/services/models' import { Progress } from '@/components/ui/progress' import { Button } from '@/components/ui/button' import { cn } from '@/lib/utils' import { useGeneralSetting } from '@/hooks/useGeneralSetting' +import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard' type SearchParams = { repo: string } +const defaultModelQuantizations = ['iq4_xs', 'q4_k_m'] + export const Route = createFileRoute('/hub/$modelId')({ component: HubModelDetail, validateSearch: (search: Record): SearchParams => ({ @@ -57,6 +62,11 @@ function HubModelDetail() { const [readmeContent, setReadmeContent] = useState('') const [isLoadingReadme, setIsLoadingReadme] = useState(false) + // State for model support status + const [modelSupportStatus, setModelSupportStatus] = useState< + Record + >({}) + useEffect(() => { fetchSources() }, [fetchSources]) @@ -131,6 +141,41 @@ function HubModelDetail() { } } + // Check model support function + const checkModelSupport = useCallback( + async (variant: ModelQuant) => { + const modelKey = variant.model_id + + // Don't check again if already checking or checked + if (modelSupportStatus[modelKey]) { + return + } + + // Set loading state + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'LOADING', + })) + + try { + // Use the HuggingFace path for the model + const modelPath = variant.path + const supported = await isModelSupported(modelPath, 8192) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: supported, + })) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'RED', + })) + } + }, + [modelSupportStatus] + ) + // Extract tags from quants (model variants) const tags = useMemo(() => { if (!modelData?.quants) return [] @@ -318,6 +363,7 @@ function HubModelDetail() { Size + Action @@ -372,7 +418,18 @@ function HubModelDetail() { {variant.file_size} - + + + + {(() => { if (isDownloading && !isDownloaded) { return ( diff --git a/web-app/src/routes/hub/index.tsx b/web-app/src/routes/hub/index.tsx index 93658816b..07dd0f85b 100644 --- a/web-app/src/routes/hub/index.tsx +++ b/web-app/src/routes/hub/index.tsx @@ -31,6 +31,7 @@ import { TooltipProvider, TooltipTrigger, } from '@/components/ui/tooltip' +import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard' import Joyride, { CallBackProps, STATUS } from 'react-joyride' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { @@ -44,6 +45,7 @@ import { pullModelWithMetadata, fetchHuggingFaceRepo, convertHfRepoToCatalogModel, + isModelSupported, } from '@/services/models' import { useDownloadStore } from '@/hooks/useDownloadStore' import { Progress } from '@/components/ui/progress' @@ -97,6 +99,9 @@ function Hub() { const [huggingFaceRepo, setHuggingFaceRepo] = useState( null ) + const [modelSupportStatus, setModelSupportStatus] = useState< + Record + >({}) const [joyrideReady, setJoyrideReady] = useState(false) const [currentStepIndex, setCurrentStepIndex] = useState(0) const addModelSourceTimeoutRef = useRef | null>( @@ -270,6 +275,41 @@ function Hub() { [navigate] ) + const checkModelSupport = useCallback( + async (variant: any) => { + const modelKey = variant.model_id + + // Don't check again if already checking or checked + if (modelSupportStatus[modelKey]) { + return + } + + // Set loading state + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'LOADING', + })) + + try { + // Use the HuggingFace path for the model + const modelPath = variant.path + const supportStatus = await isModelSupported(modelPath, 8192) + + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: supportStatus, + })) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'RED', + })) + } + }, + [modelSupportStatus] + ) + const DownloadButtonPlaceholder = useMemo(() => { return ({ model }: ModelProps) => { // Check if this is a HuggingFace repository (no quants) @@ -616,6 +656,14 @@ function Hub() { )?.file_size } + @@ -671,45 +719,47 @@ function Hub() { ?.length || 0}
- {filteredModels[virtualItem.index].tools && ( -
- - - -
- -
-
- -

{t('tools')}

-
-
-
-
- )} - {filteredModels[virtualItem.index].num_mmproj > - 0 && ( -
- - - -
- -
-
- -

{t('vision')}

-
-
-
-
- )} +
+ {filteredModels[virtualItem.index].num_mmproj > + 0 && ( +
+ + + +
+ +
+
+ +

{t('vision')}

+
+
+
+
+ )} + {filteredModels[virtualItem.index].tools && ( +
+ + + +
+ +
+
+ +

{t('tools')}

+
+
+
+
+ )} +
{filteredModels[virtualItem.index].quants.length > 1 && (
@@ -744,12 +794,75 @@ function Hub() { (variant) => ( +
+ + {variant.model_id} + + {filteredModels[virtualItem.index] + .num_mmproj > 0 && ( +
+ + + +
+ +
+
+ +

{t('vision')}

+
+
+
+
+ )} + {filteredModels[virtualItem.index] + .tools && ( +
+ + + +
+ +
+
+ +

{t('tools')}

+
+
+
+
+ )} +
+ + } actions={

{variant.file_size}

+ {(() => { const isDownloading = localDownloadingModels.has( diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index 368fd19be..286fb01a4 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -13,6 +13,7 @@ import { stopModel, stopAllModels, startModel, + isModelSupported, HuggingFaceRepo, CatalogModel, } from '../models' @@ -845,4 +846,95 @@ describe('models service', () => { expect(result.quants[0].file_size).toBe('Unknown size') }) }) + + describe('isModelSupported', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return GREEN when model is fully supported', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('GREEN'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/model.gguf', 4096) + + expect(result).toBe('GREEN') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/model.gguf', + 4096 + ) + }) + + it('should return YELLOW when model weights fit but KV cache does not', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('YELLOW'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/model.gguf', 8192) + + expect(result).toBe('YELLOW') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/model.gguf', + 8192 + ) + }) + + it('should return RED when model is not supported', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('RED'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/large-model.gguf') + + expect(result).toBe('RED') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/large-model.gguf', + undefined + ) + }) + + it('should return YELLOW as fallback when engine method is not available', async () => { + const mockEngineWithoutSupport = { + ...mockEngine, + // isModelSupported method not available + } + + mockEngineManager.get.mockReturnValue(mockEngineWithoutSupport) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('YELLOW') + }) + + it('should return RED when engine is not available', async () => { + mockEngineManager.get.mockReturnValue(null) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('YELLOW') // Should use fallback + }) + + it('should return RED when there is an error', async () => { + const mockEngineWithError = { + ...mockEngine, + isModelSupported: vi.fn().mockRejectedValue(new Error('Test error')), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithError) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('RED') + }) + }) }) diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index 0edfe165a..df7dacf00 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -579,3 +579,35 @@ export const checkMmprojExists = async (modelId: string): Promise => { } return false } + +/** + * Checks if a model is supported by analyzing memory requirements and system resources. + * @param modelPath - The path to the model file (local path or URL) + * @param ctxSize - The context size for the model (default: 4096) + * @returns Promise<'RED' | 'YELLOW' | 'GREEN'> - Support status: + * - 'RED': Model weights don't fit in available memory + * - 'YELLOW': Model weights fit, but KV cache doesn't + * - 'GREEN': Both model weights and KV cache fit in available memory + */ +export const isModelSupported = async ( + modelPath: string, + ctxSize?: number +): Promise<'RED' | 'YELLOW' | 'GREEN'> => { + try { + const engine = getEngine('llamacpp') as AIEngine & { + isModelSupported?: ( + path: string, + ctx_size?: number + ) => Promise<'RED' | 'YELLOW' | 'GREEN'> + } + if (engine && typeof engine.isModelSupported === 'function') { + return await engine.isModelSupported(modelPath, ctxSize) + } + // Fallback if method is not available + console.warn('isModelSupported method not available in llamacpp engine') + return 'YELLOW' // Conservative fallback + } catch (error) { + console.error(`Error checking model support for ${modelPath}:`, error) + return 'RED' // Error state, assume not supported + } +}