diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index f4cdd83c8..fe4f2f34c 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -35,7 +35,11 @@ import { import { invoke } from '@tauri-apps/api/core' import { getProxyConfig } from './util' import { basename } from '@tauri-apps/api/path' -import { readGgufMetadata } from '@janhq/tauri-plugin-llamacpp-api' +import { + GgufMetadata, + readGgufMetadata, +} from '@janhq/tauri-plugin-llamacpp-api' +import { getSystemUsage } from '@janhq/tauri-plugin-hardware-api' type LlamacppConfig = { version_backend: string @@ -1742,7 +1746,7 @@ export default class llamacpp_extension extends AIEngine { const [version, backend] = cfg.version_backend.split('/') if (!version || !backend) { throw new Error( - `Invalid version/backend format: ${cfg.version_backend}. Expected format: /` + 'Backend setup was not successful. Please restart the app in a stable internet connection.' ) } // set envs @@ -1843,4 +1847,134 @@ export default class llamacpp_extension extends AIEngine { 'tokenizer.chat_template' ]?.includes('tools') } + + /** + * estimate KVCache size of from a given metadata + * + */ + private async estimateKVCache( + meta: Record, + ctx_size?: number + ): Promise { + const arch = meta['general.architecture'] + if (!arch) throw new Error('Invalid metadata: architecture not found') + + const nLayer = Number(meta[`${arch}.block_count`]) + if (!nLayer) throw new Error('Invalid metadata: block_count not found') + + const nHead = Number(meta[`${arch}.attention.head_count`]) + if (!nHead) throw new Error('Invalid metadata: head_count not found') + + // Try to get key/value lengths first (more accurate) + const keyLen = Number(meta[`${arch}.attention.key_length`]) + const valLen = Number(meta[`${arch}.attention.value_length`]) + + let headDim: number + + if (keyLen && valLen) { + // Use explicit key/value lengths if available + logger.info( + `Using explicit key_length: ${keyLen}, value_length: ${valLen}` + ) + headDim = (keyLen + valLen) + } else { + // Fall back to embedding_length estimation + const embeddingLen = Number(meta[`${arch}.embedding_length`]) + if (!embeddingLen) + throw new Error('Invalid metadata: embedding_length not found') + + // Standard transformer: head_dim = embedding_dim / num_heads + // For KV cache: we need both K and V, so 2 * head_dim per head + headDim = (embeddingLen / nHead) * 2 + logger.info( + `Using embedding_length estimation: ${embeddingLen}, calculated head_dim: ${headDim}` + ) + } + let ctxLen: number + if (!ctx_size) { + ctxLen = Number(meta[`${arch}.context_length`]) + } else { + ctxLen = ctx_size + } + + logger.info(`ctxLen: ${ctxLen}`) + logger.info(`nLayer: ${nLayer}`) + logger.info(`nHead: ${nHead}`) + logger.info(`headDim: ${headDim}`) + + // Consider f16 by default + // Can be extended by checking cache-type-v and cache-type-k + // but we are checking overall compatibility with the default settings + // fp16 = 8 bits * 2 = 16 + const bytesPerElement = 2 + + // Total KV cache size per token = nHead * headDim * bytesPerElement + const kvPerToken = nHead * headDim * bytesPerElement + + return ctxLen * nLayer * kvPerToken + } + + private async getModelSize(path: string): Promise { + if (path.startsWith('https://')) { + const res = await fetch(path, { method: 'HEAD' }) + const len = res.headers.get('content-length') + return len ? parseInt(len, 10) : 0 + } else { + return (await fs.fileStat(path)).size + } + } + + /* + * check the support status of a model by its path (local/remote) + * + * * Returns: + * - "RED" → weights don't fit + * - "YELLOW" → weights fit, KV cache doesn't + * - "GREEN" → both weights + KV cache fit + */ + async isModelSupported( + path: string, + ctx_size?: number + ): Promise<'RED' | 'YELLOW' | 'GREEN'> { + try { + const modelSize = await this.getModelSize(path) + logger.info(`modelSize: ${modelSize}`) + let gguf: GgufMetadata + gguf = await readGgufMetadata(path) + let kvCacheSize: number + if (ctx_size) { + kvCacheSize = await this.estimateKVCache(gguf.metadata, ctx_size) + } else { + kvCacheSize = await this.estimateKVCache(gguf.metadata) + } + // total memory consumption = model weights + kvcache + a small buffer for outputs + // output buffer is small so not considering here + const totalRequired = modelSize + kvCacheSize + logger.info( + `isModelSupported: Total memory requirement: ${totalRequired} for ${path}` + ) + let availableMemBytes: number + const devices = await this.getDevices() + if (devices.length > 0) { + // Sum free memory across all GPUs + availableMemBytes = devices + .map((d) => d.free * 1024 * 1024) + .reduce((a, b) => a + b, 0) + } else { + // CPU fallback + const sys = await getSystemUsage() + availableMemBytes = (sys.total_memory - sys.used_memory) * 1024 * 1024 + } + // check model size wrt system memory + if (modelSize > availableMemBytes) { + return 'RED' + } else if (modelSize + kvCacheSize > availableMemBytes) { + return 'YELLOW' + } else { + return 'GREEN' + } + } catch (e) { + throw new Error(String(e)) + } + } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml b/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml index d4abaf88d..0defc62b7 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml +++ b/src-tauri/plugins/tauri-plugin-llamacpp/Cargo.toml @@ -23,12 +23,13 @@ sysinfo = "0.34.2" tauri = { version = "2.5.0", default-features = false, features = [] } thiserror = "2.0.12" tokio = { version = "1", features = ["full"] } +reqwest = { version = "0.11", features = ["json", "blocking", "stream"] } # Windows-specific dependencies [target.'cfg(windows)'.dependencies] windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] } -# Unix-specific dependencies +# Unix-specific dependencies [target.'cfg(unix)'.dependencies] nix = { version = "=0.30.1", features = ["signal", "process"] } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs index 5d005a241..ae38f56f3 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/commands.rs @@ -1,8 +1,58 @@ use super::helpers; use super::types::GgufMetadata; +use reqwest; +use std::fs::File; +use std::io::BufReader; /// Read GGUF metadata from a model file #[tauri::command] pub async fn read_gguf_metadata(path: String) -> Result { - helpers::read_gguf_metadata(&path).map_err(|e| format!("Failed to read GGUF metadata: {}", e)) + if path.starts_with("http://") || path.starts_with("https://") { + // Remote: read in 2MB chunks until successful + let client = reqwest::Client::new(); + let chunk_size = 2 * 1024 * 1024; // Fixed 2MB chunks + let max_total_size = 120 * 1024 * 1024; // Don't exceed 120MB total + let mut total_downloaded = 0; + let mut accumulated_data = Vec::new(); + + while total_downloaded < max_total_size { + let start = total_downloaded; + let end = std::cmp::min(start + chunk_size - 1, max_total_size - 1); + + let resp = client + .get(&path) + .header("Range", format!("bytes={}-{}", start, end)) + .send() + .await + .map_err(|e| format!("Failed to fetch chunk {}-{}: {}", start, end, e))?; + + let chunk_data = resp + .bytes() + .await + .map_err(|e| format!("Failed to read chunk response: {}", e))?; + + accumulated_data.extend_from_slice(&chunk_data); + total_downloaded += chunk_data.len(); + + // Try parsing after each chunk + let cursor = std::io::Cursor::new(&accumulated_data); + if let Ok(metadata) = helpers::read_gguf_metadata(cursor) { + return Ok(metadata); + } + + // If we got less data than expected, we've reached EOF + if chunk_data.len() < chunk_size { + break; + } + } + Err("Could not parse GGUF metadata from downloaded data".to_string()) + } else { + // Local: use streaming file reader + let file = + File::open(&path).map_err(|e| format!("Failed to open local file {}: {}", path, e))?; + let reader = BufReader::new(file); + + helpers::read_gguf_metadata(reader) + .map_err(|e| format!("Failed to parse GGUF metadata: {}", e)) + } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs index 245b986a1..b728522c3 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/helpers.rs @@ -1,13 +1,11 @@ use byteorder::{LittleEndian, ReadBytesExt}; use std::convert::TryFrom; -use std::fs::File; use std::io::{self, BufReader, Read, Seek}; -use std::path::Path; use super::types::{GgufMetadata, GgufValueType}; -pub fn read_gguf_metadata>(path: P) -> io::Result { - let mut file = BufReader::new(File::open(path)?); +pub fn read_gguf_metadata(reader: R) -> io::Result { + let mut file = BufReader::new(reader); let mut magic = [0u8; 4]; file.read_exact(&mut magic)?; diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 53e47db51..0fa7a4b32 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -17,7 +17,6 @@ import { IconPhoto, IconWorld, IconAtom, - IconEye, IconTool, IconCodeCircle2, IconPlayerStopFilled, @@ -537,7 +536,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { {/* File attachment - show only for models with mmproj */} {hasMmproj && (
@@ -554,20 +553,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { {/*
*/} - {selectedModel?.capabilities?.includes('vision') && ( - - - -
- -
-
- -

{t('vision')}

-
-
-
- )} {selectedModel?.capabilities?.includes('embeddings') && ( diff --git a/web-app/src/containers/DropdownModelProvider.tsx b/web-app/src/containers/DropdownModelProvider.tsx index ac2bbc535..005580890 100644 --- a/web-app/src/containers/DropdownModelProvider.tsx +++ b/web-app/src/containers/DropdownModelProvider.tsx @@ -14,12 +14,16 @@ import { route } from '@/constants/routes' import { useThreads } from '@/hooks/useThreads' import { ModelSetting } from '@/containers/ModelSetting' import ProvidersAvatar from '@/containers/ProvidersAvatar' +import { ModelSupportStatus } from '@/containers/ModelSupportStatus' import { Fzf } from 'fzf' import { localStorageKey } from '@/constants/localStorage' import { useTranslation } from '@/i18n/react-i18next-compat' import { useFavoriteModel } from '@/hooks/useFavoriteModel' import { predefinedProviders } from '@/consts/providers' -import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models' +import { + checkMmprojExistsAndUpdateOffloadMMprojSetting, + checkMmprojExists, +} from '@/services/models' type DropdownModelProviderProps = { model?: ThreadModel @@ -91,6 +95,50 @@ const DropdownModelProvider = ({ [providers] ) + // Helper function to get context size from model settings + const getContextSize = useCallback((): number => { + if (!selectedModel?.settings?.ctx_len?.controller_props?.value) { + return 8192 // Default context size + } + return selectedModel.settings.ctx_len.controller_props.value as number + }, [selectedModel?.settings?.ctx_len?.controller_props?.value]) + + // Function to check if a llamacpp model has vision capabilities and update model capabilities + const checkAndUpdateModelVisionCapability = useCallback( + async (modelId: string) => { + try { + const hasVision = await checkMmprojExists(modelId) + if (hasVision) { + // Update the model capabilities to include 'vision' + const provider = getProviderByName('llamacpp') + if (provider) { + const modelIndex = provider.models.findIndex( + (m) => m.id === modelId + ) + if (modelIndex !== -1) { + const model = provider.models[modelIndex] + const capabilities = model.capabilities || [] + + // Add 'vision' capability if not already present + if (!capabilities.includes('vision')) { + const updatedModels = [...provider.models] + updatedModels[modelIndex] = { + ...model, + capabilities: [...capabilities, 'vision'], + } + + updateProvider('llamacpp', { models: updatedModels }) + } + } + } + } + } catch (error) { + console.debug('Error checking mmproj for model:', modelId, error) + } + }, + [getProviderByName, updateProvider] + ) + // Initialize model provider only once useEffect(() => { const initializeModel = async () => { @@ -107,6 +155,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(model.id as string) } } else if (useLastUsedModel) { // Try to use last used model only when explicitly requested (for new chat) @@ -119,6 +169,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(lastUsed.model) } } else { selectModelProvider('', '') @@ -136,6 +188,7 @@ const DropdownModelProvider = ({ checkModelExists, updateProvider, getProviderByName, + checkAndUpdateModelVisionCapability, ]) // Update display model when selection changes @@ -147,6 +200,25 @@ const DropdownModelProvider = ({ } }, [selectedProvider, selectedModel, t]) + // Check vision capabilities for all llamacpp models + useEffect(() => { + const checkAllLlamacppModelsForVision = async () => { + const llamacppProvider = providers.find( + (p) => p.provider === 'llamacpp' && p.active + ) + if (llamacppProvider) { + const checkPromises = llamacppProvider.models.map((model) => + checkAndUpdateModelVisionCapability(model.id) + ) + await Promise.allSettled(checkPromises) + } + } + + if (open) { + checkAllLlamacppModelsForVision() + } + }, [open, providers, checkAndUpdateModelVisionCapability]) + // Reset search value when dropdown closes const onOpenChange = useCallback((open: boolean) => { setOpen(open) @@ -287,6 +359,8 @@ const DropdownModelProvider = ({ updateProvider, getProviderByName ) + // Also check vision capability + await checkAndUpdateModelVisionCapability(searchableModel.model.id) } // Store the selected model as last used @@ -305,6 +379,7 @@ const DropdownModelProvider = ({ useLastUsedModel, updateProvider, getProviderByName, + checkAndUpdateModelVisionCapability, ] ) @@ -318,7 +393,7 @@ const DropdownModelProvider = ({ return ( -
+
+ onCheckModelSupport: (variant: ModelQuant) => void + children?: React.ReactNode +} + +export const ModelInfoHoverCard = ({ + model, + variant, + defaultModelQuantizations, + modelSupportStatus, + onCheckModelSupport, + children, +}: ModelInfoHoverCardProps) => { + const isVariantMode = !!variant + const displayVariant = + variant || + model.quants.find((m) => + defaultModelQuantizations.some((e) => + m.model_id.toLowerCase().includes(e) + ) + ) || + model.quants?.[0] + + const handleMouseEnter = () => { + if (displayVariant) { + onCheckModelSupport(displayVariant) + } + } + + const getCompatibilityStatus = () => { + const status = displayVariant + ? modelSupportStatus[displayVariant.model_id] + : null + + if (status === 'LOADING') { + return ( +
+
+ Checking... +
+ ) + } else if (status === 'GREEN') { + return ( +
+
+ + Recommended for your device + +
+ ) + } else if (status === 'YELLOW') { + return ( +
+
+ + May be slow on your device + +
+ ) + } else if (status === 'RED') { + return ( +
+
+ + May be incompatible with your device + +
+ ) + } else { + return ( +
+
+ Unknown +
+ ) + } + } + + return ( + + + {children || ( +
+ +
+ )} +
+ +
+ {/* Header */} +
+

+ {isVariantMode ? variant.model_id : model.model_name} +

+

+ {isVariantMode + ? 'Model Variant Information' + : 'Model Information'} +

+
+ + {/* Main Info Grid */} +
+
+ {isVariantMode ? ( + <> +
+ + File Size + + + {variant.file_size} + +
+
+ + Quantization + + + {variant.model_id.split('-').pop()?.toUpperCase() || + 'N/A'} + +
+ + ) : ( + <> +
+ + Downloads + + + {model.downloads?.toLocaleString() || '0'} + +
+
+ Variants + + {model.quants?.length || 0} + +
+ + )} +
+ +
+ {!isVariantMode && ( +
+ + Default Size + + + {displayVariant?.file_size || 'N/A'} + +
+ )} +
+ + Compatibility + +
+ {getCompatibilityStatus()} +
+
+
+
+ + {/* Features Section */} + {(model.num_mmproj > 0 || model.tools) && ( +
+
+ Features +
+
+ {model.num_mmproj > 0 && ( +
+ + Vision + +
+ )} + {model.tools && ( +
+ + Tools + +
+ )} +
+
+ )} + + {/* Content Section */} +
+
+ {isVariantMode ? 'Download URL' : 'Description'} +
+
+ {isVariantMode ? ( +
{variant.path}
+ ) : ( + extractDescription(model?.description) || + 'No description available' + )} +
+
+
+
+
+ ) +} diff --git a/web-app/src/containers/ModelSupportStatus.tsx b/web-app/src/containers/ModelSupportStatus.tsx new file mode 100644 index 000000000..3667f4461 --- /dev/null +++ b/web-app/src/containers/ModelSupportStatus.tsx @@ -0,0 +1,142 @@ +import { useCallback, useEffect, useState } from 'react' +import { cn } from '@/lib/utils' +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from '@/components/ui/tooltip' +import { isModelSupported } from '@/services/models' +import { getJanDataFolderPath, joinPath } from '@janhq/core' + +interface ModelSupportStatusProps { + modelId: string | undefined + provider: string | undefined + contextSize: number + className?: string +} + +export const ModelSupportStatus = ({ + modelId, + provider, + contextSize, + className, +}: ModelSupportStatusProps) => { + const [modelSupportStatus, setModelSupportStatus] = useState< + 'RED' | 'YELLOW' | 'GREEN' | 'LOADING' | null + >(null) + + // Helper function to check model support with proper path resolution + const checkModelSupportWithPath = useCallback( + async ( + id: string, + ctxSize: number + ): Promise<'RED' | 'YELLOW' | 'GREEN'> => { + try { + // Get Jan's data folder path and construct the full model file path + // Following the llamacpp extension structure: /llamacpp/models//model.gguf + const janDataFolder = await getJanDataFolderPath() + const modelFilePath = await joinPath([ + janDataFolder, + 'llamacpp', + 'models', + id, + 'model.gguf', + ]) + + return await isModelSupported(modelFilePath, ctxSize) + } catch (error) { + console.error( + 'Error checking model support with constructed path:', + error + ) + // If path construction or model support check fails, assume not supported + return 'RED' + } + }, + [] + ) + + // Helper function to get icon color based on model support status + const getStatusColor = (): string => { + switch (modelSupportStatus) { + case 'GREEN': + return 'bg-green-500' + case 'YELLOW': + return 'bg-yellow-500' + case 'RED': + return 'bg-red-500' + case 'LOADING': + return 'bg-main-view-fg/50' + default: + return 'bg-main-view-fg/50' + } + } + + // Helper function to get tooltip text based on model support status + const getStatusTooltip = (): string => { + switch (modelSupportStatus) { + case 'GREEN': + return `Works Well on your device (ctx: ${contextSize})` + case 'YELLOW': + return `Might work on your device (ctx: ${contextSize})` + case 'RED': + return `Doesn't work on your device (ctx: ${contextSize})` + case 'LOADING': + return 'Checking device compatibility...' + default: + return 'Unknown' + } + } + + // Check model support when model changes + useEffect(() => { + const checkModelSupport = async () => { + if (modelId && provider === 'llamacpp') { + // Set loading state immediately + setModelSupportStatus('LOADING') + try { + const supportStatus = await checkModelSupportWithPath( + modelId, + contextSize + ) + setModelSupportStatus(supportStatus) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus('RED') + } + } else { + // Only show status for llamacpp models since isModelSupported is specific to llamacpp + setModelSupportStatus(null) + } + } + + checkModelSupport() + }, [modelId, provider, contextSize, checkModelSupportWithPath]) + + // Don't render anything if no status or not llamacpp + if (!modelSupportStatus || provider !== 'llamacpp') { + return null + } + + return ( + + + +
+ + +

{getStatusTooltip()}

+
+ + + ) +} diff --git a/web-app/src/containers/__tests__/ChatInput.test.tsx b/web-app/src/containers/__tests__/ChatInput.test.tsx index 292484006..0580821e5 100644 --- a/web-app/src/containers/__tests__/ChatInput.test.tsx +++ b/web-app/src/containers/__tests__/ChatInput.test.tsx @@ -291,15 +291,6 @@ describe('ChatInput', () => { expect(stopButton).toBeInTheDocument() }) - it('shows capability icons when model supports them', () => { - act(() => { - renderWithRouter() - }) - - // Should show vision icon (rendered as SVG with tabler-icon-eye class) - const visionIcon = document.querySelector('.tabler-icon-eye') - expect(visionIcon).toBeInTheDocument() - }) it('shows model selection dropdown', () => { act(() => { diff --git a/web-app/src/routes/hub/$modelId.tsx b/web-app/src/routes/hub/$modelId.tsx index e5f2f44b3..2d0eecc70 100644 --- a/web-app/src/routes/hub/$modelId.tsx +++ b/web-app/src/routes/hub/$modelId.tsx @@ -20,19 +20,24 @@ import { useModelProvider } from '@/hooks/useModelProvider' import { useDownloadStore } from '@/hooks/useDownloadStore' import { CatalogModel, + ModelQuant, convertHfRepoToCatalogModel, fetchHuggingFaceRepo, pullModelWithMetadata, + isModelSupported, } from '@/services/models' import { Progress } from '@/components/ui/progress' import { Button } from '@/components/ui/button' import { cn } from '@/lib/utils' import { useGeneralSetting } from '@/hooks/useGeneralSetting' +import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard' type SearchParams = { repo: string } +const defaultModelQuantizations = ['iq4_xs', 'q4_k_m'] + export const Route = createFileRoute('/hub/$modelId')({ component: HubModelDetail, validateSearch: (search: Record): SearchParams => ({ @@ -57,6 +62,11 @@ function HubModelDetail() { const [readmeContent, setReadmeContent] = useState('') const [isLoadingReadme, setIsLoadingReadme] = useState(false) + // State for model support status + const [modelSupportStatus, setModelSupportStatus] = useState< + Record + >({}) + useEffect(() => { fetchSources() }, [fetchSources]) @@ -131,6 +141,41 @@ function HubModelDetail() { } } + // Check model support function + const checkModelSupport = useCallback( + async (variant: ModelQuant) => { + const modelKey = variant.model_id + + // Don't check again if already checking or checked + if (modelSupportStatus[modelKey]) { + return + } + + // Set loading state + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'LOADING', + })) + + try { + // Use the HuggingFace path for the model + const modelPath = variant.path + const supported = await isModelSupported(modelPath, 8192) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: supported, + })) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'RED', + })) + } + }, + [modelSupportStatus] + ) + // Extract tags from quants (model variants) const tags = useMemo(() => { if (!modelData?.quants) return [] @@ -318,6 +363,7 @@ function HubModelDetail() { Size + Action @@ -372,7 +418,18 @@ function HubModelDetail() { {variant.file_size} - + + + + {(() => { if (isDownloading && !isDownloaded) { return ( diff --git a/web-app/src/routes/hub/index.tsx b/web-app/src/routes/hub/index.tsx index 93658816b..07dd0f85b 100644 --- a/web-app/src/routes/hub/index.tsx +++ b/web-app/src/routes/hub/index.tsx @@ -31,6 +31,7 @@ import { TooltipProvider, TooltipTrigger, } from '@/components/ui/tooltip' +import { ModelInfoHoverCard } from '@/containers/ModelInfoHoverCard' import Joyride, { CallBackProps, STATUS } from 'react-joyride' import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide' import { @@ -44,6 +45,7 @@ import { pullModelWithMetadata, fetchHuggingFaceRepo, convertHfRepoToCatalogModel, + isModelSupported, } from '@/services/models' import { useDownloadStore } from '@/hooks/useDownloadStore' import { Progress } from '@/components/ui/progress' @@ -97,6 +99,9 @@ function Hub() { const [huggingFaceRepo, setHuggingFaceRepo] = useState( null ) + const [modelSupportStatus, setModelSupportStatus] = useState< + Record + >({}) const [joyrideReady, setJoyrideReady] = useState(false) const [currentStepIndex, setCurrentStepIndex] = useState(0) const addModelSourceTimeoutRef = useRef | null>( @@ -270,6 +275,41 @@ function Hub() { [navigate] ) + const checkModelSupport = useCallback( + async (variant: any) => { + const modelKey = variant.model_id + + // Don't check again if already checking or checked + if (modelSupportStatus[modelKey]) { + return + } + + // Set loading state + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'LOADING', + })) + + try { + // Use the HuggingFace path for the model + const modelPath = variant.path + const supportStatus = await isModelSupported(modelPath, 8192) + + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: supportStatus, + })) + } catch (error) { + console.error('Error checking model support:', error) + setModelSupportStatus((prev) => ({ + ...prev, + [modelKey]: 'RED', + })) + } + }, + [modelSupportStatus] + ) + const DownloadButtonPlaceholder = useMemo(() => { return ({ model }: ModelProps) => { // Check if this is a HuggingFace repository (no quants) @@ -616,6 +656,14 @@ function Hub() { )?.file_size } + @@ -671,45 +719,47 @@ function Hub() { ?.length || 0}
- {filteredModels[virtualItem.index].tools && ( -
- - - -
- -
-
- -

{t('tools')}

-
-
-
-
- )} - {filteredModels[virtualItem.index].num_mmproj > - 0 && ( -
- - - -
- -
-
- -

{t('vision')}

-
-
-
-
- )} +
+ {filteredModels[virtualItem.index].num_mmproj > + 0 && ( +
+ + + +
+ +
+
+ +

{t('vision')}

+
+
+
+
+ )} + {filteredModels[virtualItem.index].tools && ( +
+ + + +
+ +
+
+ +

{t('tools')}

+
+
+
+
+ )} +
{filteredModels[virtualItem.index].quants.length > 1 && (
@@ -744,12 +794,75 @@ function Hub() { (variant) => ( +
+ + {variant.model_id} + + {filteredModels[virtualItem.index] + .num_mmproj > 0 && ( +
+ + + +
+ +
+
+ +

{t('vision')}

+
+
+
+
+ )} + {filteredModels[virtualItem.index] + .tools && ( +
+ + + +
+ +
+
+ +

{t('tools')}

+
+
+
+
+ )} +
+ + } actions={

{variant.file_size}

+ {(() => { const isDownloading = localDownloadingModels.has( diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index 368fd19be..286fb01a4 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -13,6 +13,7 @@ import { stopModel, stopAllModels, startModel, + isModelSupported, HuggingFaceRepo, CatalogModel, } from '../models' @@ -845,4 +846,95 @@ describe('models service', () => { expect(result.quants[0].file_size).toBe('Unknown size') }) }) + + describe('isModelSupported', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should return GREEN when model is fully supported', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('GREEN'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/model.gguf', 4096) + + expect(result).toBe('GREEN') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/model.gguf', + 4096 + ) + }) + + it('should return YELLOW when model weights fit but KV cache does not', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('YELLOW'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/model.gguf', 8192) + + expect(result).toBe('YELLOW') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/model.gguf', + 8192 + ) + }) + + it('should return RED when model is not supported', async () => { + const mockEngineWithSupport = { + ...mockEngine, + isModelSupported: vi.fn().mockResolvedValue('RED'), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithSupport) + + const result = await isModelSupported('/path/to/large-model.gguf') + + expect(result).toBe('RED') + expect(mockEngineWithSupport.isModelSupported).toHaveBeenCalledWith( + '/path/to/large-model.gguf', + undefined + ) + }) + + it('should return YELLOW as fallback when engine method is not available', async () => { + const mockEngineWithoutSupport = { + ...mockEngine, + // isModelSupported method not available + } + + mockEngineManager.get.mockReturnValue(mockEngineWithoutSupport) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('YELLOW') + }) + + it('should return RED when engine is not available', async () => { + mockEngineManager.get.mockReturnValue(null) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('YELLOW') // Should use fallback + }) + + it('should return RED when there is an error', async () => { + const mockEngineWithError = { + ...mockEngine, + isModelSupported: vi.fn().mockRejectedValue(new Error('Test error')), + } + + mockEngineManager.get.mockReturnValue(mockEngineWithError) + + const result = await isModelSupported('/path/to/model.gguf') + + expect(result).toBe('RED') + }) + }) }) diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index 0edfe165a..df7dacf00 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -579,3 +579,35 @@ export const checkMmprojExists = async (modelId: string): Promise => { } return false } + +/** + * Checks if a model is supported by analyzing memory requirements and system resources. + * @param modelPath - The path to the model file (local path or URL) + * @param ctxSize - The context size for the model (default: 4096) + * @returns Promise<'RED' | 'YELLOW' | 'GREEN'> - Support status: + * - 'RED': Model weights don't fit in available memory + * - 'YELLOW': Model weights fit, but KV cache doesn't + * - 'GREEN': Both model weights and KV cache fit in available memory + */ +export const isModelSupported = async ( + modelPath: string, + ctxSize?: number +): Promise<'RED' | 'YELLOW' | 'GREEN'> => { + try { + const engine = getEngine('llamacpp') as AIEngine & { + isModelSupported?: ( + path: string, + ctx_size?: number + ) => Promise<'RED' | 'YELLOW' | 'GREEN'> + } + if (engine && typeof engine.isModelSupported === 'function') { + return await engine.isModelSupported(modelPath, ctxSize) + } + // Fallback if method is not available + console.warn('isModelSupported method not available in llamacpp engine') + return 'YELLOW' // Conservative fallback + } catch (error) { + console.error(`Error checking model support for ${modelPath}:`, error) + return 'RED' // Error state, assume not supported + } +}