diff --git a/extensions/llamacpp-extension/package.json b/extensions/llamacpp-extension/package.json index b5db33c5e..585365130 100644 --- a/extensions/llamacpp-extension/package.json +++ b/extensions/llamacpp-extension/package.json @@ -31,6 +31,7 @@ "@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware", "@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp", "@tauri-apps/api": "^2.5.0", + "@tauri-apps/plugin-http": "^2.5.1", "@tauri-apps/plugin-log": "^2.6.0", "fetch-retry": "^5.0.6", "ulidx": "^2.3.0" diff --git a/extensions/llamacpp-extension/rolldown.config.mjs b/extensions/llamacpp-extension/rolldown.config.mjs index 86b6798d7..64f92f29a 100644 --- a/extensions/llamacpp-extension/rolldown.config.mjs +++ b/extensions/llamacpp-extension/rolldown.config.mjs @@ -17,4 +17,7 @@ export default defineConfig({ IS_MAC: JSON.stringify(process.platform === 'darwin'), IS_LINUX: JSON.stringify(process.platform === 'linux'), }, + inject: { + fetch: ['@tauri-apps/plugin-http', 'fetch'], + }, }) diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index f4ad82f95..cd1fa4534 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -41,6 +41,7 @@ type LlamacppConfig = { auto_unload: boolean chat_template: string n_gpu_layers: number + offload_mmproj: boolean override_tensor_buffer_t: string ctx_size: number threads: number @@ -1222,6 +1223,10 @@ export default class llamacpp_extension extends AIEngine { // Takes a regex with matching tensor name as input if (cfg.override_tensor_buffer_t) args.push('--override-tensor', cfg.override_tensor_buffer_t) + // offload multimodal projector model to the GPU by default. if there is not enough memory + // turn this setting off will keep the projector model on the CPU but the image processing can + // take longer + if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload') args.push('-a', modelId) args.push('--port', String(port)) if (modelConfig.mmproj_path) { @@ -1383,7 +1388,8 @@ export default class llamacpp_extension extends AIEngine { method: 'POST', headers, body, - signal: abortController?.signal, + connectTimeout: 600000, // 10 minutes + signal: AbortSignal.any([AbortSignal.timeout(600000), abortController?.signal]), }) if (!response.ok) { const errorData = await response.json().catch(() => null) @@ -1542,6 +1548,26 @@ export default class llamacpp_extension extends AIEngine { } } + /** + * Check if mmproj.gguf file exists for a given model ID + * @param modelId - The model ID to check for mmproj.gguf + * @returns Promise - true if mmproj.gguf exists, false otherwise + */ + async checkMmprojExists(modelId: string): Promise { + try { + const mmprojPath = await joinPath([ + await this.getProviderPath(), + 'models', + modelId, + 'mmproj.gguf', + ]) + return await fs.existsSync(mmprojPath) + } catch (e) { + logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e) + return false + } + } + async getDevices(): Promise { const cfg = this.config const [version, backend] = cfg.version_backend.split('/') @@ -1644,4 +1670,18 @@ export default class llamacpp_extension extends AIEngine { 'tokenizer.chat_template' ]?.includes('tools') } + + private async loadMetadata(path: string): Promise { + try { + const data = await invoke( + 'plugin:llamacpp|read_gguf_metadata', + { + path: path, + } + ) + return data + } catch (err) { + throw err + } + } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs index 16590491e..35dc35c5e 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs @@ -12,7 +12,7 @@ use tokio::time::Instant; use crate::device::{get_devices_from_backend, DeviceInfo}; use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult}; -use crate::path::{validate_binary_path, validate_model_path}; +use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path}; use crate::process::{ find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids, get_random_available_port, is_process_running_by_pid, @@ -55,6 +55,7 @@ pub async fn load_llama_model( let port = parse_port_from_args(&args); let model_path_pb = validate_model_path(&mut args)?; + let _mmproj_path_pb = validate_mmproj_path(&mut args)?; let api_key: String; diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/path.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/path.rs index 44ed00109..a62fb069a 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/path.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/path.rs @@ -98,3 +98,50 @@ pub fn validate_model_path(args: &mut Vec) -> ServerResult { Ok(model_path_pb) } + +/// Validate mmproj path exists and update args with platform-appropriate path format +pub fn validate_mmproj_path(args: &mut Vec) -> ServerResult> { + let mmproj_path_index = match args.iter().position(|arg| arg == "--mmproj") { + Some(index) => index, + None => return Ok(None), // mmproj is optional + }; + + let mmproj_path = args.get(mmproj_path_index + 1).cloned().ok_or_else(|| { + LlamacppError::new( + ErrorCode::ModelLoadFailed, + "Mmproj path was not provided after '--mmproj' flag.".into(), + None, + ) + })?; + + let mmproj_path_pb = PathBuf::from(&mmproj_path); + if !mmproj_path_pb.exists() { + let err_msg = format!( + "Invalid or inaccessible mmproj path: {}", + mmproj_path_pb.display() + ); + log::error!("{}", &err_msg); + return Err(LlamacppError::new( + ErrorCode::ModelFileNotFound, + "The specified mmproj file does not exist or is not accessible.".into(), + Some(err_msg), + ) + .into()); + } + + #[cfg(windows)] + { + // use short path on Windows + if let Some(short) = get_short_path(&mmproj_path_pb) { + args[mmproj_path_index + 1] = short; + } else { + args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string(); + } + } + #[cfg(not(windows))] + { + args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string(); + } + + Ok(Some(mmproj_path_pb)) +} diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index c2e37e483..c5dcb9c1b 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -35,7 +35,8 @@ "effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"], "state": "active", "radius": 8 - } + }, + "dragDropEnabled": false } ], "security": { diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index c6360253e..5383a170e 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -1,7 +1,7 @@ 'use client' import TextareaAutosize from 'react-textarea-autosize' -import { cn, toGigabytes } from '@/lib/utils' +import { cn } from '@/lib/utils' import { usePrompt } from '@/hooks/usePrompt' import { useThreads } from '@/hooks/useThreads' import { useCallback, useEffect, useRef, useState } from 'react' @@ -14,7 +14,7 @@ import { } from '@/components/ui/tooltip' import { ArrowRight } from 'lucide-react' import { - IconPaperclip, + IconPhoto, IconWorld, IconAtom, IconEye, @@ -34,6 +34,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider' import { ModelLoader } from '@/containers/loaders/ModelLoader' import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable' import { getConnectedServers } from '@/services/mcp' +import { checkMmprojExists } from '@/services/models' type ChatInputProps = { className?: string @@ -60,7 +61,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { const maxRows = 10 - const { selectedModel } = useModelProvider() + const { selectedModel, selectedProvider } = useModelProvider() const { sendMessage } = useChat() const [message, setMessage] = useState('') const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false) @@ -75,6 +76,8 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { }> >([]) const [connectedServers, setConnectedServers] = useState([]) + const [isDragOver, setIsDragOver] = useState(false) + const [hasMmproj, setHasMmproj] = useState(false) // Check for connected MCP servers useEffect(() => { @@ -96,6 +99,29 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { return () => clearInterval(intervalId) }, []) + // Check for mmproj existence or vision capability when model changes + useEffect(() => { + const checkMmprojSupport = async () => { + if (selectedModel?.id) { + try { + // Only check mmproj for llamacpp provider + if (selectedProvider === 'llamacpp') { + const hasLocalMmproj = await checkMmprojExists(selectedModel.id) + setHasMmproj(hasLocalMmproj) + } else { + // For non-llamacpp providers, only check vision capability + setHasMmproj(true) + } + } catch (error) { + console.error('Error checking mmproj:', error) + setHasMmproj(false) + } + } + } + + checkMmprojSupport() + }, [selectedModel?.id, selectedProvider]) + // Check if there are active MCP servers const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0 @@ -104,11 +130,16 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { setMessage('Please select a model to start chatting.') return } - if (!prompt.trim()) { + if (!prompt.trim() && uploadedFiles.length === 0) { return } setMessage('') - sendMessage(prompt) + sendMessage( + prompt, + true, + uploadedFiles.length > 0 ? uploadedFiles : undefined + ) + setUploadedFiles([]) } useEffect(() => { @@ -191,8 +222,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { return 'image/jpeg' case 'png': return 'image/png' - case 'pdf': - return 'application/pdf' default: return '' } @@ -226,17 +255,12 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { const detectedType = file.type || getFileTypeFromExtension(file.name) const actualType = getFileTypeFromExtension(file.name) || detectedType - // Check file type - const allowedTypes = [ - 'image/jpg', - 'image/jpeg', - 'image/png', - 'application/pdf', - ] + // Check file type - images only + const allowedTypes = ['image/jpg', 'image/jpeg', 'image/png'] if (!allowedTypes.includes(actualType)) { setMessage( - `File is not supported. Only JPEG, JPG, PNG, and PDF files are allowed.` + `File attachments not supported currently. Only JPEG, JPG, and PNG files are allowed.` ) // Reset file input to allow re-uploading if (fileInputRef.current) { @@ -287,6 +311,104 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { } } + const handleDragEnter = (e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + // Only allow drag if model supports mmproj + if (hasMmproj) { + setIsDragOver(true) + } + } + + const handleDragLeave = (e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + // Only set dragOver to false if we're leaving the drop zone entirely + // In Tauri, relatedTarget can be null, so we need to handle that case + const relatedTarget = e.relatedTarget as Node | null + if (!relatedTarget || !e.currentTarget.contains(relatedTarget)) { + setIsDragOver(false) + } + } + + const handleDragOver = (e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + // Ensure drag state is maintained during drag over + if (hasMmproj) { + setIsDragOver(true) + } + } + + const handleDrop = (e: React.DragEvent) => { + e.preventDefault() + e.stopPropagation() + setIsDragOver(false) + + // Only allow drop if model supports mmproj + if (!hasMmproj) { + return + } + + // Check if dataTransfer exists (it might not in some Tauri scenarios) + if (!e.dataTransfer) { + console.warn('No dataTransfer available in drop event') + return + } + + const files = e.dataTransfer.files + if (files && files.length > 0) { + // Create a synthetic event to reuse existing file handling logic + const syntheticEvent = { + target: { + files: files, + }, + } as React.ChangeEvent + + handleFileChange(syntheticEvent) + } + } + + const handlePaste = (e: React.ClipboardEvent) => { + const clipboardItems = e.clipboardData?.items + if (!clipboardItems) return + + // Only allow paste if model supports mmproj + if (!hasMmproj) { + return + } + + const imageItems = Array.from(clipboardItems).filter((item) => + item.type.startsWith('image/') + ) + + if (imageItems.length > 0) { + e.preventDefault() + + const files: File[] = [] + let processedCount = 0 + + imageItems.forEach((item) => { + const file = item.getAsFile() + if (file) { + files.push(file) + } + processedCount++ + + // When all items are processed, handle the valid files + if (processedCount === imageItems.length && files.length > 0) { + const syntheticEvent = { + target: { + files: files, + }, + } as unknown as React.ChangeEvent + + handleFileChange(syntheticEvent) + } + }) + } + } + return (
@@ -311,8 +433,14 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
{uploadedFiles.length > 0 && (
@@ -332,25 +460,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { alt={`${file.name} - ${index}`} /> )} - {file.type === 'application/pdf' && ( -
-
-
- - {file.name.split('.').pop()} - -
-
-
- {file.name} -
-

- {toGigabytes(file.size)} -

-
-
-
- )}
handleRemoveFile(index)} @@ -369,7 +478,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { rows={1} maxRows={10} value={prompt} - data-test-id={'chat-input'} + data-testid={'chat-input'} onChange={(e) => { setPrompt(e.target.value) // Count the number of newlines to estimate rows @@ -378,14 +487,21 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { }} onKeyDown={(e) => { // e.keyCode 229 is for IME input with Safari - const isComposing = e.nativeEvent.isComposing || e.keyCode === 229; - if (e.key === 'Enter' && !e.shiftKey && prompt.trim() && !isComposing) { + const isComposing = + e.nativeEvent.isComposing || e.keyCode === 229 + if ( + e.key === 'Enter' && + !e.shiftKey && + prompt.trim() && + !isComposing + ) { e.preventDefault() // Submit the message when Enter is pressed without Shift handleSendMesage(prompt) // When Shift+Enter is pressed, a new line is added (default behavior) } }} + onPaste={hasMmproj ? handlePaste : undefined} placeholder={t('common:placeholder.chatInput')} autoFocus spellCheck={spellCheckChatInput} @@ -406,7 +522,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
@@ -418,19 +534,22 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { useLastUsedModel={initialMessage} /> )} - {/* File attachment - always available */} -
- - -
+ {/* File attachment - show only for models with mmproj */} + {hasMmproj && ( +
+ + +
+ )} {/* Microphone - always available - Temp Hide */} {/*
@@ -574,9 +693,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => { ) : (
+ {message && (
diff --git a/web-app/src/containers/DropdownModelProvider.tsx b/web-app/src/containers/DropdownModelProvider.tsx index 1425061d9..ac2bbc535 100644 --- a/web-app/src/containers/DropdownModelProvider.tsx +++ b/web-app/src/containers/DropdownModelProvider.tsx @@ -19,6 +19,7 @@ import { localStorageKey } from '@/constants/localStorage' import { useTranslation } from '@/i18n/react-i18next-compat' import { useFavoriteModel } from '@/hooks/useFavoriteModel' import { predefinedProviders } from '@/consts/providers' +import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models' type DropdownModelProviderProps = { model?: ThreadModel @@ -66,6 +67,7 @@ const DropdownModelProvider = ({ getModelBy, selectedProvider, selectedModel, + updateProvider, } = useModelProvider() const [displayModel, setDisplayModel] = useState('') const { updateCurrentThreadModel } = useThreads() @@ -79,31 +81,52 @@ const DropdownModelProvider = ({ const searchInputRef = useRef(null) // Helper function to check if a model exists in providers - const checkModelExists = useCallback((providerName: string, modelId: string) => { - const provider = providers.find( - (p) => p.provider === providerName && p.active - ) - return provider?.models.find((m) => m.id === modelId) - }, [providers]) + const checkModelExists = useCallback( + (providerName: string, modelId: string) => { + const provider = providers.find( + (p) => p.provider === providerName && p.active + ) + return provider?.models.find((m) => m.id === modelId) + }, + [providers] + ) // Initialize model provider only once useEffect(() => { - // Auto select model when existing thread is passed - if (model) { - selectModelProvider(model?.provider as string, model?.id as string) - if (!checkModelExists(model.provider, model.id)) { - selectModelProvider('', '') - } - } else if (useLastUsedModel) { - // Try to use last used model only when explicitly requested (for new chat) - const lastUsed = getLastUsedModel() - if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) { - selectModelProvider(lastUsed.provider, lastUsed.model) - } else { - // Fallback to default model if last used model no longer exists - selectModelProvider('', '') + const initializeModel = async () => { + // Auto select model when existing thread is passed + if (model) { + selectModelProvider(model?.provider as string, model?.id as string) + if (!checkModelExists(model.provider, model.id)) { + selectModelProvider('', '') + } + // Check mmproj existence for llamacpp models + if (model?.provider === 'llamacpp') { + await checkMmprojExistsAndUpdateOffloadMMprojSetting( + model.id as string, + updateProvider, + getProviderByName + ) + } + } else if (useLastUsedModel) { + // Try to use last used model only when explicitly requested (for new chat) + const lastUsed = getLastUsedModel() + if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) { + selectModelProvider(lastUsed.provider, lastUsed.model) + if (lastUsed.provider === 'llamacpp') { + await checkMmprojExistsAndUpdateOffloadMMprojSetting( + lastUsed.model, + updateProvider, + getProviderByName + ) + } + } else { + selectModelProvider('', '') + } } } + + initializeModel() }, [ model, selectModelProvider, @@ -111,6 +134,8 @@ const DropdownModelProvider = ({ providers, useLastUsedModel, checkModelExists, + updateProvider, + getProviderByName, ]) // Update display model when selection changes @@ -245,7 +270,7 @@ const DropdownModelProvider = ({ }, [filteredItems, providers, searchValue, favoriteModels]) const handleSelect = useCallback( - (searchableModel: SearchableModel) => { + async (searchableModel: SearchableModel) => { selectModelProvider( searchableModel.provider.provider, searchableModel.model.id @@ -254,6 +279,16 @@ const DropdownModelProvider = ({ id: searchableModel.model.id, provider: searchableModel.provider.provider, }) + + // Check mmproj existence for llamacpp models + if (searchableModel.provider.provider === 'llamacpp') { + await checkMmprojExistsAndUpdateOffloadMMprojSetting( + searchableModel.model.id, + updateProvider, + getProviderByName + ) + } + // Store the selected model as last used if (useLastUsedModel) { setLastUsedModel( @@ -264,7 +299,13 @@ const DropdownModelProvider = ({ setSearchValue('') setOpen(false) }, - [selectModelProvider, updateCurrentThreadModel, useLastUsedModel] + [ + selectModelProvider, + updateCurrentThreadModel, + useLastUsedModel, + updateProvider, + getProviderByName, + ] ) const currentModel = selectedModel?.id diff --git a/web-app/src/containers/ModelSetting.tsx b/web-app/src/containers/ModelSetting.tsx index 29d996382..b3bb55e40 100644 --- a/web-app/src/containers/ModelSetting.tsx +++ b/web-app/src/containers/ModelSetting.tsx @@ -70,8 +70,8 @@ export function ModelSetting({ models: updatedModels, }) - // Call debounced stopModel only when updating ctx_len or ngl - if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template') { + // Call debounced stopModel only when updating ctx_len, ngl, chat_template, or offload_mmproj + if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template' || key === 'offload_mmproj') { debouncedStopModel(model.id) } } diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index 54ba342cb..a5a872b3e 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { ThreadMessage } from '@janhq/core' import { RenderMarkdown } from './RenderMarkdown' import React, { Fragment, memo, useCallback, useMemo, useState } from 'react' @@ -144,7 +145,7 @@ export const ThreadContent = memo( isLastMessage?: boolean index?: number showAssistant?: boolean - // eslint-disable-next-line @typescript-eslint/no-explicit-any + streamTools?: any contextOverflowModal?: React.ReactNode | null updateMessage?: (item: ThreadMessage, message: string) => void @@ -172,9 +173,12 @@ export const ThreadContent = memo( const { reasoningSegment, textSegment } = useMemo(() => { // Check for thinking formats const hasThinkTag = text.includes('') && !text.includes('') - const hasAnalysisChannel = text.includes('<|channel|>analysis<|message|>') && !text.includes('<|start|>assistant<|channel|>final<|message|>') - - if (hasThinkTag || hasAnalysisChannel) return { reasoningSegment: text, textSegment: '' } + const hasAnalysisChannel = + text.includes('<|channel|>analysis<|message|>') && + !text.includes('<|start|>assistant<|channel|>final<|message|>') + + if (hasThinkTag || hasAnalysisChannel) + return { reasoningSegment: text, textSegment: '' } // Check for completed think tag format const thinkMatch = text.match(/([\s\S]*?)<\/think>/) @@ -187,7 +191,9 @@ export const ThreadContent = memo( } // Check for completed analysis channel format - const analysisMatch = text.match(/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/) + const analysisMatch = text.match( + /<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/ + ) if (analysisMatch?.index !== undefined) { const splitIndex = analysisMatch.index + analysisMatch[0].length return { @@ -213,7 +219,36 @@ export const ThreadContent = memo( } if (toSendMessage) { deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '') - sendMessage(toSendMessage.content?.[0]?.text?.value || '') + // Extract text content and any attachments + const textContent = + toSendMessage.content?.find((c) => c.type === 'text')?.text?.value || + '' + const attachments = toSendMessage.content + ?.filter((c) => (c.type === 'image_url' && c.image_url?.url) || false) + .map((c) => { + if (c.type === 'image_url' && c.image_url?.url) { + const url = c.image_url.url + const [mimeType, base64] = url + .replace('data:', '') + .split(';base64,') + return { + name: 'image', // We don't have the original filename + type: mimeType, + size: 0, // We don't have the original size + base64: base64, + dataUrl: url, + } + } + return null + }) + .filter(Boolean) as Array<{ + name: string + type: string + size: number + base64: string + dataUrl: string + }> + sendMessage(textContent, true, attachments) } }, [deleteMessage, getMessages, item, sendMessage]) @@ -255,22 +290,68 @@ export const ThreadContent = memo( return ( - {item.content?.[0]?.text && item.role === 'user' && ( + {item.role === 'user' && (
-
-
-
- + {/* Render attachments above the message bubble */} + {item.content?.some( + (c) => (c.type === 'image_url' && c.image_url?.url) || false + ) && ( +
+
+ {item.content + ?.filter( + (c) => + (c.type === 'image_url' && c.image_url?.url) || false + ) + .map((contentPart, index) => { + // Handle images + if ( + contentPart.type === 'image_url' && + contentPart.image_url?.url + ) { + return ( +
+ Uploaded attachment +
+ ) + } + return null + })}
-
+ )} + + {/* Render text content in the message bubble */} + {item.content?.some((c) => c.type === 'text' && c.text?.value) && ( +
+
+
+ {item.content + ?.filter((c) => c.type === 'text' && c.text?.value) + .map((contentPart, index) => ( +
+ +
+ ))} +
+
+
+ )} +
c.type === 'text')?.text?.value || + '' + } setMessage={(message) => { if (item.updateMessage) { item.updateMessage(item, message) diff --git a/web-app/src/containers/__tests__/ChatInput.test.tsx b/web-app/src/containers/__tests__/ChatInput.test.tsx index 7c1607191..292484006 100644 --- a/web-app/src/containers/__tests__/ChatInput.test.tsx +++ b/web-app/src/containers/__tests__/ChatInput.test.tsx @@ -73,6 +73,11 @@ vi.mock('@/services/mcp', () => ({ vi.mock('@/services/models', () => ({ stopAllModels: vi.fn(), + checkMmprojExists: vi.fn(() => Promise.resolve(true)), +})) + +vi.mock('../MovingBorder', () => ({ + MovingBorder: ({ children }: { children: React.ReactNode }) =>
{children}
, })) describe('ChatInput', () => { @@ -231,7 +236,7 @@ describe('ChatInput', () => { const sendButton = document.querySelector('[data-test-id="send-message-button"]') await user.click(sendButton) - expect(mockSendMessage).toHaveBeenCalledWith('Hello world') + expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined) }) it('sends message when Enter key is pressed', async () => { @@ -248,7 +253,7 @@ describe('ChatInput', () => { const textarea = screen.getByRole('textbox') await user.type(textarea, '{Enter}') - expect(mockSendMessage).toHaveBeenCalledWith('Hello world') + expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined) }) it('does not send message when Shift+Enter is pressed', async () => { @@ -343,9 +348,12 @@ describe('ChatInput', () => { const user = userEvent.setup() renderWithRouter() - // File upload is rendered as hidden input element - const fileInput = document.querySelector('input[type="file"]') - expect(fileInput).toBeInTheDocument() + // Wait for async effects to complete (mmproj check) + await waitFor(() => { + // File upload is rendered as hidden input element + const fileInput = document.querySelector('input[type="file"]') + expect(fileInput).toBeInTheDocument() + }) }) it('disables input when streaming', () => { @@ -361,7 +369,7 @@ describe('ChatInput', () => { renderWithRouter() }) - const textarea = screen.getByRole('textbox') + const textarea = screen.getByTestId('chat-input') expect(textarea).toBeDisabled() }) @@ -378,4 +386,28 @@ describe('ChatInput', () => { expect(toolsIcon).toBeInTheDocument() }) }) + + it('uses selectedProvider for provider checks', () => { + // Test that the component correctly uses selectedProvider instead of selectedModel.provider + vi.mocked(useModelProvider).mockReturnValue({ + selectedModel: { + id: 'test-model', + capabilities: ['vision'], + }, + providers: [], + getModelBy: vi.fn(), + selectModelProvider: vi.fn(), + selectedProvider: 'llamacpp', + setProviders: vi.fn(), + getProviderByName: vi.fn(), + updateProvider: vi.fn(), + addProvider: vi.fn(), + deleteProvider: vi.fn(), + deleteModel: vi.fn(), + deletedModels: [], + }) + + // This test ensures the component renders without errors when using selectedProvider + expect(() => renderWithRouter()).not.toThrow() + }) }) \ No newline at end of file diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index e90d5b1c2..3ffd19318 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -203,7 +203,17 @@ export const useChat = () => { ) const sendMessage = useCallback( - async (message: string, troubleshooting = true) => { + async ( + message: string, + troubleshooting = true, + attachments?: Array<{ + name: string + type: string + size: number + base64: string + dataUrl: string + }> + ) => { const activeThread = await getCurrentThread() resetTokenSpeed() @@ -217,7 +227,7 @@ export const useChat = () => { updateStreamingContent(emptyThreadContent) // Do not add new message on retry if (troubleshooting) - addMessage(newUserThreadContent(activeThread.id, message)) + addMessage(newUserThreadContent(activeThread.id, message, attachments)) updateThreadTimestamp(activeThread.id) setPrompt('') try { @@ -231,7 +241,7 @@ export const useChat = () => { messages, currentAssistant?.instructions ) - if (troubleshooting) builder.addUserMessage(message) + if (troubleshooting) builder.addUserMessage(message, attachments) let isCompleted = false diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index c92a0b096..6f5f6cdab 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { ContentType, ChatCompletionRole, @@ -51,11 +52,16 @@ export type ChatCompletionResponse = */ export const newUserThreadContent = ( threadId: string, - content: string -): ThreadMessage => ({ - type: 'text', - role: ChatCompletionRole.User, - content: [ + content: string, + attachments?: Array<{ + name: string + type: string + size: number + base64: string + dataUrl: string + }> +): ThreadMessage => { + const contentParts = [ { type: ContentType.Text, text: { @@ -63,14 +69,35 @@ export const newUserThreadContent = ( annotations: [], }, }, - ], - id: ulid(), - object: 'thread.message', - thread_id: threadId, - status: MessageStatus.Ready, - created_at: 0, - completed_at: 0, -}) + ] + + // Add attachments to content array + if (attachments) { + attachments.forEach((attachment) => { + if (attachment.type.startsWith('image/')) { + contentParts.push({ + type: ContentType.Image, + image_url: { + url: `data:${attachment.type};base64,${attachment.base64}`, + detail: 'auto', + }, + } as any) + } + }) + } + + return { + type: 'text', + role: ChatCompletionRole.User, + content: contentParts, + id: ulid(), + object: 'thread.message', + thread_id: threadId, + status: MessageStatus.Ready, + created_at: 0, + completed_at: 0, + } +} /** * @fileoverview Helper functions for creating thread content. * These functions are used to create thread content objects @@ -162,13 +189,11 @@ export const sendCompletion = async ( if ( thread.model.id && !Object.values(models[providerName]).flat().includes(thread.model.id) && - // eslint-disable-next-line @typescript-eslint/no-explicit-any !tokenJS.extendedModelExist(providerName as any, thread.model.id) && provider.provider !== 'llamacpp' ) { try { tokenJS.extendModelList( - // eslint-disable-next-line @typescript-eslint/no-explicit-any providerName as any, thread.model.id, // This is to inherit the model capabilities from another built-in model @@ -201,7 +226,7 @@ export const sendCompletion = async ( ? await tokenJS.chat.completions.create( { stream: true, - // eslint-disable-next-line @typescript-eslint/no-explicit-any + provider: providerName as any, model: thread.model?.id, messages, diff --git a/web-app/src/lib/messages.ts b/web-app/src/lib/messages.ts index b187fb514..c7eba13d7 100644 --- a/web-app/src/lib/messages.ts +++ b/web-app/src/lib/messages.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageToolCall } from 'openai/resources' import { ThreadMessage } from '@janhq/core' @@ -19,32 +20,106 @@ export class CompletionMessagesBuilder { this.messages.push( ...messages .filter((e) => !e.metadata?.error) - .map( - (msg) => - ({ + .map((msg) => { + if (msg.role === 'assistant') { + return { role: msg.role, - content: - msg.role === 'assistant' - ? this.normalizeContent(msg.content[0]?.text?.value || '.') - : msg.content[0]?.text?.value || '.', - }) as ChatCompletionMessageParam - ) + content: this.normalizeContent( + msg.content[0]?.text?.value || '.' + ), + } as ChatCompletionMessageParam + } else { + // For user messages, handle multimodal content + if (msg.content.length > 1) { + // Multiple content parts (text + images + files) + + const content = msg.content.map((contentPart) => { + if (contentPart.type === 'text') { + return { + type: 'text', + text: contentPart.text?.value || '', + } + } else if (contentPart.type === 'image_url') { + return { + type: 'image_url', + image_url: { + url: contentPart.image_url?.url || '', + detail: contentPart.image_url?.detail || 'auto', + }, + } + } else { + return contentPart + } + }) + return { + role: msg.role, + content, + } as ChatCompletionMessageParam + } else { + // Single text content + return { + role: msg.role, + content: msg.content[0]?.text?.value || '.', + } as ChatCompletionMessageParam + } + } + }) ) } /** * Add a user message to the messages array. * @param content - The content of the user message. + * @param attachments - Optional attachments for the message. */ - addUserMessage(content: string) { + addUserMessage( + content: string, + attachments?: Array<{ + name: string + type: string + size: number + base64: string + dataUrl: string + }> + ) { // Ensure no consecutive user messages if (this.messages[this.messages.length - 1]?.role === 'user') { this.messages.pop() } - this.messages.push({ - role: 'user', - content: content, - }) + + // Handle multimodal content with attachments + if (attachments && attachments.length > 0) { + const messageContent: any[] = [ + { + type: 'text', + text: content, + }, + ] + + // Add attachments (images and PDFs) + attachments.forEach((attachment) => { + if (attachment.type.startsWith('image/')) { + messageContent.push({ + type: 'image_url', + image_url: { + url: `data:${attachment.type};base64,${attachment.base64}`, + detail: 'auto', + }, + }) + } + }) + + this.messages.push({ + role: 'user', + content: messageContent, + } as any) + } else { + // Text-only message + this.messages.push({ + role: 'user', + content: content, + }) + } } /** diff --git a/web-app/src/routes/__root.tsx b/web-app/src/routes/__root.tsx index 5278f73fc..77d9f9d2b 100644 --- a/web-app/src/routes/__root.tsx +++ b/web-app/src/routes/__root.tsx @@ -26,7 +26,7 @@ import { ResizablePanel, ResizableHandle, } from '@/components/ui/resizable' -import { useCallback } from 'react' +import { useCallback, useEffect } from 'react' import GlobalError from '@/containers/GlobalError' import { GlobalEventHandler } from '@/providers/GlobalEventHandler' @@ -65,6 +65,41 @@ const AppLayout = () => { [setLeftPanelSize, setLeftPanel] ) + // Prevent default drag and drop behavior globally + useEffect(() => { + const preventDefaults = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + } + + const handleGlobalDrop = (e: DragEvent) => { + e.preventDefault() + e.stopPropagation() + + // Only prevent if the target is not within a chat input or other valid drop zone + const target = e.target as Element + const isValidDropZone = target?.closest('[data-drop-zone="true"]') || + target?.closest('.chat-input-drop-zone') || + target?.closest('[data-tauri-drag-region]') + + if (!isValidDropZone) { + // Prevent the file from opening in the window + return false + } + } + + // Add event listeners to prevent default drag/drop behavior + window.addEventListener('dragenter', preventDefaults) + window.addEventListener('dragover', preventDefaults) + window.addEventListener('drop', handleGlobalDrop) + + return () => { + window.removeEventListener('dragenter', preventDefaults) + window.removeEventListener('dragover', preventDefaults) + window.removeEventListener('drop', handleGlobalDrop) + } + }, []) + return ( diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index b783f6ab5..dc30dc54f 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -290,7 +290,7 @@ describe('models service', () => { likes: 100, tags: ['conversational', 'pytorch'], pipeline_tag: 'text-generation', - created_at: '2023-01-01T00:00:00Z', + createdAt: '2023-01-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z', private: false, disabled: false, @@ -443,7 +443,7 @@ describe('models service', () => { likes: 100, tags: ['conversational'], pipeline_tag: 'text-generation', - created_at: '2023-01-01T00:00:00Z', + createdAt: '2023-01-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z', private: false, disabled: false, @@ -471,7 +471,7 @@ describe('models service', () => { likes: 100, tags: ['conversational'], pipeline_tag: 'text-generation', - created_at: '2023-01-01T00:00:00Z', + createdAt: '2023-01-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z', private: false, disabled: false, @@ -510,7 +510,7 @@ describe('models service', () => { likes: 100, tags: ['conversational'], pipeline_tag: 'text-generation', - created_at: '2023-01-01T00:00:00Z', + createdAt: '2023-01-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z', private: false, disabled: false, @@ -559,7 +559,7 @@ describe('models service', () => { likes: 75, tags: ['pytorch', 'transformers', 'text-generation'], pipeline_tag: 'text-generation', - created_at: '2021-01-01T00:00:00Z', + createdAt: '2021-01-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z', private: false, disabled: false, @@ -605,6 +605,8 @@ describe('models service', () => { file_size: '4.0 GB', }, ], + num_mmproj: 0, + mmproj_models: [], created_at: '2021-01-01T00:00:00Z', readme: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md', @@ -820,7 +822,7 @@ describe('models service', () => { downloads: 0, likes: 0, tags: [], - created_at: '2021-01-01T00:00:00Z', + createdAt: '2021-01-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z', private: false, disabled: false, diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index 6f0bda5f9..790620f22 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { sanitizeModelId } from '@/lib/utils' import { AIEngine, @@ -27,6 +28,7 @@ export interface CatalogModel { num_quants: number quants: ModelQuant[] mmproj_models?: MMProjModel[] + num_mmproj: number created_at?: string readme?: string tools?: boolean @@ -44,7 +46,7 @@ export interface HuggingFaceRepo { library_name?: string tags: string[] pipeline_tag?: string - created_at: string + createdAt: string last_modified: string private: boolean disabled: boolean @@ -155,21 +157,30 @@ export const fetchHuggingFaceRepo = async ( export const convertHfRepoToCatalogModel = ( repo: HuggingFaceRepo ): CatalogModel => { + // Format file size helper + const formatFileSize = (size?: number) => { + if (!size) return 'Unknown size' + if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB` + return `${(size / 1024 ** 3).toFixed(1)} GB` + } + // Extract GGUF files from the repository siblings const ggufFiles = repo.siblings?.filter((file) => file.rfilename.toLowerCase().endsWith('.gguf') ) || [] - // Convert GGUF files to quants format - const quants = ggufFiles.map((file) => { - // Format file size - const formatFileSize = (size?: number) => { - if (!size) return 'Unknown size' - if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB` - return `${(size / 1024 ** 3).toFixed(1)} GB` - } + // Separate regular GGUF files from mmproj files + const regularGgufFiles = ggufFiles.filter( + (file) => !file.rfilename.toLowerCase().includes('mmproj') + ) + const mmprojFiles = ggufFiles.filter((file) => + file.rfilename.toLowerCase().includes('mmproj') + ) + + // Convert regular GGUF files to quants format + const quants = regularGgufFiles.map((file) => { // Generate model_id from filename (remove .gguf extension, case-insensitive) const modelId = file.rfilename.replace(/\.gguf$/i, '') @@ -180,15 +191,28 @@ export const convertHfRepoToCatalogModel = ( } }) + // Convert mmproj files to mmproj_models format + const mmprojModels = mmprojFiles.map((file) => { + const modelId = file.rfilename.replace(/\.gguf$/i, '') + + return { + model_id: sanitizeModelId(modelId), + path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`, + file_size: formatFileSize(file.size), + } + }) + return { model_name: repo.modelId, - description: `**Tags**: ${repo.tags?.join(', ')}`, developer: repo.author, downloads: repo.downloads || 0, + created_at: repo.createdAt, num_quants: quants.length, quants: quants, - created_at: repo.created_at, + num_mmproj: mmprojModels.length, + mmproj_models: mmprojModels, readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`, + description: `**Tags**: ${repo.tags?.join(', ')}`, } } @@ -318,8 +342,8 @@ export const startModel = async ( /** * Check if model support tool use capability * Returned by backend engine - * @param modelId - * @returns + * @param modelId + * @returns */ export const isToolSupported = async (modelId: string): Promise => { const engine = getEngine() @@ -327,3 +351,137 @@ export const isToolSupported = async (modelId: string): Promise => { return engine.isToolSupported(modelId) } + +/** + * Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider. + * Also checks if the model has offload_mmproj setting. + * If mmproj.gguf exists, adds offload_mmproj setting with value true. + * @param modelId - The model ID to check for mmproj.gguf + * @param updateProvider - Function to update the provider state + * @param getProviderByName - Function to get provider by name + * @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified + */ +export const checkMmprojExistsAndUpdateOffloadMMprojSetting = async ( + modelId: string, + updateProvider?: (providerName: string, data: Partial) => void, + getProviderByName?: (providerName: string) => ModelProvider | undefined +): Promise<{ exists: boolean; settingsUpdated: boolean }> => { + let settingsUpdated = false + + try { + const engine = getEngine('llamacpp') as AIEngine & { + checkMmprojExists?: (id: string) => Promise + } + if (engine && typeof engine.checkMmprojExists === 'function') { + const exists = await engine.checkMmprojExists(modelId) + + // If we have the store functions, use them; otherwise fall back to localStorage + if (updateProvider && getProviderByName) { + const provider = getProviderByName('llamacpp') + if (provider) { + const model = provider.models.find((m) => m.id === modelId) + + if (model?.settings) { + const hasOffloadMmproj = 'offload_mmproj' in model.settings + + // If mmproj exists, add offload_mmproj setting (only if it doesn't exist) + if (exists && !hasOffloadMmproj) { + // Create updated models array with the new setting + const updatedModels = provider.models.map((m) => { + if (m.id === modelId) { + return { + ...m, + settings: { + ...m.settings, + offload_mmproj: { + key: 'offload_mmproj', + title: 'Offload MMProj', + description: + 'Offload multimodal projection layers to GPU', + controller_type: 'checkbox', + controller_props: { + value: true, + }, + }, + }, + } + } + return m + }) + + // Update the provider with the new models array + updateProvider('llamacpp', { models: updatedModels }) + settingsUpdated = true + } + } + } + } else { + // Fall back to localStorage approach for backwards compatibility + try { + const modelProviderData = JSON.parse( + localStorage.getItem('model-provider') || '{}' + ) + const llamacppProvider = modelProviderData.state?.providers?.find( + (p: any) => p.provider === 'llamacpp' + ) + const model = llamacppProvider?.models?.find( + (m: any) => m.id === modelId + ) + + if (model?.settings) { + // If mmproj exists, add offload_mmproj setting (only if it doesn't exist) + if (exists) { + if (!model.settings.offload_mmproj) { + model.settings.offload_mmproj = { + key: 'offload_mmproj', + title: 'Offload MMProj', + description: 'Offload multimodal projection layers to GPU', + controller_type: 'checkbox', + controller_props: { + value: true, + }, + } + // Save updated settings back to localStorage + localStorage.setItem( + 'model-provider', + JSON.stringify(modelProviderData) + ) + settingsUpdated = true + } + } + } + } catch (localStorageError) { + console.error( + `Error checking localStorage for model ${modelId}:`, + localStorageError + ) + } + } + + return { exists, settingsUpdated } + } + } catch (error) { + console.error(`Error checking mmproj for model ${modelId}:`, error) + } + return { exists: false, settingsUpdated } +} + +/** + * Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider. + * If mmproj.gguf exists, adds offload_mmproj setting with value true. + * @param modelId - The model ID to check for mmproj.gguf + * @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified + */ +export const checkMmprojExists = async (modelId: string): Promise => { + try { + const engine = getEngine('llamacpp') as AIEngine & { + checkMmprojExists?: (id: string) => Promise + } + if (engine && typeof engine.checkMmprojExists === 'function') { + return await engine.checkMmprojExists(modelId) + } + } catch (error) { + console.error(`Error checking mmproj for model ${modelId}:`, error) + } + return false +}