Merge pull request #6134 from menloresearch/feat/attachment-ui
feat: attachment UI
This commit is contained in:
commit
5481ee9e35
@ -31,6 +31,7 @@
|
|||||||
"@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware",
|
"@janhq/tauri-plugin-hardware-api": "link:../../src-tauri/plugins/tauri-plugin-hardware",
|
||||||
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
|
"@janhq/tauri-plugin-llamacpp-api": "link:../../src-tauri/plugins/tauri-plugin-llamacpp",
|
||||||
"@tauri-apps/api": "^2.5.0",
|
"@tauri-apps/api": "^2.5.0",
|
||||||
|
"@tauri-apps/plugin-http": "^2.5.1",
|
||||||
"@tauri-apps/plugin-log": "^2.6.0",
|
"@tauri-apps/plugin-log": "^2.6.0",
|
||||||
"fetch-retry": "^5.0.6",
|
"fetch-retry": "^5.0.6",
|
||||||
"ulidx": "^2.3.0"
|
"ulidx": "^2.3.0"
|
||||||
|
|||||||
@ -17,4 +17,7 @@ export default defineConfig({
|
|||||||
IS_MAC: JSON.stringify(process.platform === 'darwin'),
|
IS_MAC: JSON.stringify(process.platform === 'darwin'),
|
||||||
IS_LINUX: JSON.stringify(process.platform === 'linux'),
|
IS_LINUX: JSON.stringify(process.platform === 'linux'),
|
||||||
},
|
},
|
||||||
|
inject: {
|
||||||
|
fetch: ['@tauri-apps/plugin-http', 'fetch'],
|
||||||
|
},
|
||||||
})
|
})
|
||||||
|
|||||||
@ -41,6 +41,7 @@ type LlamacppConfig = {
|
|||||||
auto_unload: boolean
|
auto_unload: boolean
|
||||||
chat_template: string
|
chat_template: string
|
||||||
n_gpu_layers: number
|
n_gpu_layers: number
|
||||||
|
offload_mmproj: boolean
|
||||||
override_tensor_buffer_t: string
|
override_tensor_buffer_t: string
|
||||||
ctx_size: number
|
ctx_size: number
|
||||||
threads: number
|
threads: number
|
||||||
@ -1222,6 +1223,10 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
// Takes a regex with matching tensor name as input
|
// Takes a regex with matching tensor name as input
|
||||||
if (cfg.override_tensor_buffer_t)
|
if (cfg.override_tensor_buffer_t)
|
||||||
args.push('--override-tensor', cfg.override_tensor_buffer_t)
|
args.push('--override-tensor', cfg.override_tensor_buffer_t)
|
||||||
|
// offload multimodal projector model to the GPU by default. if there is not enough memory
|
||||||
|
// turn this setting off will keep the projector model on the CPU but the image processing can
|
||||||
|
// take longer
|
||||||
|
if (cfg.offload_mmproj === false) args.push('--no-mmproj-offload')
|
||||||
args.push('-a', modelId)
|
args.push('-a', modelId)
|
||||||
args.push('--port', String(port))
|
args.push('--port', String(port))
|
||||||
if (modelConfig.mmproj_path) {
|
if (modelConfig.mmproj_path) {
|
||||||
@ -1383,7 +1388,8 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
signal: abortController?.signal,
|
connectTimeout: 600000, // 10 minutes
|
||||||
|
signal: AbortSignal.any([AbortSignal.timeout(600000), abortController?.signal]),
|
||||||
})
|
})
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
const errorData = await response.json().catch(() => null)
|
const errorData = await response.json().catch(() => null)
|
||||||
@ -1542,6 +1548,26 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if mmproj.gguf file exists for a given model ID
|
||||||
|
* @param modelId - The model ID to check for mmproj.gguf
|
||||||
|
* @returns Promise<boolean> - true if mmproj.gguf exists, false otherwise
|
||||||
|
*/
|
||||||
|
async checkMmprojExists(modelId: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
const mmprojPath = await joinPath([
|
||||||
|
await this.getProviderPath(),
|
||||||
|
'models',
|
||||||
|
modelId,
|
||||||
|
'mmproj.gguf',
|
||||||
|
])
|
||||||
|
return await fs.existsSync(mmprojPath)
|
||||||
|
} catch (e) {
|
||||||
|
logger.error(`Error checking mmproj.gguf for model ${modelId}:`, e)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async getDevices(): Promise<DeviceList[]> {
|
async getDevices(): Promise<DeviceList[]> {
|
||||||
const cfg = this.config
|
const cfg = this.config
|
||||||
const [version, backend] = cfg.version_backend.split('/')
|
const [version, backend] = cfg.version_backend.split('/')
|
||||||
@ -1644,4 +1670,18 @@ export default class llamacpp_extension extends AIEngine {
|
|||||||
'tokenizer.chat_template'
|
'tokenizer.chat_template'
|
||||||
]?.includes('tools')
|
]?.includes('tools')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async loadMetadata(path: string): Promise<GgufMetadata> {
|
||||||
|
try {
|
||||||
|
const data = await invoke<GgufMetadata>(
|
||||||
|
'plugin:llamacpp|read_gguf_metadata',
|
||||||
|
{
|
||||||
|
path: path,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
} catch (err) {
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,7 +12,7 @@ use tokio::time::Instant;
|
|||||||
|
|
||||||
use crate::device::{get_devices_from_backend, DeviceInfo};
|
use crate::device::{get_devices_from_backend, DeviceInfo};
|
||||||
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
|
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
|
||||||
use crate::path::{validate_binary_path, validate_model_path};
|
use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path};
|
||||||
use crate::process::{
|
use crate::process::{
|
||||||
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
|
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
|
||||||
get_random_available_port, is_process_running_by_pid,
|
get_random_available_port, is_process_running_by_pid,
|
||||||
@ -55,6 +55,7 @@ pub async fn load_llama_model<R: Runtime>(
|
|||||||
|
|
||||||
let port = parse_port_from_args(&args);
|
let port = parse_port_from_args(&args);
|
||||||
let model_path_pb = validate_model_path(&mut args)?;
|
let model_path_pb = validate_model_path(&mut args)?;
|
||||||
|
let _mmproj_path_pb = validate_mmproj_path(&mut args)?;
|
||||||
|
|
||||||
let api_key: String;
|
let api_key: String;
|
||||||
|
|
||||||
|
|||||||
@ -98,3 +98,50 @@ pub fn validate_model_path(args: &mut Vec<String>) -> ServerResult<PathBuf> {
|
|||||||
|
|
||||||
Ok(model_path_pb)
|
Ok(model_path_pb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Validate mmproj path exists and update args with platform-appropriate path format
|
||||||
|
pub fn validate_mmproj_path(args: &mut Vec<String>) -> ServerResult<Option<PathBuf>> {
|
||||||
|
let mmproj_path_index = match args.iter().position(|arg| arg == "--mmproj") {
|
||||||
|
Some(index) => index,
|
||||||
|
None => return Ok(None), // mmproj is optional
|
||||||
|
};
|
||||||
|
|
||||||
|
let mmproj_path = args.get(mmproj_path_index + 1).cloned().ok_or_else(|| {
|
||||||
|
LlamacppError::new(
|
||||||
|
ErrorCode::ModelLoadFailed,
|
||||||
|
"Mmproj path was not provided after '--mmproj' flag.".into(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mmproj_path_pb = PathBuf::from(&mmproj_path);
|
||||||
|
if !mmproj_path_pb.exists() {
|
||||||
|
let err_msg = format!(
|
||||||
|
"Invalid or inaccessible mmproj path: {}",
|
||||||
|
mmproj_path_pb.display()
|
||||||
|
);
|
||||||
|
log::error!("{}", &err_msg);
|
||||||
|
return Err(LlamacppError::new(
|
||||||
|
ErrorCode::ModelFileNotFound,
|
||||||
|
"The specified mmproj file does not exist or is not accessible.".into(),
|
||||||
|
Some(err_msg),
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(windows)]
|
||||||
|
{
|
||||||
|
// use short path on Windows
|
||||||
|
if let Some(short) = get_short_path(&mmproj_path_pb) {
|
||||||
|
args[mmproj_path_index + 1] = short;
|
||||||
|
} else {
|
||||||
|
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(windows))]
|
||||||
|
{
|
||||||
|
args[mmproj_path_index + 1] = mmproj_path_pb.display().to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(mmproj_path_pb))
|
||||||
|
}
|
||||||
|
|||||||
@ -35,7 +35,8 @@
|
|||||||
"effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"],
|
"effects": ["fullScreenUI", "mica", "tabbed", "blur", "acrylic"],
|
||||||
"state": "active",
|
"state": "active",
|
||||||
"radius": 8
|
"radius": 8
|
||||||
}
|
},
|
||||||
|
"dragDropEnabled": false
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"security": {
|
"security": {
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import TextareaAutosize from 'react-textarea-autosize'
|
import TextareaAutosize from 'react-textarea-autosize'
|
||||||
import { cn, toGigabytes } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
import { usePrompt } from '@/hooks/usePrompt'
|
import { usePrompt } from '@/hooks/usePrompt'
|
||||||
import { useThreads } from '@/hooks/useThreads'
|
import { useThreads } from '@/hooks/useThreads'
|
||||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||||
@ -14,7 +14,7 @@ import {
|
|||||||
} from '@/components/ui/tooltip'
|
} from '@/components/ui/tooltip'
|
||||||
import { ArrowRight } from 'lucide-react'
|
import { ArrowRight } from 'lucide-react'
|
||||||
import {
|
import {
|
||||||
IconPaperclip,
|
IconPhoto,
|
||||||
IconWorld,
|
IconWorld,
|
||||||
IconAtom,
|
IconAtom,
|
||||||
IconEye,
|
IconEye,
|
||||||
@ -34,6 +34,7 @@ import DropdownModelProvider from '@/containers/DropdownModelProvider'
|
|||||||
import { ModelLoader } from '@/containers/loaders/ModelLoader'
|
import { ModelLoader } from '@/containers/loaders/ModelLoader'
|
||||||
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
|
import DropdownToolsAvailable from '@/containers/DropdownToolsAvailable'
|
||||||
import { getConnectedServers } from '@/services/mcp'
|
import { getConnectedServers } from '@/services/mcp'
|
||||||
|
import { checkMmprojExists } from '@/services/models'
|
||||||
|
|
||||||
type ChatInputProps = {
|
type ChatInputProps = {
|
||||||
className?: string
|
className?: string
|
||||||
@ -60,7 +61,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
|
|
||||||
const maxRows = 10
|
const maxRows = 10
|
||||||
|
|
||||||
const { selectedModel } = useModelProvider()
|
const { selectedModel, selectedProvider } = useModelProvider()
|
||||||
const { sendMessage } = useChat()
|
const { sendMessage } = useChat()
|
||||||
const [message, setMessage] = useState('')
|
const [message, setMessage] = useState('')
|
||||||
const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false)
|
const [dropdownToolsAvailable, setDropdownToolsAvailable] = useState(false)
|
||||||
@ -75,6 +76,8 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
}>
|
}>
|
||||||
>([])
|
>([])
|
||||||
const [connectedServers, setConnectedServers] = useState<string[]>([])
|
const [connectedServers, setConnectedServers] = useState<string[]>([])
|
||||||
|
const [isDragOver, setIsDragOver] = useState(false)
|
||||||
|
const [hasMmproj, setHasMmproj] = useState(false)
|
||||||
|
|
||||||
// Check for connected MCP servers
|
// Check for connected MCP servers
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -96,6 +99,29 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
return () => clearInterval(intervalId)
|
return () => clearInterval(intervalId)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
// Check for mmproj existence or vision capability when model changes
|
||||||
|
useEffect(() => {
|
||||||
|
const checkMmprojSupport = async () => {
|
||||||
|
if (selectedModel?.id) {
|
||||||
|
try {
|
||||||
|
// Only check mmproj for llamacpp provider
|
||||||
|
if (selectedProvider === 'llamacpp') {
|
||||||
|
const hasLocalMmproj = await checkMmprojExists(selectedModel.id)
|
||||||
|
setHasMmproj(hasLocalMmproj)
|
||||||
|
} else {
|
||||||
|
// For non-llamacpp providers, only check vision capability
|
||||||
|
setHasMmproj(true)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error checking mmproj:', error)
|
||||||
|
setHasMmproj(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
checkMmprojSupport()
|
||||||
|
}, [selectedModel?.id, selectedProvider])
|
||||||
|
|
||||||
// Check if there are active MCP servers
|
// Check if there are active MCP servers
|
||||||
const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0
|
const hasActiveMCPServers = connectedServers.length > 0 || tools.length > 0
|
||||||
|
|
||||||
@ -104,11 +130,16 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
setMessage('Please select a model to start chatting.')
|
setMessage('Please select a model to start chatting.')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if (!prompt.trim()) {
|
if (!prompt.trim() && uploadedFiles.length === 0) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setMessage('')
|
setMessage('')
|
||||||
sendMessage(prompt)
|
sendMessage(
|
||||||
|
prompt,
|
||||||
|
true,
|
||||||
|
uploadedFiles.length > 0 ? uploadedFiles : undefined
|
||||||
|
)
|
||||||
|
setUploadedFiles([])
|
||||||
}
|
}
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -191,8 +222,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
return 'image/jpeg'
|
return 'image/jpeg'
|
||||||
case 'png':
|
case 'png':
|
||||||
return 'image/png'
|
return 'image/png'
|
||||||
case 'pdf':
|
|
||||||
return 'application/pdf'
|
|
||||||
default:
|
default:
|
||||||
return ''
|
return ''
|
||||||
}
|
}
|
||||||
@ -226,17 +255,12 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
const detectedType = file.type || getFileTypeFromExtension(file.name)
|
const detectedType = file.type || getFileTypeFromExtension(file.name)
|
||||||
const actualType = getFileTypeFromExtension(file.name) || detectedType
|
const actualType = getFileTypeFromExtension(file.name) || detectedType
|
||||||
|
|
||||||
// Check file type
|
// Check file type - images only
|
||||||
const allowedTypes = [
|
const allowedTypes = ['image/jpg', 'image/jpeg', 'image/png']
|
||||||
'image/jpg',
|
|
||||||
'image/jpeg',
|
|
||||||
'image/png',
|
|
||||||
'application/pdf',
|
|
||||||
]
|
|
||||||
|
|
||||||
if (!allowedTypes.includes(actualType)) {
|
if (!allowedTypes.includes(actualType)) {
|
||||||
setMessage(
|
setMessage(
|
||||||
`File is not supported. Only JPEG, JPG, PNG, and PDF files are allowed.`
|
`File attachments not supported currently. Only JPEG, JPG, and PNG files are allowed.`
|
||||||
)
|
)
|
||||||
// Reset file input to allow re-uploading
|
// Reset file input to allow re-uploading
|
||||||
if (fileInputRef.current) {
|
if (fileInputRef.current) {
|
||||||
@ -287,6 +311,104 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleDragEnter = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
// Only allow drag if model supports mmproj
|
||||||
|
if (hasMmproj) {
|
||||||
|
setIsDragOver(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDragLeave = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
// Only set dragOver to false if we're leaving the drop zone entirely
|
||||||
|
// In Tauri, relatedTarget can be null, so we need to handle that case
|
||||||
|
const relatedTarget = e.relatedTarget as Node | null
|
||||||
|
if (!relatedTarget || !e.currentTarget.contains(relatedTarget)) {
|
||||||
|
setIsDragOver(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDragOver = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
// Ensure drag state is maintained during drag over
|
||||||
|
if (hasMmproj) {
|
||||||
|
setIsDragOver(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleDrop = (e: React.DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
setIsDragOver(false)
|
||||||
|
|
||||||
|
// Only allow drop if model supports mmproj
|
||||||
|
if (!hasMmproj) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if dataTransfer exists (it might not in some Tauri scenarios)
|
||||||
|
if (!e.dataTransfer) {
|
||||||
|
console.warn('No dataTransfer available in drop event')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const files = e.dataTransfer.files
|
||||||
|
if (files && files.length > 0) {
|
||||||
|
// Create a synthetic event to reuse existing file handling logic
|
||||||
|
const syntheticEvent = {
|
||||||
|
target: {
|
||||||
|
files: files,
|
||||||
|
},
|
||||||
|
} as React.ChangeEvent<HTMLInputElement>
|
||||||
|
|
||||||
|
handleFileChange(syntheticEvent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handlePaste = (e: React.ClipboardEvent) => {
|
||||||
|
const clipboardItems = e.clipboardData?.items
|
||||||
|
if (!clipboardItems) return
|
||||||
|
|
||||||
|
// Only allow paste if model supports mmproj
|
||||||
|
if (!hasMmproj) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const imageItems = Array.from(clipboardItems).filter((item) =>
|
||||||
|
item.type.startsWith('image/')
|
||||||
|
)
|
||||||
|
|
||||||
|
if (imageItems.length > 0) {
|
||||||
|
e.preventDefault()
|
||||||
|
|
||||||
|
const files: File[] = []
|
||||||
|
let processedCount = 0
|
||||||
|
|
||||||
|
imageItems.forEach((item) => {
|
||||||
|
const file = item.getAsFile()
|
||||||
|
if (file) {
|
||||||
|
files.push(file)
|
||||||
|
}
|
||||||
|
processedCount++
|
||||||
|
|
||||||
|
// When all items are processed, handle the valid files
|
||||||
|
if (processedCount === imageItems.length && files.length > 0) {
|
||||||
|
const syntheticEvent = {
|
||||||
|
target: {
|
||||||
|
files: files,
|
||||||
|
},
|
||||||
|
} as unknown as React.ChangeEvent<HTMLInputElement>
|
||||||
|
|
||||||
|
handleFileChange(syntheticEvent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
@ -311,8 +433,14 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view',
|
'relative z-20 px-0 pb-10 border border-main-view-fg/5 rounded-lg text-main-view-fg bg-main-view',
|
||||||
isFocused && 'ring-1 ring-main-view-fg/10'
|
isFocused && 'ring-1 ring-main-view-fg/10',
|
||||||
|
isDragOver && 'ring-2 ring-accent border-accent'
|
||||||
)}
|
)}
|
||||||
|
data-drop-zone={hasMmproj ? 'true' : undefined}
|
||||||
|
onDragEnter={hasMmproj ? handleDragEnter : undefined}
|
||||||
|
onDragLeave={hasMmproj ? handleDragLeave : undefined}
|
||||||
|
onDragOver={hasMmproj ? handleDragOver : undefined}
|
||||||
|
onDrop={hasMmproj ? handleDrop : undefined}
|
||||||
>
|
>
|
||||||
{uploadedFiles.length > 0 && (
|
{uploadedFiles.length > 0 && (
|
||||||
<div className="flex gap-3 items-center p-2 pb-0">
|
<div className="flex gap-3 items-center p-2 pb-0">
|
||||||
@ -332,25 +460,6 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
alt={`${file.name} - ${index}`}
|
alt={`${file.name} - ${index}`}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{file.type === 'application/pdf' && (
|
|
||||||
<div className="bg-main-view-fg/4 h-full rounded-lg p-2 max-w-[400px] pr-4">
|
|
||||||
<div className="flex gap-2 items-center justify-center h-full">
|
|
||||||
<div className="size-10 rounded-md bg-main-view shrink-0 flex items-center justify-center">
|
|
||||||
<span className="uppercase font-bold">
|
|
||||||
{file.name.split('.').pop()}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div className="truncate">
|
|
||||||
<h6 className="truncate mb-0.5 text-main-view-fg/80">
|
|
||||||
{file.name}
|
|
||||||
</h6>
|
|
||||||
<p className="text-xs text-main-view-fg/70">
|
|
||||||
{toGigabytes(file.size)}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<div
|
<div
|
||||||
className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer"
|
className="absolute -top-1 -right-2.5 bg-destructive size-5 flex rounded-full items-center justify-center cursor-pointer"
|
||||||
onClick={() => handleRemoveFile(index)}
|
onClick={() => handleRemoveFile(index)}
|
||||||
@ -369,7 +478,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
rows={1}
|
rows={1}
|
||||||
maxRows={10}
|
maxRows={10}
|
||||||
value={prompt}
|
value={prompt}
|
||||||
data-test-id={'chat-input'}
|
data-testid={'chat-input'}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
setPrompt(e.target.value)
|
setPrompt(e.target.value)
|
||||||
// Count the number of newlines to estimate rows
|
// Count the number of newlines to estimate rows
|
||||||
@ -378,14 +487,21 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
}}
|
}}
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
// e.keyCode 229 is for IME input with Safari
|
// e.keyCode 229 is for IME input with Safari
|
||||||
const isComposing = e.nativeEvent.isComposing || e.keyCode === 229;
|
const isComposing =
|
||||||
if (e.key === 'Enter' && !e.shiftKey && prompt.trim() && !isComposing) {
|
e.nativeEvent.isComposing || e.keyCode === 229
|
||||||
|
if (
|
||||||
|
e.key === 'Enter' &&
|
||||||
|
!e.shiftKey &&
|
||||||
|
prompt.trim() &&
|
||||||
|
!isComposing
|
||||||
|
) {
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
// Submit the message when Enter is pressed without Shift
|
// Submit the message when Enter is pressed without Shift
|
||||||
handleSendMesage(prompt)
|
handleSendMesage(prompt)
|
||||||
// When Shift+Enter is pressed, a new line is added (default behavior)
|
// When Shift+Enter is pressed, a new line is added (default behavior)
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
onPaste={hasMmproj ? handlePaste : undefined}
|
||||||
placeholder={t('common:placeholder.chatInput')}
|
placeholder={t('common:placeholder.chatInput')}
|
||||||
autoFocus
|
autoFocus
|
||||||
spellCheck={spellCheckChatInput}
|
spellCheck={spellCheckChatInput}
|
||||||
@ -406,7 +522,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
<div className="px-1 flex items-center gap-1">
|
<div className="px-1 flex items-center gap-1">
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
'px-1 flex items-center gap-1',
|
'px-1 flex items-center',
|
||||||
streamingContent && 'opacity-50 pointer-events-none'
|
streamingContent && 'opacity-50 pointer-events-none'
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
@ -418,19 +534,22 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
useLastUsedModel={initialMessage}
|
useLastUsedModel={initialMessage}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{/* File attachment - always available */}
|
{/* File attachment - show only for models with mmproj */}
|
||||||
<div
|
{hasMmproj && (
|
||||||
className="h-6 hidden p-1 items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"
|
<div
|
||||||
onClick={handleAttachmentClick}
|
className="h-6 p-1 ml-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1"
|
||||||
>
|
onClick={handleAttachmentClick}
|
||||||
<IconPaperclip size={18} className="text-main-view-fg/50" />
|
>
|
||||||
<input
|
<IconPhoto size={18} className="text-main-view-fg/50" />
|
||||||
type="file"
|
<input
|
||||||
ref={fileInputRef}
|
type="file"
|
||||||
className="hidden"
|
ref={fileInputRef}
|
||||||
onChange={handleFileChange}
|
className="hidden"
|
||||||
/>
|
multiple
|
||||||
</div>
|
onChange={handleFileChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{/* Microphone - always available - Temp Hide */}
|
{/* Microphone - always available - Temp Hide */}
|
||||||
{/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
|
{/* <div className="h-6 p-1 flex items-center justify-center rounded-sm hover:bg-main-view-fg/10 transition-all duration-200 ease-in-out gap-1">
|
||||||
<IconMicrophone size={18} className="text-main-view-fg/50" />
|
<IconMicrophone size={18} className="text-main-view-fg/50" />
|
||||||
@ -574,9 +693,13 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
</Button>
|
</Button>
|
||||||
) : (
|
) : (
|
||||||
<Button
|
<Button
|
||||||
variant={!prompt.trim() ? null : 'default'}
|
variant={
|
||||||
|
!prompt.trim() && uploadedFiles.length === 0
|
||||||
|
? null
|
||||||
|
: 'default'
|
||||||
|
}
|
||||||
size="icon"
|
size="icon"
|
||||||
disabled={!prompt.trim()}
|
disabled={!prompt.trim() && uploadedFiles.length === 0}
|
||||||
data-test-id="send-message-button"
|
data-test-id="send-message-button"
|
||||||
onClick={() => handleSendMesage(prompt)}
|
onClick={() => handleSendMesage(prompt)}
|
||||||
>
|
>
|
||||||
@ -590,6 +713,7 @@ const ChatInput = ({ model, className, initialMessage }: ChatInputProps) => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{message && (
|
{message && (
|
||||||
<div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out">
|
<div className="bg-main-view-fg/2 -mt-0.5 mx-2 pb-2 px-3 pt-1.5 rounded-b-lg text-xs text-destructive transition-all duration-200 ease-in-out">
|
||||||
<div className="flex items-center gap-1 justify-between">
|
<div className="flex items-center gap-1 justify-between">
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import { localStorageKey } from '@/constants/localStorage'
|
|||||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||||
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
import { useFavoriteModel } from '@/hooks/useFavoriteModel'
|
||||||
import { predefinedProviders } from '@/consts/providers'
|
import { predefinedProviders } from '@/consts/providers'
|
||||||
|
import { checkMmprojExistsAndUpdateOffloadMMprojSetting } from '@/services/models'
|
||||||
|
|
||||||
type DropdownModelProviderProps = {
|
type DropdownModelProviderProps = {
|
||||||
model?: ThreadModel
|
model?: ThreadModel
|
||||||
@ -66,6 +67,7 @@ const DropdownModelProvider = ({
|
|||||||
getModelBy,
|
getModelBy,
|
||||||
selectedProvider,
|
selectedProvider,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
|
updateProvider,
|
||||||
} = useModelProvider()
|
} = useModelProvider()
|
||||||
const [displayModel, setDisplayModel] = useState<string>('')
|
const [displayModel, setDisplayModel] = useState<string>('')
|
||||||
const { updateCurrentThreadModel } = useThreads()
|
const { updateCurrentThreadModel } = useThreads()
|
||||||
@ -79,31 +81,52 @@ const DropdownModelProvider = ({
|
|||||||
const searchInputRef = useRef<HTMLInputElement>(null)
|
const searchInputRef = useRef<HTMLInputElement>(null)
|
||||||
|
|
||||||
// Helper function to check if a model exists in providers
|
// Helper function to check if a model exists in providers
|
||||||
const checkModelExists = useCallback((providerName: string, modelId: string) => {
|
const checkModelExists = useCallback(
|
||||||
const provider = providers.find(
|
(providerName: string, modelId: string) => {
|
||||||
(p) => p.provider === providerName && p.active
|
const provider = providers.find(
|
||||||
)
|
(p) => p.provider === providerName && p.active
|
||||||
return provider?.models.find((m) => m.id === modelId)
|
)
|
||||||
}, [providers])
|
return provider?.models.find((m) => m.id === modelId)
|
||||||
|
},
|
||||||
|
[providers]
|
||||||
|
)
|
||||||
|
|
||||||
// Initialize model provider only once
|
// Initialize model provider only once
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Auto select model when existing thread is passed
|
const initializeModel = async () => {
|
||||||
if (model) {
|
// Auto select model when existing thread is passed
|
||||||
selectModelProvider(model?.provider as string, model?.id as string)
|
if (model) {
|
||||||
if (!checkModelExists(model.provider, model.id)) {
|
selectModelProvider(model?.provider as string, model?.id as string)
|
||||||
selectModelProvider('', '')
|
if (!checkModelExists(model.provider, model.id)) {
|
||||||
}
|
selectModelProvider('', '')
|
||||||
} else if (useLastUsedModel) {
|
}
|
||||||
// Try to use last used model only when explicitly requested (for new chat)
|
// Check mmproj existence for llamacpp models
|
||||||
const lastUsed = getLastUsedModel()
|
if (model?.provider === 'llamacpp') {
|
||||||
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
|
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
|
||||||
selectModelProvider(lastUsed.provider, lastUsed.model)
|
model.id as string,
|
||||||
} else {
|
updateProvider,
|
||||||
// Fallback to default model if last used model no longer exists
|
getProviderByName
|
||||||
selectModelProvider('', '')
|
)
|
||||||
|
}
|
||||||
|
} else if (useLastUsedModel) {
|
||||||
|
// Try to use last used model only when explicitly requested (for new chat)
|
||||||
|
const lastUsed = getLastUsedModel()
|
||||||
|
if (lastUsed && checkModelExists(lastUsed.provider, lastUsed.model)) {
|
||||||
|
selectModelProvider(lastUsed.provider, lastUsed.model)
|
||||||
|
if (lastUsed.provider === 'llamacpp') {
|
||||||
|
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
|
||||||
|
lastUsed.model,
|
||||||
|
updateProvider,
|
||||||
|
getProviderByName
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
selectModelProvider('', '')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
initializeModel()
|
||||||
}, [
|
}, [
|
||||||
model,
|
model,
|
||||||
selectModelProvider,
|
selectModelProvider,
|
||||||
@ -111,6 +134,8 @@ const DropdownModelProvider = ({
|
|||||||
providers,
|
providers,
|
||||||
useLastUsedModel,
|
useLastUsedModel,
|
||||||
checkModelExists,
|
checkModelExists,
|
||||||
|
updateProvider,
|
||||||
|
getProviderByName,
|
||||||
])
|
])
|
||||||
|
|
||||||
// Update display model when selection changes
|
// Update display model when selection changes
|
||||||
@ -245,7 +270,7 @@ const DropdownModelProvider = ({
|
|||||||
}, [filteredItems, providers, searchValue, favoriteModels])
|
}, [filteredItems, providers, searchValue, favoriteModels])
|
||||||
|
|
||||||
const handleSelect = useCallback(
|
const handleSelect = useCallback(
|
||||||
(searchableModel: SearchableModel) => {
|
async (searchableModel: SearchableModel) => {
|
||||||
selectModelProvider(
|
selectModelProvider(
|
||||||
searchableModel.provider.provider,
|
searchableModel.provider.provider,
|
||||||
searchableModel.model.id
|
searchableModel.model.id
|
||||||
@ -254,6 +279,16 @@ const DropdownModelProvider = ({
|
|||||||
id: searchableModel.model.id,
|
id: searchableModel.model.id,
|
||||||
provider: searchableModel.provider.provider,
|
provider: searchableModel.provider.provider,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Check mmproj existence for llamacpp models
|
||||||
|
if (searchableModel.provider.provider === 'llamacpp') {
|
||||||
|
await checkMmprojExistsAndUpdateOffloadMMprojSetting(
|
||||||
|
searchableModel.model.id,
|
||||||
|
updateProvider,
|
||||||
|
getProviderByName
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Store the selected model as last used
|
// Store the selected model as last used
|
||||||
if (useLastUsedModel) {
|
if (useLastUsedModel) {
|
||||||
setLastUsedModel(
|
setLastUsedModel(
|
||||||
@ -264,7 +299,13 @@ const DropdownModelProvider = ({
|
|||||||
setSearchValue('')
|
setSearchValue('')
|
||||||
setOpen(false)
|
setOpen(false)
|
||||||
},
|
},
|
||||||
[selectModelProvider, updateCurrentThreadModel, useLastUsedModel]
|
[
|
||||||
|
selectModelProvider,
|
||||||
|
updateCurrentThreadModel,
|
||||||
|
useLastUsedModel,
|
||||||
|
updateProvider,
|
||||||
|
getProviderByName,
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
const currentModel = selectedModel?.id
|
const currentModel = selectedModel?.id
|
||||||
|
|||||||
@ -70,8 +70,8 @@ export function ModelSetting({
|
|||||||
models: updatedModels,
|
models: updatedModels,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Call debounced stopModel only when updating ctx_len or ngl
|
// Call debounced stopModel only when updating ctx_len, ngl, chat_template, or offload_mmproj
|
||||||
if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template') {
|
if (key === 'ctx_len' || key === 'ngl' || key === 'chat_template' || key === 'offload_mmproj') {
|
||||||
debouncedStopModel(model.id)
|
debouncedStopModel(model.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { ThreadMessage } from '@janhq/core'
|
import { ThreadMessage } from '@janhq/core'
|
||||||
import { RenderMarkdown } from './RenderMarkdown'
|
import { RenderMarkdown } from './RenderMarkdown'
|
||||||
import React, { Fragment, memo, useCallback, useMemo, useState } from 'react'
|
import React, { Fragment, memo, useCallback, useMemo, useState } from 'react'
|
||||||
@ -144,7 +145,7 @@ export const ThreadContent = memo(
|
|||||||
isLastMessage?: boolean
|
isLastMessage?: boolean
|
||||||
index?: number
|
index?: number
|
||||||
showAssistant?: boolean
|
showAssistant?: boolean
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
streamTools?: any
|
streamTools?: any
|
||||||
contextOverflowModal?: React.ReactNode | null
|
contextOverflowModal?: React.ReactNode | null
|
||||||
updateMessage?: (item: ThreadMessage, message: string) => void
|
updateMessage?: (item: ThreadMessage, message: string) => void
|
||||||
@ -172,9 +173,12 @@ export const ThreadContent = memo(
|
|||||||
const { reasoningSegment, textSegment } = useMemo(() => {
|
const { reasoningSegment, textSegment } = useMemo(() => {
|
||||||
// Check for thinking formats
|
// Check for thinking formats
|
||||||
const hasThinkTag = text.includes('<think>') && !text.includes('</think>')
|
const hasThinkTag = text.includes('<think>') && !text.includes('</think>')
|
||||||
const hasAnalysisChannel = text.includes('<|channel|>analysis<|message|>') && !text.includes('<|start|>assistant<|channel|>final<|message|>')
|
const hasAnalysisChannel =
|
||||||
|
text.includes('<|channel|>analysis<|message|>') &&
|
||||||
|
!text.includes('<|start|>assistant<|channel|>final<|message|>')
|
||||||
|
|
||||||
if (hasThinkTag || hasAnalysisChannel) return { reasoningSegment: text, textSegment: '' }
|
if (hasThinkTag || hasAnalysisChannel)
|
||||||
|
return { reasoningSegment: text, textSegment: '' }
|
||||||
|
|
||||||
// Check for completed think tag format
|
// Check for completed think tag format
|
||||||
const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/)
|
const thinkMatch = text.match(/<think>([\s\S]*?)<\/think>/)
|
||||||
@ -187,7 +191,9 @@ export const ThreadContent = memo(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check for completed analysis channel format
|
// Check for completed analysis channel format
|
||||||
const analysisMatch = text.match(/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/)
|
const analysisMatch = text.match(
|
||||||
|
/<\|channel\|>analysis<\|message\|>([\s\S]*?)<\|start\|>assistant<\|channel\|>final<\|message\|>/
|
||||||
|
)
|
||||||
if (analysisMatch?.index !== undefined) {
|
if (analysisMatch?.index !== undefined) {
|
||||||
const splitIndex = analysisMatch.index + analysisMatch[0].length
|
const splitIndex = analysisMatch.index + analysisMatch[0].length
|
||||||
return {
|
return {
|
||||||
@ -213,7 +219,36 @@ export const ThreadContent = memo(
|
|||||||
}
|
}
|
||||||
if (toSendMessage) {
|
if (toSendMessage) {
|
||||||
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
|
deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '')
|
||||||
sendMessage(toSendMessage.content?.[0]?.text?.value || '')
|
// Extract text content and any attachments
|
||||||
|
const textContent =
|
||||||
|
toSendMessage.content?.find((c) => c.type === 'text')?.text?.value ||
|
||||||
|
''
|
||||||
|
const attachments = toSendMessage.content
|
||||||
|
?.filter((c) => (c.type === 'image_url' && c.image_url?.url) || false)
|
||||||
|
.map((c) => {
|
||||||
|
if (c.type === 'image_url' && c.image_url?.url) {
|
||||||
|
const url = c.image_url.url
|
||||||
|
const [mimeType, base64] = url
|
||||||
|
.replace('data:', '')
|
||||||
|
.split(';base64,')
|
||||||
|
return {
|
||||||
|
name: 'image', // We don't have the original filename
|
||||||
|
type: mimeType,
|
||||||
|
size: 0, // We don't have the original size
|
||||||
|
base64: base64,
|
||||||
|
dataUrl: url,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
})
|
||||||
|
.filter(Boolean) as Array<{
|
||||||
|
name: string
|
||||||
|
type: string
|
||||||
|
size: number
|
||||||
|
base64: string
|
||||||
|
dataUrl: string
|
||||||
|
}>
|
||||||
|
sendMessage(textContent, true, attachments)
|
||||||
}
|
}
|
||||||
}, [deleteMessage, getMessages, item, sendMessage])
|
}, [deleteMessage, getMessages, item, sendMessage])
|
||||||
|
|
||||||
@ -255,22 +290,68 @@ export const ThreadContent = memo(
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Fragment>
|
<Fragment>
|
||||||
{item.content?.[0]?.text && item.role === 'user' && (
|
{item.role === 'user' && (
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal">
|
{/* Render attachments above the message bubble */}
|
||||||
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] ">
|
{item.content?.some(
|
||||||
<div className="select-text">
|
(c) => (c.type === 'image_url' && c.image_url?.url) || false
|
||||||
<RenderMarkdown
|
) && (
|
||||||
content={item.content?.[0].text.value}
|
<div className="flex justify-end w-full mb-2">
|
||||||
components={linkComponents}
|
<div className="flex flex-wrap gap-2 max-w-[80%] justify-end">
|
||||||
isUser
|
{item.content
|
||||||
/>
|
?.filter(
|
||||||
|
(c) =>
|
||||||
|
(c.type === 'image_url' && c.image_url?.url) || false
|
||||||
|
)
|
||||||
|
.map((contentPart, index) => {
|
||||||
|
// Handle images
|
||||||
|
if (
|
||||||
|
contentPart.type === 'image_url' &&
|
||||||
|
contentPart.image_url?.url
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<div key={index} className="relative">
|
||||||
|
<img
|
||||||
|
src={contentPart.image_url.url}
|
||||||
|
alt="Uploaded attachment"
|
||||||
|
className="size-40 rounded-md object-cover border border-main-view-fg/10"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
})}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
)}
|
||||||
|
|
||||||
|
{/* Render text content in the message bubble */}
|
||||||
|
{item.content?.some((c) => c.type === 'text' && c.text?.value) && (
|
||||||
|
<div className="flex justify-end w-full h-full text-start break-words whitespace-normal">
|
||||||
|
<div className="bg-main-view-fg/4 relative text-main-view-fg p-2 rounded-md inline-block max-w-[80%] ">
|
||||||
|
<div className="select-text">
|
||||||
|
{item.content
|
||||||
|
?.filter((c) => c.type === 'text' && c.text?.value)
|
||||||
|
.map((contentPart, index) => (
|
||||||
|
<div key={index}>
|
||||||
|
<RenderMarkdown
|
||||||
|
content={contentPart.text!.value}
|
||||||
|
components={linkComponents}
|
||||||
|
isUser
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
<div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2">
|
<div className="flex items-center justify-end gap-2 text-main-view-fg/60 text-xs mt-2">
|
||||||
<EditDialog
|
<EditDialog
|
||||||
message={item.content?.[0]?.text.value}
|
message={
|
||||||
|
item.content?.find((c) => c.type === 'text')?.text?.value ||
|
||||||
|
''
|
||||||
|
}
|
||||||
setMessage={(message) => {
|
setMessage={(message) => {
|
||||||
if (item.updateMessage) {
|
if (item.updateMessage) {
|
||||||
item.updateMessage(item, message)
|
item.updateMessage(item, message)
|
||||||
|
|||||||
@ -73,6 +73,11 @@ vi.mock('@/services/mcp', () => ({
|
|||||||
|
|
||||||
vi.mock('@/services/models', () => ({
|
vi.mock('@/services/models', () => ({
|
||||||
stopAllModels: vi.fn(),
|
stopAllModels: vi.fn(),
|
||||||
|
checkMmprojExists: vi.fn(() => Promise.resolve(true)),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../MovingBorder', () => ({
|
||||||
|
MovingBorder: ({ children }: { children: React.ReactNode }) => <div data-testid="moving-border">{children}</div>,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
describe('ChatInput', () => {
|
describe('ChatInput', () => {
|
||||||
@ -231,7 +236,7 @@ describe('ChatInput', () => {
|
|||||||
const sendButton = document.querySelector('[data-test-id="send-message-button"]')
|
const sendButton = document.querySelector('[data-test-id="send-message-button"]')
|
||||||
await user.click(sendButton)
|
await user.click(sendButton)
|
||||||
|
|
||||||
expect(mockSendMessage).toHaveBeenCalledWith('Hello world')
|
expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('sends message when Enter key is pressed', async () => {
|
it('sends message when Enter key is pressed', async () => {
|
||||||
@ -248,7 +253,7 @@ describe('ChatInput', () => {
|
|||||||
const textarea = screen.getByRole('textbox')
|
const textarea = screen.getByRole('textbox')
|
||||||
await user.type(textarea, '{Enter}')
|
await user.type(textarea, '{Enter}')
|
||||||
|
|
||||||
expect(mockSendMessage).toHaveBeenCalledWith('Hello world')
|
expect(mockSendMessage).toHaveBeenCalledWith('Hello world', true, undefined)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('does not send message when Shift+Enter is pressed', async () => {
|
it('does not send message when Shift+Enter is pressed', async () => {
|
||||||
@ -343,9 +348,12 @@ describe('ChatInput', () => {
|
|||||||
const user = userEvent.setup()
|
const user = userEvent.setup()
|
||||||
renderWithRouter()
|
renderWithRouter()
|
||||||
|
|
||||||
// File upload is rendered as hidden input element
|
// Wait for async effects to complete (mmproj check)
|
||||||
const fileInput = document.querySelector('input[type="file"]')
|
await waitFor(() => {
|
||||||
expect(fileInput).toBeInTheDocument()
|
// File upload is rendered as hidden input element
|
||||||
|
const fileInput = document.querySelector('input[type="file"]')
|
||||||
|
expect(fileInput).toBeInTheDocument()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('disables input when streaming', () => {
|
it('disables input when streaming', () => {
|
||||||
@ -361,7 +369,7 @@ describe('ChatInput', () => {
|
|||||||
renderWithRouter()
|
renderWithRouter()
|
||||||
})
|
})
|
||||||
|
|
||||||
const textarea = screen.getByRole('textbox')
|
const textarea = screen.getByTestId('chat-input')
|
||||||
expect(textarea).toBeDisabled()
|
expect(textarea).toBeDisabled()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -378,4 +386,28 @@ describe('ChatInput', () => {
|
|||||||
expect(toolsIcon).toBeInTheDocument()
|
expect(toolsIcon).toBeInTheDocument()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('uses selectedProvider for provider checks', () => {
|
||||||
|
// Test that the component correctly uses selectedProvider instead of selectedModel.provider
|
||||||
|
vi.mocked(useModelProvider).mockReturnValue({
|
||||||
|
selectedModel: {
|
||||||
|
id: 'test-model',
|
||||||
|
capabilities: ['vision'],
|
||||||
|
},
|
||||||
|
providers: [],
|
||||||
|
getModelBy: vi.fn(),
|
||||||
|
selectModelProvider: vi.fn(),
|
||||||
|
selectedProvider: 'llamacpp',
|
||||||
|
setProviders: vi.fn(),
|
||||||
|
getProviderByName: vi.fn(),
|
||||||
|
updateProvider: vi.fn(),
|
||||||
|
addProvider: vi.fn(),
|
||||||
|
deleteProvider: vi.fn(),
|
||||||
|
deleteModel: vi.fn(),
|
||||||
|
deletedModels: [],
|
||||||
|
})
|
||||||
|
|
||||||
|
// This test ensures the component renders without errors when using selectedProvider
|
||||||
|
expect(() => renderWithRouter()).not.toThrow()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
@ -203,7 +203,17 @@ export const useChat = () => {
|
|||||||
)
|
)
|
||||||
|
|
||||||
const sendMessage = useCallback(
|
const sendMessage = useCallback(
|
||||||
async (message: string, troubleshooting = true) => {
|
async (
|
||||||
|
message: string,
|
||||||
|
troubleshooting = true,
|
||||||
|
attachments?: Array<{
|
||||||
|
name: string
|
||||||
|
type: string
|
||||||
|
size: number
|
||||||
|
base64: string
|
||||||
|
dataUrl: string
|
||||||
|
}>
|
||||||
|
) => {
|
||||||
const activeThread = await getCurrentThread()
|
const activeThread = await getCurrentThread()
|
||||||
|
|
||||||
resetTokenSpeed()
|
resetTokenSpeed()
|
||||||
@ -217,7 +227,7 @@ export const useChat = () => {
|
|||||||
updateStreamingContent(emptyThreadContent)
|
updateStreamingContent(emptyThreadContent)
|
||||||
// Do not add new message on retry
|
// Do not add new message on retry
|
||||||
if (troubleshooting)
|
if (troubleshooting)
|
||||||
addMessage(newUserThreadContent(activeThread.id, message))
|
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateThreadTimestamp(activeThread.id)
|
||||||
setPrompt('')
|
setPrompt('')
|
||||||
try {
|
try {
|
||||||
@ -231,7 +241,7 @@ export const useChat = () => {
|
|||||||
messages,
|
messages,
|
||||||
currentAssistant?.instructions
|
currentAssistant?.instructions
|
||||||
)
|
)
|
||||||
if (troubleshooting) builder.addUserMessage(message)
|
if (troubleshooting) builder.addUserMessage(message, attachments)
|
||||||
|
|
||||||
let isCompleted = false
|
let isCompleted = false
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import {
|
import {
|
||||||
ContentType,
|
ContentType,
|
||||||
ChatCompletionRole,
|
ChatCompletionRole,
|
||||||
@ -51,11 +52,16 @@ export type ChatCompletionResponse =
|
|||||||
*/
|
*/
|
||||||
export const newUserThreadContent = (
|
export const newUserThreadContent = (
|
||||||
threadId: string,
|
threadId: string,
|
||||||
content: string
|
content: string,
|
||||||
): ThreadMessage => ({
|
attachments?: Array<{
|
||||||
type: 'text',
|
name: string
|
||||||
role: ChatCompletionRole.User,
|
type: string
|
||||||
content: [
|
size: number
|
||||||
|
base64: string
|
||||||
|
dataUrl: string
|
||||||
|
}>
|
||||||
|
): ThreadMessage => {
|
||||||
|
const contentParts = [
|
||||||
{
|
{
|
||||||
type: ContentType.Text,
|
type: ContentType.Text,
|
||||||
text: {
|
text: {
|
||||||
@ -63,14 +69,35 @@ export const newUserThreadContent = (
|
|||||||
annotations: [],
|
annotations: [],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
]
|
||||||
id: ulid(),
|
|
||||||
object: 'thread.message',
|
// Add attachments to content array
|
||||||
thread_id: threadId,
|
if (attachments) {
|
||||||
status: MessageStatus.Ready,
|
attachments.forEach((attachment) => {
|
||||||
created_at: 0,
|
if (attachment.type.startsWith('image/')) {
|
||||||
completed_at: 0,
|
contentParts.push({
|
||||||
})
|
type: ContentType.Image,
|
||||||
|
image_url: {
|
||||||
|
url: `data:${attachment.type};base64,${attachment.base64}`,
|
||||||
|
detail: 'auto',
|
||||||
|
},
|
||||||
|
} as any)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'text',
|
||||||
|
role: ChatCompletionRole.User,
|
||||||
|
content: contentParts,
|
||||||
|
id: ulid(),
|
||||||
|
object: 'thread.message',
|
||||||
|
thread_id: threadId,
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
created_at: 0,
|
||||||
|
completed_at: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* @fileoverview Helper functions for creating thread content.
|
* @fileoverview Helper functions for creating thread content.
|
||||||
* These functions are used to create thread content objects
|
* These functions are used to create thread content objects
|
||||||
@ -162,13 +189,11 @@ export const sendCompletion = async (
|
|||||||
if (
|
if (
|
||||||
thread.model.id &&
|
thread.model.id &&
|
||||||
!Object.values(models[providerName]).flat().includes(thread.model.id) &&
|
!Object.values(models[providerName]).flat().includes(thread.model.id) &&
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
!tokenJS.extendedModelExist(providerName as any, thread.model.id) &&
|
!tokenJS.extendedModelExist(providerName as any, thread.model.id) &&
|
||||||
provider.provider !== 'llamacpp'
|
provider.provider !== 'llamacpp'
|
||||||
) {
|
) {
|
||||||
try {
|
try {
|
||||||
tokenJS.extendModelList(
|
tokenJS.extendModelList(
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
providerName as any,
|
providerName as any,
|
||||||
thread.model.id,
|
thread.model.id,
|
||||||
// This is to inherit the model capabilities from another built-in model
|
// This is to inherit the model capabilities from another built-in model
|
||||||
@ -201,7 +226,7 @@ export const sendCompletion = async (
|
|||||||
? await tokenJS.chat.completions.create(
|
? await tokenJS.chat.completions.create(
|
||||||
{
|
{
|
||||||
stream: true,
|
stream: true,
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
provider: providerName as any,
|
provider: providerName as any,
|
||||||
model: thread.model?.id,
|
model: thread.model?.id,
|
||||||
messages,
|
messages,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { ChatCompletionMessageParam } from 'token.js'
|
import { ChatCompletionMessageParam } from 'token.js'
|
||||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
import { ThreadMessage } from '@janhq/core'
|
import { ThreadMessage } from '@janhq/core'
|
||||||
@ -19,32 +20,106 @@ export class CompletionMessagesBuilder {
|
|||||||
this.messages.push(
|
this.messages.push(
|
||||||
...messages
|
...messages
|
||||||
.filter((e) => !e.metadata?.error)
|
.filter((e) => !e.metadata?.error)
|
||||||
.map<ChatCompletionMessageParam>(
|
.map<ChatCompletionMessageParam>((msg) => {
|
||||||
(msg) =>
|
if (msg.role === 'assistant') {
|
||||||
({
|
return {
|
||||||
role: msg.role,
|
role: msg.role,
|
||||||
content:
|
content: this.normalizeContent(
|
||||||
msg.role === 'assistant'
|
msg.content[0]?.text?.value || '.'
|
||||||
? this.normalizeContent(msg.content[0]?.text?.value || '.')
|
),
|
||||||
: msg.content[0]?.text?.value || '.',
|
} as ChatCompletionMessageParam
|
||||||
}) as ChatCompletionMessageParam
|
} else {
|
||||||
)
|
// For user messages, handle multimodal content
|
||||||
|
if (msg.content.length > 1) {
|
||||||
|
// Multiple content parts (text + images + files)
|
||||||
|
|
||||||
|
const content = msg.content.map((contentPart) => {
|
||||||
|
if (contentPart.type === 'text') {
|
||||||
|
return {
|
||||||
|
type: 'text',
|
||||||
|
text: contentPart.text?.value || '',
|
||||||
|
}
|
||||||
|
} else if (contentPart.type === 'image_url') {
|
||||||
|
return {
|
||||||
|
type: 'image_url',
|
||||||
|
image_url: {
|
||||||
|
url: contentPart.image_url?.url || '',
|
||||||
|
detail: contentPart.image_url?.detail || 'auto',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return contentPart
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
role: msg.role,
|
||||||
|
content,
|
||||||
|
} as ChatCompletionMessageParam
|
||||||
|
} else {
|
||||||
|
// Single text content
|
||||||
|
return {
|
||||||
|
role: msg.role,
|
||||||
|
content: msg.content[0]?.text?.value || '.',
|
||||||
|
} as ChatCompletionMessageParam
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a user message to the messages array.
|
* Add a user message to the messages array.
|
||||||
* @param content - The content of the user message.
|
* @param content - The content of the user message.
|
||||||
|
* @param attachments - Optional attachments for the message.
|
||||||
*/
|
*/
|
||||||
addUserMessage(content: string) {
|
addUserMessage(
|
||||||
|
content: string,
|
||||||
|
attachments?: Array<{
|
||||||
|
name: string
|
||||||
|
type: string
|
||||||
|
size: number
|
||||||
|
base64: string
|
||||||
|
dataUrl: string
|
||||||
|
}>
|
||||||
|
) {
|
||||||
// Ensure no consecutive user messages
|
// Ensure no consecutive user messages
|
||||||
if (this.messages[this.messages.length - 1]?.role === 'user') {
|
if (this.messages[this.messages.length - 1]?.role === 'user') {
|
||||||
this.messages.pop()
|
this.messages.pop()
|
||||||
}
|
}
|
||||||
this.messages.push({
|
|
||||||
role: 'user',
|
// Handle multimodal content with attachments
|
||||||
content: content,
|
if (attachments && attachments.length > 0) {
|
||||||
})
|
const messageContent: any[] = [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: content,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
// Add attachments (images and PDFs)
|
||||||
|
attachments.forEach((attachment) => {
|
||||||
|
if (attachment.type.startsWith('image/')) {
|
||||||
|
messageContent.push({
|
||||||
|
type: 'image_url',
|
||||||
|
image_url: {
|
||||||
|
url: `data:${attachment.type};base64,${attachment.base64}`,
|
||||||
|
detail: 'auto',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
this.messages.push({
|
||||||
|
role: 'user',
|
||||||
|
content: messageContent,
|
||||||
|
} as any)
|
||||||
|
} else {
|
||||||
|
// Text-only message
|
||||||
|
this.messages.push({
|
||||||
|
role: 'user',
|
||||||
|
content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -26,7 +26,7 @@ import {
|
|||||||
ResizablePanel,
|
ResizablePanel,
|
||||||
ResizableHandle,
|
ResizableHandle,
|
||||||
} from '@/components/ui/resizable'
|
} from '@/components/ui/resizable'
|
||||||
import { useCallback } from 'react'
|
import { useCallback, useEffect } from 'react'
|
||||||
import GlobalError from '@/containers/GlobalError'
|
import GlobalError from '@/containers/GlobalError'
|
||||||
import { GlobalEventHandler } from '@/providers/GlobalEventHandler'
|
import { GlobalEventHandler } from '@/providers/GlobalEventHandler'
|
||||||
|
|
||||||
@ -65,6 +65,41 @@ const AppLayout = () => {
|
|||||||
[setLeftPanelSize, setLeftPanel]
|
[setLeftPanelSize, setLeftPanel]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Prevent default drag and drop behavior globally
|
||||||
|
useEffect(() => {
|
||||||
|
const preventDefaults = (e: DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleGlobalDrop = (e: DragEvent) => {
|
||||||
|
e.preventDefault()
|
||||||
|
e.stopPropagation()
|
||||||
|
|
||||||
|
// Only prevent if the target is not within a chat input or other valid drop zone
|
||||||
|
const target = e.target as Element
|
||||||
|
const isValidDropZone = target?.closest('[data-drop-zone="true"]') ||
|
||||||
|
target?.closest('.chat-input-drop-zone') ||
|
||||||
|
target?.closest('[data-tauri-drag-region]')
|
||||||
|
|
||||||
|
if (!isValidDropZone) {
|
||||||
|
// Prevent the file from opening in the window
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add event listeners to prevent default drag/drop behavior
|
||||||
|
window.addEventListener('dragenter', preventDefaults)
|
||||||
|
window.addEventListener('dragover', preventDefaults)
|
||||||
|
window.addEventListener('drop', handleGlobalDrop)
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
window.removeEventListener('dragenter', preventDefaults)
|
||||||
|
window.removeEventListener('dragover', preventDefaults)
|
||||||
|
window.removeEventListener('drop', handleGlobalDrop)
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Fragment>
|
<Fragment>
|
||||||
<AnalyticProvider />
|
<AnalyticProvider />
|
||||||
|
|||||||
@ -290,7 +290,7 @@ describe('models service', () => {
|
|||||||
likes: 100,
|
likes: 100,
|
||||||
tags: ['conversational', 'pytorch'],
|
tags: ['conversational', 'pytorch'],
|
||||||
pipeline_tag: 'text-generation',
|
pipeline_tag: 'text-generation',
|
||||||
created_at: '2023-01-01T00:00:00Z',
|
createdAt: '2023-01-01T00:00:00Z',
|
||||||
last_modified: '2023-12-01T00:00:00Z',
|
last_modified: '2023-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@ -443,7 +443,7 @@ describe('models service', () => {
|
|||||||
likes: 100,
|
likes: 100,
|
||||||
tags: ['conversational'],
|
tags: ['conversational'],
|
||||||
pipeline_tag: 'text-generation',
|
pipeline_tag: 'text-generation',
|
||||||
created_at: '2023-01-01T00:00:00Z',
|
createdAt: '2023-01-01T00:00:00Z',
|
||||||
last_modified: '2023-12-01T00:00:00Z',
|
last_modified: '2023-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@ -471,7 +471,7 @@ describe('models service', () => {
|
|||||||
likes: 100,
|
likes: 100,
|
||||||
tags: ['conversational'],
|
tags: ['conversational'],
|
||||||
pipeline_tag: 'text-generation',
|
pipeline_tag: 'text-generation',
|
||||||
created_at: '2023-01-01T00:00:00Z',
|
createdAt: '2023-01-01T00:00:00Z',
|
||||||
last_modified: '2023-12-01T00:00:00Z',
|
last_modified: '2023-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@ -510,7 +510,7 @@ describe('models service', () => {
|
|||||||
likes: 100,
|
likes: 100,
|
||||||
tags: ['conversational'],
|
tags: ['conversational'],
|
||||||
pipeline_tag: 'text-generation',
|
pipeline_tag: 'text-generation',
|
||||||
created_at: '2023-01-01T00:00:00Z',
|
createdAt: '2023-01-01T00:00:00Z',
|
||||||
last_modified: '2023-12-01T00:00:00Z',
|
last_modified: '2023-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@ -559,7 +559,7 @@ describe('models service', () => {
|
|||||||
likes: 75,
|
likes: 75,
|
||||||
tags: ['pytorch', 'transformers', 'text-generation'],
|
tags: ['pytorch', 'transformers', 'text-generation'],
|
||||||
pipeline_tag: 'text-generation',
|
pipeline_tag: 'text-generation',
|
||||||
created_at: '2021-01-01T00:00:00Z',
|
createdAt: '2021-01-01T00:00:00Z',
|
||||||
last_modified: '2021-12-01T00:00:00Z',
|
last_modified: '2021-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
@ -605,6 +605,8 @@ describe('models service', () => {
|
|||||||
file_size: '4.0 GB',
|
file_size: '4.0 GB',
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
num_mmproj: 0,
|
||||||
|
mmproj_models: [],
|
||||||
created_at: '2021-01-01T00:00:00Z',
|
created_at: '2021-01-01T00:00:00Z',
|
||||||
readme:
|
readme:
|
||||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
|
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
|
||||||
@ -820,7 +822,7 @@ describe('models service', () => {
|
|||||||
downloads: 0,
|
downloads: 0,
|
||||||
likes: 0,
|
likes: 0,
|
||||||
tags: [],
|
tags: [],
|
||||||
created_at: '2021-01-01T00:00:00Z',
|
createdAt: '2021-01-01T00:00:00Z',
|
||||||
last_modified: '2021-12-01T00:00:00Z',
|
last_modified: '2021-12-01T00:00:00Z',
|
||||||
private: false,
|
private: false,
|
||||||
disabled: false,
|
disabled: false,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
import { sanitizeModelId } from '@/lib/utils'
|
import { sanitizeModelId } from '@/lib/utils'
|
||||||
import {
|
import {
|
||||||
AIEngine,
|
AIEngine,
|
||||||
@ -27,6 +28,7 @@ export interface CatalogModel {
|
|||||||
num_quants: number
|
num_quants: number
|
||||||
quants: ModelQuant[]
|
quants: ModelQuant[]
|
||||||
mmproj_models?: MMProjModel[]
|
mmproj_models?: MMProjModel[]
|
||||||
|
num_mmproj: number
|
||||||
created_at?: string
|
created_at?: string
|
||||||
readme?: string
|
readme?: string
|
||||||
tools?: boolean
|
tools?: boolean
|
||||||
@ -44,7 +46,7 @@ export interface HuggingFaceRepo {
|
|||||||
library_name?: string
|
library_name?: string
|
||||||
tags: string[]
|
tags: string[]
|
||||||
pipeline_tag?: string
|
pipeline_tag?: string
|
||||||
created_at: string
|
createdAt: string
|
||||||
last_modified: string
|
last_modified: string
|
||||||
private: boolean
|
private: boolean
|
||||||
disabled: boolean
|
disabled: boolean
|
||||||
@ -155,21 +157,30 @@ export const fetchHuggingFaceRepo = async (
|
|||||||
export const convertHfRepoToCatalogModel = (
|
export const convertHfRepoToCatalogModel = (
|
||||||
repo: HuggingFaceRepo
|
repo: HuggingFaceRepo
|
||||||
): CatalogModel => {
|
): CatalogModel => {
|
||||||
|
// Format file size helper
|
||||||
|
const formatFileSize = (size?: number) => {
|
||||||
|
if (!size) return 'Unknown size'
|
||||||
|
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
|
||||||
|
return `${(size / 1024 ** 3).toFixed(1)} GB`
|
||||||
|
}
|
||||||
|
|
||||||
// Extract GGUF files from the repository siblings
|
// Extract GGUF files from the repository siblings
|
||||||
const ggufFiles =
|
const ggufFiles =
|
||||||
repo.siblings?.filter((file) =>
|
repo.siblings?.filter((file) =>
|
||||||
file.rfilename.toLowerCase().endsWith('.gguf')
|
file.rfilename.toLowerCase().endsWith('.gguf')
|
||||||
) || []
|
) || []
|
||||||
|
|
||||||
// Convert GGUF files to quants format
|
// Separate regular GGUF files from mmproj files
|
||||||
const quants = ggufFiles.map((file) => {
|
const regularGgufFiles = ggufFiles.filter(
|
||||||
// Format file size
|
(file) => !file.rfilename.toLowerCase().includes('mmproj')
|
||||||
const formatFileSize = (size?: number) => {
|
)
|
||||||
if (!size) return 'Unknown size'
|
|
||||||
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
|
|
||||||
return `${(size / 1024 ** 3).toFixed(1)} GB`
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const mmprojFiles = ggufFiles.filter((file) =>
|
||||||
|
file.rfilename.toLowerCase().includes('mmproj')
|
||||||
|
)
|
||||||
|
|
||||||
|
// Convert regular GGUF files to quants format
|
||||||
|
const quants = regularGgufFiles.map((file) => {
|
||||||
// Generate model_id from filename (remove .gguf extension, case-insensitive)
|
// Generate model_id from filename (remove .gguf extension, case-insensitive)
|
||||||
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||||
|
|
||||||
@ -180,15 +191,28 @@ export const convertHfRepoToCatalogModel = (
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Convert mmproj files to mmproj_models format
|
||||||
|
const mmprojModels = mmprojFiles.map((file) => {
|
||||||
|
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||||
|
|
||||||
|
return {
|
||||||
|
model_id: sanitizeModelId(modelId),
|
||||||
|
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
||||||
|
file_size: formatFileSize(file.size),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
model_name: repo.modelId,
|
model_name: repo.modelId,
|
||||||
description: `**Tags**: ${repo.tags?.join(', ')}`,
|
|
||||||
developer: repo.author,
|
developer: repo.author,
|
||||||
downloads: repo.downloads || 0,
|
downloads: repo.downloads || 0,
|
||||||
|
created_at: repo.createdAt,
|
||||||
num_quants: quants.length,
|
num_quants: quants.length,
|
||||||
quants: quants,
|
quants: quants,
|
||||||
created_at: repo.created_at,
|
num_mmproj: mmprojModels.length,
|
||||||
|
mmproj_models: mmprojModels,
|
||||||
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
|
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
|
||||||
|
description: `**Tags**: ${repo.tags?.join(', ')}`,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -327,3 +351,137 @@ export const isToolSupported = async (modelId: string): Promise<boolean> => {
|
|||||||
|
|
||||||
return engine.isToolSupported(modelId)
|
return engine.isToolSupported(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
|
||||||
|
* Also checks if the model has offload_mmproj setting.
|
||||||
|
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
|
||||||
|
* @param modelId - The model ID to check for mmproj.gguf
|
||||||
|
* @param updateProvider - Function to update the provider state
|
||||||
|
* @param getProviderByName - Function to get provider by name
|
||||||
|
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
|
||||||
|
*/
|
||||||
|
export const checkMmprojExistsAndUpdateOffloadMMprojSetting = async (
|
||||||
|
modelId: string,
|
||||||
|
updateProvider?: (providerName: string, data: Partial<ModelProvider>) => void,
|
||||||
|
getProviderByName?: (providerName: string) => ModelProvider | undefined
|
||||||
|
): Promise<{ exists: boolean; settingsUpdated: boolean }> => {
|
||||||
|
let settingsUpdated = false
|
||||||
|
|
||||||
|
try {
|
||||||
|
const engine = getEngine('llamacpp') as AIEngine & {
|
||||||
|
checkMmprojExists?: (id: string) => Promise<boolean>
|
||||||
|
}
|
||||||
|
if (engine && typeof engine.checkMmprojExists === 'function') {
|
||||||
|
const exists = await engine.checkMmprojExists(modelId)
|
||||||
|
|
||||||
|
// If we have the store functions, use them; otherwise fall back to localStorage
|
||||||
|
if (updateProvider && getProviderByName) {
|
||||||
|
const provider = getProviderByName('llamacpp')
|
||||||
|
if (provider) {
|
||||||
|
const model = provider.models.find((m) => m.id === modelId)
|
||||||
|
|
||||||
|
if (model?.settings) {
|
||||||
|
const hasOffloadMmproj = 'offload_mmproj' in model.settings
|
||||||
|
|
||||||
|
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
|
||||||
|
if (exists && !hasOffloadMmproj) {
|
||||||
|
// Create updated models array with the new setting
|
||||||
|
const updatedModels = provider.models.map((m) => {
|
||||||
|
if (m.id === modelId) {
|
||||||
|
return {
|
||||||
|
...m,
|
||||||
|
settings: {
|
||||||
|
...m.settings,
|
||||||
|
offload_mmproj: {
|
||||||
|
key: 'offload_mmproj',
|
||||||
|
title: 'Offload MMProj',
|
||||||
|
description:
|
||||||
|
'Offload multimodal projection layers to GPU',
|
||||||
|
controller_type: 'checkbox',
|
||||||
|
controller_props: {
|
||||||
|
value: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
})
|
||||||
|
|
||||||
|
// Update the provider with the new models array
|
||||||
|
updateProvider('llamacpp', { models: updatedModels })
|
||||||
|
settingsUpdated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fall back to localStorage approach for backwards compatibility
|
||||||
|
try {
|
||||||
|
const modelProviderData = JSON.parse(
|
||||||
|
localStorage.getItem('model-provider') || '{}'
|
||||||
|
)
|
||||||
|
const llamacppProvider = modelProviderData.state?.providers?.find(
|
||||||
|
(p: any) => p.provider === 'llamacpp'
|
||||||
|
)
|
||||||
|
const model = llamacppProvider?.models?.find(
|
||||||
|
(m: any) => m.id === modelId
|
||||||
|
)
|
||||||
|
|
||||||
|
if (model?.settings) {
|
||||||
|
// If mmproj exists, add offload_mmproj setting (only if it doesn't exist)
|
||||||
|
if (exists) {
|
||||||
|
if (!model.settings.offload_mmproj) {
|
||||||
|
model.settings.offload_mmproj = {
|
||||||
|
key: 'offload_mmproj',
|
||||||
|
title: 'Offload MMProj',
|
||||||
|
description: 'Offload multimodal projection layers to GPU',
|
||||||
|
controller_type: 'checkbox',
|
||||||
|
controller_props: {
|
||||||
|
value: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Save updated settings back to localStorage
|
||||||
|
localStorage.setItem(
|
||||||
|
'model-provider',
|
||||||
|
JSON.stringify(modelProviderData)
|
||||||
|
)
|
||||||
|
settingsUpdated = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (localStorageError) {
|
||||||
|
console.error(
|
||||||
|
`Error checking localStorage for model ${modelId}:`,
|
||||||
|
localStorageError
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { exists, settingsUpdated }
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error checking mmproj for model ${modelId}:`, error)
|
||||||
|
}
|
||||||
|
return { exists: false, settingsUpdated }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if mmproj.gguf file exists for a given model ID in the llamacpp provider.
|
||||||
|
* If mmproj.gguf exists, adds offload_mmproj setting with value true.
|
||||||
|
* @param modelId - The model ID to check for mmproj.gguf
|
||||||
|
* @returns Promise<{exists: boolean, settingsUpdated: boolean}> - exists: true if mmproj.gguf exists, settingsUpdated: true if settings were modified
|
||||||
|
*/
|
||||||
|
export const checkMmprojExists = async (modelId: string): Promise<boolean> => {
|
||||||
|
try {
|
||||||
|
const engine = getEngine('llamacpp') as AIEngine & {
|
||||||
|
checkMmprojExists?: (id: string) => Promise<boolean>
|
||||||
|
}
|
||||||
|
if (engine && typeof engine.checkMmprojExists === 'function') {
|
||||||
|
return await engine.checkMmprojExists(modelId)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error checking mmproj for model ${modelId}:`, error)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user