Merge pull request #6134 from menloresearch/feat/attachment-ui

feat: attachment UI
This commit is contained in:
Faisal Amir 2025-08-20 10:04:32 +07:00 committed by GitHub
commit 5481ee9e35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 833 additions and 157 deletions

View File

@ -31,6 +31,7 @@
"@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware", "@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware",
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp", "@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
"@tauri-apps/api": "^2.5.0", "@tauri-apps/api": "^2.5.0",
"@tauri-apps/plugin-http": "^2.5.1",
"@tauri-apps/plugin-log": "^2.6.0", "@tauri-apps/plugin-log": "^2.6.0",
"fetch-retry": "^5.0.6", "fetch-retry": "^5.0.6",
"ulidx": "^2.3.0" "ulidx": "^2.3.0"

View File

@ -17,4 +17,7 @@ export default defineConfig({
IS_MAC: JSON.stringify(process.platform === 'darwin'), IS_MAC: JSON.stringify(process.platform === 'darwin'),
IS_LINUX: JSON.stringify(process.platform === 'linux'), IS_LINUX: JSON.stringify(process.platform === 'linux'),
}, },
inject: {
fetch: ['@tauri-apps/plugin-http', 'fetch'],
},
}) })

View File

@ -41,6 +41,7 @@ type LlamacppConfig = {
auto_unload: boolean auto_unload: boolean
chat_template: string chat_template: string
n_gpu_layers: number n_gpu_layers: number
offload_mmproj: boolean
override_tensor_buffer_t: string override_tensor_buffer_t: string
ctx_size: number ctx_size: number
threads: number threads: number
@ -1222,6 +1223,10 @@ export default class llamacpp_extension extends AIEngine {
// Takes a regex with matching tensor name as input // Takes a regex with matching tensor name as input
if (cfg.override_tensor_buffer_t) if (cfg.override_tensor_buffer_t)
args.push('--override-tensor', cfg.override_tensor_buffer_t) args.push('--override-tensor', cfg.override_tensor_buffer_t)
// offload multimodal projector model to the GPU by default. if there is not enough memory
// turn this setting off will keep the projector model on the CPU but the image processing can
// take longer
if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload')
args.push('-a', modelId) args.push('-a', modelId)
args.push('--port', String(port)) args.push('--port', String(port))
if (modelConfig.mmproj_path) { if (modelConfig.mmproj_path) {
@ -1383,7 +1388,8 @@ export default class llamacpp_extension extends AIEngine {
method: 'POST', method: 'POST',
headers, headers,
body, body,
signal: abortController?.signal, connectTimeout: 600000, // 10 minutes
signal: AbortSignal.any([AbortSignal.timeout(600000), abortController?.signal]),
}) })
if (!response.ok) { if (!response.ok) {
const errorData = await response.json().catch(() => null) const errorData = await response.json().catch(() => null)
@ -1542,6 +1548,26 @@ export default class llamacpp_extension extends AIEngine {
} }
} }
/**
* Check if mmproj.gguf file exists for a given model ID
* @param modelId - The model ID to check for mmproj.gguf
* @returns Promise<boolean> - true if mmproj.gguf exists, false otherwise
*/
async checkMmprojExists(modelId: string): Promise<boolean> {
try {
const mmprojPath = await joinPath([
await this.getProviderPath(),
'models',
modelId,
'mmproj.gguf',
])
return await fs.existsSync(mmprojPath)
} catch (e) {
logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e)
return false
}
}
async getDevices(): Promise<DeviceList[]> { async getDevices(): Promise<DeviceList[]> {
const cfg = this.config const cfg = this.config
const [version, backend] = cfg.version_backend.split('/') const [version, backend] = cfg.version_backend.split('/')
@ -1644,4 +1670,18 @@ export default class llamacpp_extension extends AIEngine {
'tokenizer.chat_template' 'tokenizer.chat_template'
]?.includes('tools') ]?.includes('tools')
} }
private async loadMetadata(path: string): Promise<GgufMetadata> {
try {
const data = await invoke<GgufMetadata>(
'plugin:llamacpp|read_gguf_metadata',
{
path: path,
}
)
return data
} catch (err) {
throw err
}
}
} }

View File

@ -12,7 +12,7 @@ use tokio::time::Instant;
use crate::device::{get_devices_from_backend, DeviceInfo}; use crate::device::{get_devices_from_backend, DeviceInfo};
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult}; use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::{validate_binary_path, validate_model_path}; use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path};
use crate::process::{ use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids, find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid, get_random_available_port, is_process_running_by_pid,
@ -55,6 +55,7 @@ pub async fn load_llama_model<R: Runtime>(
let port = parse_port_from_args(&args); let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?; let model_path_pb = validate_model_path(&mut args)?;
let _mmproj_path_pb = validate_mmproj_path(&mut args)?;
let api_key: String; let api_key: String;

View File

@ -98,3 +98,50 @@ pub fn validate_model_path(args: &mut Vec<String>) -> ServerResult<PathBuf> {
Ok(model_path_pb) Ok(model_path_pb)
} }
/// Validate mmproj path exists and update args with platform-appropriate path format
pub fn validate_mmproj_path(args: &mut Vec<String>) -> ServerResult<Option<PathBuf>> {
let mmproj_path_index = match args.iter().position(|arg| arg == "--mmproj") {
Some(index) => index,
None => return Ok(None), // mmproj is optional
};
let mmproj_path = args.get(mmproj_path_index + 1).cloned().ok_or_else(|| {
LlamacppError::new(
ErrorCode::ModelLoadFailed,
"Mmproj path was not provided after '--mmproj' flag.".into(),
None,
)
})?;
let mmproj_path_pb = PathBuf::from(&mmproj_path);
if !mmproj_path_pb.exists() {
let err_msg = format!(
"Invalid or inaccessible mmproj path: {}",
mmproj_path_pb.display()
);
log::error!("{}", &err_msg);
return Err(LlamacppError::new(
ErrorCode::ModelFileNotFound,
"The specified mmproj file does not exist or is not accessible.".into(),
Some(err_msg),
)
.into());
}
#[cfg(windows)]
{
// use short path on Windows
if let Some(short) = get_short_path(&mmproj_path_pb) {
args[mmproj_path_index + 1] = short;
} else {
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
}
}
#[cfg(not(windows))]
{
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
}
Ok(Some(mmproj_path_pb))
}

View File

@ -35,7 +35,8 @@
"effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"], "effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"],
"state": "active", "state": "active",
"radius": 8 "radius": 8
} },
"dragDropEnabled": false
} }
], ],
"security": { "security": {

View File

@ -1,7 +1,7 @@
'use client' 'use client'
import TextareaAutosize from 'react-textarea-autosize' import TextareaAutosize from 'react-textarea-autosize'
import { cn, toGigabytes } from '@/lib/utils' import { cn } from '@/lib/utils'
import { usePrompt } from '@/hooks/usePrompt' import { usePrompt } from '@/hooks/usePrompt'
import { useThreads } from '@/hooks/useThreads' import { useThreads } from '@/hooks/useThreads'
import { useCallback, useEffect, useRef, useState } from 'react' import { useCallback, useEffect, useRef, useState } from 'react'
@ -14,7 +14,7 @@ import {
} from '@/components/ui/tooltip' } from '@/components/ui/tooltip'
import { ArrowRight } from 'lucide-react' import { ArrowRight } from 'lucide-react'
import { import {
IconPaperclip, IconPhoto,
IconWorld, IconWorld,
IconAtom, IconAtom,
IconEye, IconEye,
@ -34,6 +34,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider'
import { ModelLoader } from '@/containers/loaders/ModelLoader' import { ModelLoader } from '@/containers/loaders/ModelLoader'
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable' import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
import { getConnectedServers } from '@/services/mcp' import { getConnectedServers } from '@/services/mcp'
import { checkMmprojExists } from '@/services/models'
type ChatInputProps = { type ChatInputProps = {
className?: string className?: string
@ -60,7 +61,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const maxRows = 10 const maxRows = 10
const { selectedModel } = useModelProvider() const { selectedModel, selectedProvider } = useModelProvider()
const { sendMessage } = useChat() const { sendMessage } = useChat()
const [message, setMessage] = useState('') const [message, setMessage] = useState('')
const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false) const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false)
@ -75,6 +76,8 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
}> }>
>([]) >([])
const [connectedServers, setConnectedServers] = useState<string[]>([]) const [connectedServers, setConnectedServers] = useState<string[]>([])
const [isDragOver, setIsDragOver] = useState(false)
const [hasMmproj, setHasMmproj] = useState(false)
// Check for connected MCP servers // Check for connected MCP servers
useEffect(() => { useEffect(() => {
@ -96,6 +99,29 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
return () => clearInterval(intervalId) return () => clearInterval(intervalId)
}, []) }, [])
// Check for mmproj existence or vision capability when model changes
useEffect(() => {
const checkMmprojSupport = async () => {
if (selectedModel?.id) {
try {
// Only check mmproj for llamacpp provider
if (selectedProvider === 'llamacpp') {
const hasLocalMmproj = await checkMmprojExists(selectedModel.id)
setHasMmproj(hasLocalMmproj)
} else {
// For non-llamacpp providers, only check vision capability
setHasMmproj(true)
}
} catch (error) {
console.error('Error checking mmproj:', error)
setHasMmproj(false)
}
}
}
checkMmprojSupport()
}, [selectedModel?.id, selectedProvider])
// Check if there are active MCP servers // Check if there are active MCP servers
const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0 const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0
@ -104,11 +130,16 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
setMessage('Please select a model to start chatting.') setMessage('Please select a model to start chatting.')
return return
} }
if (!prompt.trim()) { if (!prompt.trim() && uploadedFiles.length === 0) {
return return
} }
setMessage('') setMessage('')
sendMessage(prompt) sendMessage(
prompt,
true,
uploadedFiles.length > 0 ? uploadedFiles : undefined
)
setUploadedFiles([])
} }
useEffect(() => { useEffect(() => {
@ -191,8 +222,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
return 'image/jpeg' return 'image/jpeg'
case 'png': case 'png':
return 'image/png' return 'image/png'
case 'pdf':
return 'application/pdf'
default: default:
return '' return ''
} }
@ -226,17 +255,12 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
const detectedType = file.type || getFileTypeFromExtension(file.name) const detectedType = file.type || getFileTypeFromExtension(file.name)
const actualType = getFileTypeFromExtension(file.name) || detectedType const actualType = getFileTypeFromExtension(file.name) || detectedType
// Check file type // Check file type - images only
const allowedTypes = [ const allowedTypes = ['image/jpg', 'image/jpeg', 'image/png']
'image/jpg',
'image/jpeg',
'image/png',
'application/pdf',
]
if (!allowedTypes.includes(actualType)) { if (!allowedTypes.includes(actualType)) {
setMessage( setMessage(
`File is not supported. Only JPEG, JPG, PNG, and PDF files are allowed.` `File attachments not supported currently. Only JPEG, JPG, and PNG files are allowed.`
) )
// Reset file input to allow re-uploading // Reset file input to allow re-uploading
if (fileInputRef.current) { if (fileInputRef.current) {
@ -287,6 +311,104 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
} }
} }
const handleDragEnter = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only allow drag if model supports mmproj
if (hasMmproj) {
setIsDragOver(true)
}
}
const handleDragLeave = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only set dragOver to false if we're leaving the drop zone entirely
// In Tauri, relatedTarget can be null, so we need to handle that case
const relatedTarget = e.relatedTarget as Node | null
if (!relatedTarget || !e.currentTarget.contains(relatedTarget)) {
setIsDragOver(false)
}
}
const handleDragOver = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Ensure drag state is maintained during drag over
if (hasMmproj) {
setIsDragOver(true)
}
}
const handleDrop = (e: React.DragEvent) => {
e.preventDefault()
e.stopPropagation()
setIsDragOver(false)
// Only allow drop if model supports mmproj
if (!hasMmproj) {
return
}
// Check if dataTransfer exists (it might not in some Tauri scenarios)
if (!e.dataTransfer) {
console.warn('No dataTransfer available in drop event')
return
}
const files = e.dataTransfer.files
if (files && files.length > 0) {
// Create a synthetic event to reuse existing file handling logic
const syntheticEvent = {
target: {
files: files,
},
} as React.ChangeEvent<HTMLInputElement>
handleFileChange(syntheticEvent)
}
}
const handlePaste = (e: React.ClipboardEvent) => {
const clipboardItems = e.clipboardData?.items
if (!clipboardItems) return
// Only allow paste if model supports mmproj
if (!hasMmproj) {
return
}
const imageItems = Array.from(clipboardItems).filter((item) =>
item.type.startsWith('image/')
)
if (imageItems.length > 0) {
e.preventDefault()
const files: File[] = []
let processedCount = 0
imageItems.forEach((item) => {
const file = item.getAsFile()
if (file) {
files.push(file)
}
processedCount++
// When all items are processed, handle the valid files
if (processedCount === imageItems.length && files.length > 0) {
const syntheticEvent = {
target: {
files: files,
},
} as unknown as React.ChangeEvent<HTMLInputElement>
handleFileChange(syntheticEvent)
}
})
}
}
return ( return (
<div className="relative"> <div className="relative">
<div className="relative"> <div className="relative">
@ -311,8 +433,14 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
<div <div
className={cn( className={cn(
'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view', 'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view',
isFocused && 'ring-1 ring-main-view-fg/10' isFocused && 'ring-1 ring-main-view-fg/10',
isDragOver && 'ring-2 ring-accent border-accent'
)} )}
data-drop-zone={hasMmproj ? 'true' : undefined}
onDragEnter={hasMmproj ? handleDragEnter : undefined}
onDragLeave={hasMmproj ? handleDragLeave : undefined}
onDragOver={hasMmproj ? handleDragOver : undefined}
onDrop={hasMmproj ? handleDrop : undefined}
> >
{uploadedFiles.length > 0 && ( {uploadedFiles.length > 0 && (
<div className="flex gap-3 items-center p-2 pb-0"> <div className="flex gap-3 items-center p-2 pb-0">
@ -332,25 +460,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
alt={`${file.name} - ${index}`} alt={`${file.name} - ${index}`}
/> />
)} )}
{file.type === 'application/pdf' && (
<div className="bg-main-view-fg/4 h-full rounded-lg p-2 max-w-[400px] pr-4">
<div className="flex gap-2 items-center justify-center h-full">
<div className="size-10 rounded-md bg-main-view shrink-0 flex items-center justify-center">
<span className="uppercase font-bold">
{file.name.split('.').pop()}
</span>
</div>
<div className="truncate">
<h6 className="truncate mb-0.5 text-main-view-fg/80">
{file.name}
</h6>
<p className="text-xs text-main-view-fg/70">
{toGigabytes(file.size)}
</p>
</div>
</div>
</div>
)}
<div <div
className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer" className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer"
onClick={() => handleRemoveFile(index)} onClick={() => handleRemoveFile(index)}
@ -369,7 +478,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
rows={1} rows={1}
maxRows={10} maxRows={10}
value={prompt} value={prompt}
data-test-id={'chat-input'} data-testid={'chat-input'}
onChange={(e) => { onChange={(e) => {
setPrompt(e.target.value) setPrompt(e.target.value)
// Count the number of newlines to estimate rows // Count the number of newlines to estimate rows
@ -378,14 +487,21 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
}} }}
onKeyDown={(e) => { onKeyDown={(e) => {
// e.keyCode 229 is for IME input with Safari // e.keyCode 229 is for IME input with Safari
const isComposing = e.nativeEvent.isComposing || e.keyCode === 229; const isComposing =
if (e.key === 'Enter' && !e.shiftKey && prompt.trim() && !isComposing) { e.nativeEvent.isComposing || e.keyCode === 229
if (
e.key === 'Enter' &&
!e.shiftKey &&
prompt.trim() &&
!isComposing
) {
e.preventDefault() e.preventDefault()
// Submit the message when Enter is pressed without Shift // Submit the message when Enter is pressed without Shift
handleSendMesage(prompt) handleSendMesage(prompt)
// When Shift+Enter is pressed, a new line is added (default behavior) // When Shift+Enter is pressed, a new line is added (default behavior)
} }
}} }}
onPaste={hasMmproj ? handlePaste : undefined}
placeholder={t('common:placeholder.chatInput')} placeholder={t('common:placeholder.chatInput')}
autoFocus autoFocus
spellCheck={spellCheckChatInput} spellCheck={spellCheckChatInput}
@ -406,7 +522,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
<div className="px-1 flex items-center gap-1"> <div className="px-1 flex items-center gap-1">
<div <div
className={cn( className={cn(
'px-1 flex items-center gap-1', 'px-1 flex items-center',
streamingContent && 'opacity-50 pointer-events-none' streamingContent && 'opacity-50 pointer-events-none'
)} )}
> >
@ -418,19 +534,22 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
useLastUsedModel={initialMessage} useLastUsedModel={initialMessage}
/> />
)} )}
{/* File attachment - always available */} {/* File attachment - show only for models with mmproj */}
<div {hasMmproj && (
className="h-6 hidden p-1 items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1" <div
onClick={handleAttachmentClick} className="h-6 p-1 ml-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"
> onClick={handleAttachmentClick}
<IconPaperclip size={18} className="text-main-view-fg/50" /> >
<input <IconPhoto size={18} className="text-main-view-fg/50" />
type="file" <input
ref={fileInputRef} type="file"
className="hidden" ref={fileInputRef}
onChange={handleFileChange} className="hidden"
/> multiple
</div> onChange={handleFileChange}
/>
</div>
)}
{/* Microphone - always available - Temp Hide */} {/* Microphone - always available - Temp Hide */}
{/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"> {/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
<IconMicrophone size={18} className="text-main-view-fg/50" /> <IconMicrophone size={18} className="text-main-view-fg/50" />
@ -574,9 +693,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</Button> </Button>
) : ( ) : (
<Button <Button
variant={!prompt.trim() ? null : 'default'} variant={
!prompt.trim() && uploadedFiles.length === 0
? null
: 'default'
}
size="icon" size="icon"
disabled={!prompt.trim()} disabled={!prompt.trim() && uploadedFiles.length === 0}
data-test-id="send-message-button" data-test-id="send-message-button"
onClick={() => handleSendMesage(prompt)} onClick={() => handleSendMesage(prompt)}
> >
@ -590,6 +713,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
</div> </div>
</div> </div>
</div> </div>
{message && ( {message && (
<div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out"> <div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out">
<div className="flex items-center gap-1 justify-between"> <div className="flex items-center gap-1 justify-between">

View File

@ -19,6 +19,7 @@ import { localStorageKey } from '@/constants/localStorage'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useFavoriteModel } from '@/hooks/useFavoriteModel' import { useFavoriteModel } from '@/hooks/useFavoriteModel'
import { predefinedProviders } from '@/consts/providers' import { predefinedProviders } from '@/consts/providers'
import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models'
type DropdownModelProviderProps = { type DropdownModelProviderProps = {
model?: ThreadModel model?: ThreadModel
@ -66,6 +67,7 @@ const DropdownModelProvider = ({
getModelBy, getModelBy,
selectedProvider, selectedProvider,
selectedModel, selectedModel,
updateProvider,
} = useModelProvider() } = useModelProvider()
const [displayModel, setDisplayModel] = useState<string>('') const [displayModel, setDisplayModel] = useState<string>('')
const { updateCurrentThreadModel } = useThreads() const { updateCurrentThreadModel } = useThreads()
@ -79,31 +81,52 @@ const DropdownModelProvider = ({
const searchInputRef = useRef<HTMLInputElement>(null) const searchInputRef = useRef<HTMLInputElement>(null)
// Helper function to check if a model exists in providers // Helper function to check if a model exists in providers
const checkModelExists = useCallback((providerName: string, modelId: string) => { const checkModelExists = useCallback(
const provider = providers.find( (providerName: string, modelId: string) => {
(p) => p.provider === providerName && p.active const provider = providers.find(
) (p) => p.provider === providerName && p.active
return provider?.models.find((m) => m.id === modelId) )
}, [providers]) return provider?.models.find((m) => m.id === modelId)
},
[providers]
)
// Initialize model provider only once // Initialize model provider only once
useEffect(() => { useEffect(() => {
// Auto select model when existing thread is passed const initializeModel = async () => {
if (model) { // Auto select model when existing thread is passed
selectModelProvider(model?.provider as string, model?.id as string) if (model) {
if (!checkModelExists(model.provider, model.id)) { selectModelProvider(model?.provider as string, model?.id as string)
selectModelProvider('', '') if (!checkModelExists(model.provider, model.id)) {
} selectModelProvider('', '')
} else if (useLastUsedModel) { }
// Try to use last used model only when explicitly requested (for new chat) // Check mmproj existence for llamacpp models
const lastUsed = getLastUsedModel() if (model?.provider === 'llamacpp') {
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) { await checkMmprojExistsAndUpdateOffloadMMprojSetting(
selectModelProvider(lastUsed.provider, lastUsed.model) model.id as string,
} else { updateProvider,
// Fallback to default model if last used model no longer exists getProviderByName
selectModelProvider('', '') )
}
} else if (useLastUsedModel) {
// Try to use last used model only when explicitly requested (for new chat)
const lastUsed = getLastUsedModel()
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
selectModelProvider(lastUsed.provider, lastUsed.model)
if (lastUsed.provider === 'llamacpp') {
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
lastUsed.model,
updateProvider,
getProviderByName
)
}
} else {
selectModelProvider('', '')
}
} }
} }
initializeModel()
}, [ }, [
model, model,
selectModelProvider, selectModelProvider,
@ -111,6 +134,8 @@ const DropdownModelProvider = ({
providers, providers,
useLastUsedModel, useLastUsedModel,
checkModelExists, checkModelExists,
updateProvider,
getProviderByName,
]) ])
// Update display model when selection changes // Update display model when selection changes
@ -245,7 +270,7 @@ const DropdownModelProvider = ({
}, [filteredItems, providers, searchValue, favoriteModels]) }, [filteredItems, providers, searchValue, favoriteModels])
const handleSelect = useCallback( const handleSelect = useCallback(
(searchableModel: SearchableModel) => { async (searchableModel: SearchableModel) => {
selectModelProvider( selectModelProvider(
searchableModel.provider.provider, searchableModel.provider.provider,
searchableModel.model.id searchableModel.model.id
@ -254,6 +279,16 @@ const DropdownModelProvider = ({
id: searchableModel.model.id, id: searchableModel.model.id,
provider: searchableModel.provider.provider, provider: searchableModel.provider.provider,
}) })
// Check mmproj existence for llamacpp models
if (searchableModel.provider.provider === 'llamacpp') {
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
searchableModel.model.id,
updateProvider,
getProviderByName
)
}
// Store the selected model as last used // Store the selected model as last used
if (useLastUsedModel) { if (useLastUsedModel) {
setLastUsedModel( setLastUsedModel(
@ -264,7 +299,13 @@ const DropdownModelProvider = ({
setSearchValue('') setSearchValue('')
setOpen(false) setOpen(false)
}, },
[selectModelProvider, updateCurrentThreadModel, useLastUsedModel] [
selectModelProvider,
updateCurrentThreadModel,
useLastUsedModel,
updateProvider,
getProviderByName,
]
) )
const currentModel = selectedModel?.id const currentModel = selectedModel?.id

View File

@ -70,8 +70,8 @@ export function ModelSetting({
models: updatedModels, models: updatedModels,
}) })
// Call debounced stopModel only when updating ctx_len or ngl // Call debounced stopModel only when updating ctx_len, ngl, chat_template, or offload_mmproj
if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template') { if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template' || key === 'offload_mmproj') {
debouncedStopModel(model.id) debouncedStopModel(model.id)
} }
} }

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
import { RenderMarkdown } from './RenderMarkdown' import { RenderMarkdown } from './RenderMarkdown'
import React, { Fragment, memo, useCallback, useMemo, useState } from 'react' import React, { Fragment, memo, useCallback, useMemo, useState } from 'react'
@ -144,7 +145,7 @@ export const ThreadContent = memo(
isLastMessage?: boolean isLastMessage?: boolean
index?: number index?: number
showAssistant?: boolean showAssistant?: boolean
// eslint-disable-next-line @typescript-eslint/no-explicit-any
streamTools?: any streamTools?: any
contextOverflowModal?: React.ReactNode | null contextOverflowModal?: React.ReactNode | null
updateMessage?: (item: ThreadMessage, message: string) => void updateMessage?: (item: ThreadMessage, message: string) => void
@ -172,9 +173,12 @@ export const ThreadContent = memo(
const { reasoningSegment, textSegment } = useMemo(() => { const { reasoningSegment, textSegment } = useMemo(() => {
// Check for thinking formats // Check for thinking formats
const hasThinkTag = text.includes('<think>') && !text.includes('</think>') const hasThinkTag = text.includes('<think>') && !text.includes('</think>')
const hasAnalysisChannel = text.includes('<|channel|>analysis<|message|>') && !text.includes('<|start|>assistant<|channel|>final<|message|>') const hasAnalysisChannel =
text.includes('<|channel|>analysis<|message|>') &&
!text.includes('<|start|>assistant<|channel|>final<|message|>')
if (hasThinkTag || hasAnalysisChannel) return { reasoningSegment: text, textSegment: '' } if (hasThinkTag || hasAnalysisChannel)
return { reasoningSegment: text, textSegment: '' }
// Check for completed think tag format // Check for completed think tag format
const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/) const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/)
@ -187,7 +191,9 @@ export const ThreadContent = memo(
} }
// Check for completed analysis channel format // Check for completed analysis channel format
const analysisMatch = text.match(/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/) const analysisMatch = text.match(
/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/
)
if (analysisMatch?.index !== undefined) { if (analysisMatch?.index !== undefined) {
const splitIndex = analysisMatch.index + analysisMatch[0].length const splitIndex = analysisMatch.index + analysisMatch[0].length
return { return {
@ -213,7 +219,36 @@ export const ThreadContent = memo(
} }
if (toSendMessage) { if (toSendMessage) {
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '') deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
sendMessage(toSendMessage.content?.[0]?.text?.value || '') // Extract text content and any attachments
const textContent =
toSendMessage.content?.find((c) => c.type === 'text')?.text?.value ||
''
const attachments = toSendMessage.content
?.filter((c) => (c.type === 'image_url' && c.image_url?.url) || false)
.map((c) => {
if (c.type === 'image_url' && c.image_url?.url) {
const url = c.image_url.url
const [mimeType, base64] = url
.replace('data:', '')
.split(';base64,')
return {
name: 'image', // We don't have the original filename
type: mimeType,
size: 0, // We don't have the original size
base64: base64,
dataUrl: url,
}
}
return null
})
.filter(Boolean) as Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
sendMessage(textContent, true, attachments)
} }
}, [deleteMessage, getMessages, item, sendMessage]) }, [deleteMessage, getMessages, item, sendMessage])
@ -255,22 +290,68 @@ export const ThreadContent = memo(
return ( return (
<Fragment> <Fragment>
{item.content?.[0]?.text && item.role === 'user' && ( {item.role === 'user' && (
<div className="w-full"> <div className="w-full">
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal"> {/* Render attachments above the message bubble */}
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] "> {item.content?.some(
<div className="select-text"> (c) => (c.type === 'image_url' && c.image_url?.url) || false
<RenderMarkdown ) && (
content={item.content?.[0].text.value} <div className="flex justify-end w-full mb-2">
components={linkComponents} <div className="flex flex-wrap gap-2 max-w-[80%] justify-end">
isUser {item.content
/> ?.filter(
(c) =>
(c.type === 'image_url' && c.image_url?.url) || false
)
.map((contentPart, index) => {
// Handle images
if (
contentPart.type === 'image_url' &&
contentPart.image_url?.url
) {
return (
<div key={index} className="relative">
<img
src={contentPart.image_url.url}
alt="Uploaded attachment"
className="size-40 rounded-md object-cover border border-main-view-fg/10"
/>
</div>
)
}
return null
})}
</div> </div>
</div> </div>
</div> )}
{/* Render text content in the message bubble */}
{item.content?.some((c) => c.type === 'text' && c.text?.value) && (
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal">
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] ">
<div className="select-text">
{item.content
?.filter((c) => c.type === 'text' && c.text?.value)
.map((contentPart, index) => (
<div key={index}>
<RenderMarkdown
content={contentPart.text!.value}
components={linkComponents}
isUser
/>
</div>
))}
</div>
</div>
</div>
)}
<div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2"> <div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2">
<EditDialog <EditDialog
message={item.content?.[0]?.text.value} message={
item.content?.find((c) => c.type === 'text')?.text?.value ||
''
}
setMessage={(message) => { setMessage={(message) => {
if (item.updateMessage) { if (item.updateMessage) {
item.updateMessage(item, message) item.updateMessage(item, message)

View File

@ -73,6 +73,11 @@ vi.mock('@/services/mcp', () => ({
vi.mock('@/services/models', () => ({ vi.mock('@/services/models', () => ({
stopAllModels: vi.fn(), stopAllModels: vi.fn(),
checkMmprojExists: vi.fn(() => Promise.resolve(true)),
}))
vi.mock('../MovingBorder', () => ({
MovingBorder: ({ children }: { children: React.ReactNode }) => <div data-testid="moving-border">{children}</div>,
})) }))
describe('ChatInput', () => { describe('ChatInput', () => {
@ -231,7 +236,7 @@ describe('ChatInput', () => {
const sendButton = document.querySelector('[data-test-id="send-message-button"]') const sendButton = document.querySelector('[data-test-id="send-message-button"]')
await user.click(sendButton) await user.click(sendButton)
expect(mockSendMessage).toHaveBeenCalledWith('Hello world') expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
}) })
it('sends message when Enter key is pressed', async () => { it('sends message when Enter key is pressed', async () => {
@ -248,7 +253,7 @@ describe('ChatInput', () => {
const textarea = screen.getByRole('textbox') const textarea = screen.getByRole('textbox')
await user.type(textarea, '{Enter}') await user.type(textarea, '{Enter}')
expect(mockSendMessage).toHaveBeenCalledWith('Hello world') expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
}) })
it('does not send message when Shift+Enter is pressed', async () => { it('does not send message when Shift+Enter is pressed', async () => {
@ -343,9 +348,12 @@ describe('ChatInput', () => {
const user = userEvent.setup() const user = userEvent.setup()
renderWithRouter() renderWithRouter()
// File upload is rendered as hidden input element // Wait for async effects to complete (mmproj check)
const fileInput = document.querySelector('input[type="file"]') await waitFor(() => {
expect(fileInput).toBeInTheDocument() // File upload is rendered as hidden input element
const fileInput = document.querySelector('input[type="file"]')
expect(fileInput).toBeInTheDocument()
})
}) })
it('disables input when streaming', () => { it('disables input when streaming', () => {
@ -361,7 +369,7 @@ describe('ChatInput', () => {
renderWithRouter() renderWithRouter()
}) })
const textarea = screen.getByRole('textbox') const textarea = screen.getByTestId('chat-input')
expect(textarea).toBeDisabled() expect(textarea).toBeDisabled()
}) })
@ -378,4 +386,28 @@ describe('ChatInput', () => {
expect(toolsIcon).toBeInTheDocument() expect(toolsIcon).toBeInTheDocument()
}) })
}) })
it('uses selectedProvider for provider checks', () => {
// Test that the component correctly uses selectedProvider instead of selectedModel.provider
vi.mocked(useModelProvider).mockReturnValue({
selectedModel: {
id: 'test-model',
capabilities: ['vision'],
},
providers: [],
getModelBy: vi.fn(),
selectModelProvider: vi.fn(),
selectedProvider: 'llamacpp',
setProviders: vi.fn(),
getProviderByName: vi.fn(),
updateProvider: vi.fn(),
addProvider: vi.fn(),
deleteProvider: vi.fn(),
deleteModel: vi.fn(),
deletedModels: [],
})
// This test ensures the component renders without errors when using selectedProvider
expect(() => renderWithRouter()).not.toThrow()
})
}) })

View File

@ -203,7 +203,17 @@ export const useChat = () => {
) )
const sendMessage = useCallback( const sendMessage = useCallback(
async (message: string, troubleshooting = true) => { async (
message: string,
troubleshooting = true,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
) => {
const activeThread = await getCurrentThread() const activeThread = await getCurrentThread()
resetTokenSpeed() resetTokenSpeed()
@ -217,7 +227,7 @@ export const useChat = () => {
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent)
// Do not add new message on retry // Do not add new message on retry
if (troubleshooting) if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message)) addMessage(newUserThreadContent(activeThread.id, message, attachments))
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id)
setPrompt('') setPrompt('')
try { try {
@ -231,7 +241,7 @@ export const useChat = () => {
messages, messages,
currentAssistant?.instructions currentAssistant?.instructions
) )
if (troubleshooting) builder.addUserMessage(message) if (troubleshooting) builder.addUserMessage(message, attachments)
let isCompleted = false let isCompleted = false

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { import {
ContentType, ContentType,
ChatCompletionRole, ChatCompletionRole,
@ -51,11 +52,16 @@ export type ChatCompletionResponse =
*/ */
export const newUserThreadContent = ( export const newUserThreadContent = (
threadId: string, threadId: string,
content: string content: string,
): ThreadMessage => ({ attachments?: Array<{
type: 'text', name: string
role: ChatCompletionRole.User, type: string
content: [ size: number
base64: string
dataUrl: string
}>
): ThreadMessage => {
const contentParts = [
{ {
type: ContentType.Text, type: ContentType.Text,
text: { text: {
@ -63,14 +69,35 @@ export const newUserThreadContent = (
annotations: [], annotations: [],
}, },
}, },
], ]
id: ulid(),
object: 'thread.message', // Add attachments to content array
thread_id: threadId, if (attachments) {
status: MessageStatus.Ready, attachments.forEach((attachment) => {
created_at: 0, if (attachment.type.startsWith('image/')) {
completed_at: 0, contentParts.push({
}) type: ContentType.Image,
image_url: {
url: `data:${attachment.type};base64,${attachment.base64}`,
detail: 'auto',
},
} as any)
}
})
}
return {
type: 'text',
role: ChatCompletionRole.User,
content: contentParts,
id: ulid(),
object: 'thread.message',
thread_id: threadId,
status: MessageStatus.Ready,
created_at: 0,
completed_at: 0,
}
}
/** /**
* @fileoverview Helper functions for creating thread content. * @fileoverview Helper functions for creating thread content.
* These functions are used to create thread content objects * These functions are used to create thread content objects
@ -162,13 +189,11 @@ export const sendCompletion = async (
if ( if (
thread.model.id && thread.model.id &&
!Object.values(models[providerName]).flat().includes(thread.model.id) && !Object.values(models[providerName]).flat().includes(thread.model.id) &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
!tokenJS.extendedModelExist(providerName as any, thread.model.id) && !tokenJS.extendedModelExist(providerName as any, thread.model.id) &&
provider.provider !== 'llamacpp' provider.provider !== 'llamacpp'
) { ) {
try { try {
tokenJS.extendModelList( tokenJS.extendModelList(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
providerName as any, providerName as any,
thread.model.id, thread.model.id,
// This is to inherit the model capabilities from another built-in model // This is to inherit the model capabilities from another built-in model
@ -201,7 +226,7 @@ export const sendCompletion = async (
? await tokenJS.chat.completions.create( ? await tokenJS.chat.completions.create(
{ {
stream: true, stream: true,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
provider: providerName as any, provider: providerName as any,
model: thread.model?.id, model: thread.model?.id,
messages, messages,

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ChatCompletionMessageParam } from 'token.js' import { ChatCompletionMessageParam } from 'token.js'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from '@janhq/core'
@ -19,32 +20,106 @@ export class CompletionMessagesBuilder {
this.messages.push( this.messages.push(
...messages ...messages
.filter((e) => !e.metadata?.error) .filter((e) => !e.metadata?.error)
.map<ChatCompletionMessageParam>( .map<ChatCompletionMessageParam>((msg) => {
(msg) => if (msg.role === 'assistant') {
({ return {
role: msg.role, role: msg.role,
content: content: this.normalizeContent(
msg.role === 'assistant' msg.content[0]?.text?.value || '.'
? this.normalizeContent(msg.content[0]?.text?.value || '.') ),
: msg.content[0]?.text?.value || '.', } as ChatCompletionMessageParam
}) as ChatCompletionMessageParam } else {
) // For user messages, handle multimodal content
if (msg.content.length > 1) {
// Multiple content parts (text + images + files)
const content = msg.content.map((contentPart) => {
if (contentPart.type === 'text') {
return {
type: 'text',
text: contentPart.text?.value || '',
}
} else if (contentPart.type === 'image_url') {
return {
type: 'image_url',
image_url: {
url: contentPart.image_url?.url || '',
detail: contentPart.image_url?.detail || 'auto',
},
}
} else {
return contentPart
}
})
return {
role: msg.role,
content,
} as ChatCompletionMessageParam
} else {
// Single text content
return {
role: msg.role,
content: msg.content[0]?.text?.value || '.',
} as ChatCompletionMessageParam
}
}
})
) )
} }
/** /**
* Add a user message to the messages array. * Add a user message to the messages array.
* @param content - The content of the user message. * @param content - The content of the user message.
* @param attachments - Optional attachments for the message.
*/ */
addUserMessage(content: string) { addUserMessage(
content: string,
attachments?: Array<{
name: string
type: string
size: number
base64: string
dataUrl: string
}>
) {
// Ensure no consecutive user messages // Ensure no consecutive user messages
if (this.messages[this.messages.length - 1]?.role === 'user') { if (this.messages[this.messages.length - 1]?.role === 'user') {
this.messages.pop() this.messages.pop()
} }
this.messages.push({
role: 'user', // Handle multimodal content with attachments
content: content, if (attachments && attachments.length > 0) {
}) const messageContent: any[] = [
{
type: 'text',
text: content,
},
]
// Add attachments (images and PDFs)
attachments.forEach((attachment) => {
if (attachment.type.startsWith('image/')) {
messageContent.push({
type: 'image_url',
image_url: {
url: `data:${attachment.type};base64,${attachment.base64}`,
detail: 'auto',
},
})
}
})
this.messages.push({
role: 'user',
content: messageContent,
} as any)
} else {
// Text-only message
this.messages.push({
role: 'user',
content: content,
})
}
} }
/** /**

View File

@ -26,7 +26,7 @@ import {
ResizablePanel, ResizablePanel,
ResizableHandle, ResizableHandle,
} from '@/components/ui/resizable' } from '@/components/ui/resizable'
import { useCallback } from 'react' import { useCallback, useEffect } from 'react'
import GlobalError from '@/containers/GlobalError' import GlobalError from '@/containers/GlobalError'
import { GlobalEventHandler } from '@/providers/GlobalEventHandler' import { GlobalEventHandler } from '@/providers/GlobalEventHandler'
@ -65,6 +65,41 @@ const AppLayout = () => {
[setLeftPanelSize, setLeftPanel] [setLeftPanelSize, setLeftPanel]
) )
// Prevent default drag and drop behavior globally
useEffect(() => {
const preventDefaults = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
}
const handleGlobalDrop = (e: DragEvent) => {
e.preventDefault()
e.stopPropagation()
// Only prevent if the target is not within a chat input or other valid drop zone
const target = e.target as Element
const isValidDropZone = target?.closest('[data-drop-zone="true"]') ||
target?.closest('.chat-input-drop-zone') ||
target?.closest('[data-tauri-drag-region]')
if (!isValidDropZone) {
// Prevent the file from opening in the window
return false
}
}
// Add event listeners to prevent default drag/drop behavior
window.addEventListener('dragenter', preventDefaults)
window.addEventListener('dragover', preventDefaults)
window.addEventListener('drop', handleGlobalDrop)
return () => {
window.removeEventListener('dragenter', preventDefaults)
window.removeEventListener('dragover', preventDefaults)
window.removeEventListener('drop', handleGlobalDrop)
}
}, [])
return ( return (
<Fragment> <Fragment>
<AnalyticProvider /> <AnalyticProvider />

View File

@ -290,7 +290,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational', 'pytorch'], tags: ['conversational', 'pytorch'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -443,7 +443,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -471,7 +471,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -510,7 +510,7 @@ describe('models service', () => {
likes: 100, likes: 100,
tags: ['conversational'], tags: ['conversational'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z', createdAt: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z', last_modified: '2023-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -559,7 +559,7 @@ describe('models service', () => {
likes: 75, likes: 75,
tags: ['pytorch', 'transformers', 'text-generation'], tags: ['pytorch', 'transformers', 'text-generation'],
pipeline_tag: 'text-generation', pipeline_tag: 'text-generation',
created_at: '2021-01-01T00:00:00Z', createdAt: '2021-01-01T00:00:00Z',
last_modified: '2021-12-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,
@ -605,6 +605,8 @@ describe('models service', () => {
file_size: '4.0 GB', file_size: '4.0 GB',
}, },
], ],
num_mmproj: 0,
mmproj_models: [],
created_at: '2021-01-01T00:00:00Z', created_at: '2021-01-01T00:00:00Z',
readme: readme:
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md', 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
@ -820,7 +822,7 @@ describe('models service', () => {
downloads: 0, downloads: 0,
likes: 0, likes: 0,
tags: [], tags: [],
created_at: '2021-01-01T00:00:00Z', createdAt: '2021-01-01T00:00:00Z',
last_modified: '2021-12-01T00:00:00Z', last_modified: '2021-12-01T00:00:00Z',
private: false, private: false,
disabled: false, disabled: false,

View File

@ -1,3 +1,4 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { sanitizeModelId } from '@/lib/utils' import { sanitizeModelId } from '@/lib/utils'
import { import {
AIEngine, AIEngine,
@ -27,6 +28,7 @@ export interface CatalogModel {
num_quants: number num_quants: number
quants: ModelQuant[] quants: ModelQuant[]
mmproj_models?: MMProjModel[] mmproj_models?: MMProjModel[]
num_mmproj: number
created_at?: string created_at?: string
readme?: string readme?: string
tools?: boolean tools?: boolean
@ -44,7 +46,7 @@ export interface HuggingFaceRepo {
library_name?: string library_name?: string
tags: string[] tags: string[]
pipeline_tag?: string pipeline_tag?: string
created_at: string createdAt: string
last_modified: string last_modified: string
private: boolean private: boolean
disabled: boolean disabled: boolean
@ -155,21 +157,30 @@ export const fetchHuggingFaceRepo = async (
export const convertHfRepoToCatalogModel = ( export const convertHfRepoToCatalogModel = (
repo: HuggingFaceRepo repo: HuggingFaceRepo
): CatalogModel => { ): CatalogModel => {
// Format file size helper
const formatFileSize = (size?: number) => {
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
// Extract GGUF files from the repository siblings // Extract GGUF files from the repository siblings
const ggufFiles = const ggufFiles =
repo.siblings?.filter((file) => repo.siblings?.filter((file) =>
file.rfilename.toLowerCase().endsWith('.gguf') file.rfilename.toLowerCase().endsWith('.gguf')
) || [] ) || []
// Convert GGUF files to quants format // Separate regular GGUF files from mmproj files
const quants = ggufFiles.map((file) => { const regularGgufFiles = ggufFiles.filter(
// Format file size (file) => !file.rfilename.toLowerCase().includes('mmproj')
const formatFileSize = (size?: number) => { )
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
const mmprojFiles = ggufFiles.filter((file) =>
file.rfilename.toLowerCase().includes('mmproj')
)
// Convert regular GGUF files to quants format
const quants = regularGgufFiles.map((file) => {
// Generate model_id from filename (remove .gguf extension, case-insensitive) // Generate model_id from filename (remove .gguf extension, case-insensitive)
const modelId = file.rfilename.replace(/\.gguf$/i, '') const modelId = file.rfilename.replace(/\.gguf$/i, '')
@ -180,15 +191,28 @@ export const convertHfRepoToCatalogModel = (
} }
}) })
// Convert mmproj files to mmproj_models format
const mmprojModels = mmprojFiles.map((file) => {
const modelId = file.rfilename.replace(/\.gguf$/i, '')
return {
model_id: sanitizeModelId(modelId),
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
file_size: formatFileSize(file.size),
}
})
return { return {
model_name: repo.modelId, model_name: repo.modelId,
description: `**Tags**: ${repo.tags?.join(', ')}`,
developer: repo.author, developer: repo.author,
downloads: repo.downloads || 0, downloads: repo.downloads || 0,
created_at: repo.createdAt,
num_quants: quants.length, num_quants: quants.length,
quants: quants, quants: quants,
created_at: repo.created_at, num_mmproj: mmprojModels.length,
mmproj_models: mmprojModels,
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`, readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
description: `**Tags**: ${repo.tags?.join(', ')}`,
} }
} }
@ -327,3 +351,137 @@ export const isToolSupported = async (modelId: string): Promise<boolean> => {
return engine.isToolSupported(modelId) return engine.isToolSupported(modelId)
} }
/**
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
* Also checks if the model has offload_mmproj setting.
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
* @param modelId - The model ID to check for mmproj.gguf
* @param updateProvider - Function to update the provider state
* @param getProviderByName - Function to get provider by name
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
*/
export const checkMmprojExistsAndUpdateOffloadMMprojSetting = async (
modelId: string,
updateProvider?: (providerName: string, data: Partial<ModelProvider>) => void,
getProviderByName?: (providerName: string) => ModelProvider | undefined
): Promise<{ exists: boolean; settingsUpdated: boolean }> => {
let settingsUpdated = false
try {
const engine = getEngine('llamacpp') as AIEngine & {
checkMmprojExists?: (id: string) => Promise<boolean>
}
if (engine && typeof engine.checkMmprojExists === 'function') {
const exists = await engine.checkMmprojExists(modelId)
// If we have the store functions, use them; otherwise fall back to localStorage
if (updateProvider && getProviderByName) {
const provider = getProviderByName('llamacpp')
if (provider) {
const model = provider.models.find((m) => m.id === modelId)
if (model?.settings) {
const hasOffloadMmproj = 'offload_mmproj' in model.settings
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
if (exists && !hasOffloadMmproj) {
// Create updated models array with the new setting
const updatedModels = provider.models.map((m) => {
if (m.id === modelId) {
return {
...m,
settings: {
...m.settings,
offload_mmproj: {
key: 'offload_mmproj',
title: 'Offload MMProj',
description:
'Offload multimodal projection layers to GPU',
controller_type: 'checkbox',
controller_props: {
value: true,
},
},
},
}
}
return m
})
// Update the provider with the new models array
updateProvider('llamacpp', { models: updatedModels })
settingsUpdated = true
}
}
}
} else {
// Fall back to localStorage approach for backwards compatibility
try {
const modelProviderData = JSON.parse(
localStorage.getItem('model-provider') || '{}'
)
const llamacppProvider = modelProviderData.state?.providers?.find(
(p: any) => p.provider === 'llamacpp'
)
const model = llamacppProvider?.models?.find(
(m: any) => m.id === modelId
)
if (model?.settings) {
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
if (exists) {
if (!model.settings.offload_mmproj) {
model.settings.offload_mmproj = {
key: 'offload_mmproj',
title: 'Offload MMProj',
description: 'Offload multimodal projection layers to GPU',
controller_type: 'checkbox',
controller_props: {
value: true,
},
}
// Save updated settings back to localStorage
localStorage.setItem(
'model-provider',
JSON.stringify(modelProviderData)
)
settingsUpdated = true
}
}
}
} catch (localStorageError) {
console.error(
`Error checking localStorage for model ${modelId}:`,
localStorageError
)
}
}
return { exists, settingsUpdated }
}
} catch (error) {
console.error(`Error checking mmproj for model ${modelId}:`, error)
}
return { exists: false, settingsUpdated }
}
/**
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
* @param modelId - The model ID to check for mmproj.gguf
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
*/
export const checkMmprojExists = async (modelId: string): Promise<boolean> => {
try {
const engine = getEngine('llamacpp') as AIEngine & {
checkMmprojExists?: (id: string) => Promise<boolean>
}
if (engine && typeof engine.checkMmprojExists === 'function') {
return await engine.checkMmprojExists(modelId)
}
} catch (error) {
console.error(`Error checking mmproj for model ${modelId}:`, error)
}
return false
}