Merge pull request #6278 from lugnicca/feat/model-selector

feat: add model selector (fetch from v1/models) when user adds a provider model
This commit is contained in:
Louis 2025-09-15 20:23:22 +07:00 committed by GitHub
commit e78e4e5cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1255 additions and 23 deletions

View File

@ -0,0 +1,430 @@
import { useState, useMemo, useRef, useEffect, useCallback } from 'react'
import { createPortal } from 'react-dom'
import { Input } from '@/components/ui/input'
import { Button } from '@/components/ui/button'
import { IconChevronDown, IconLoader2, IconRefresh } from '@tabler/icons-react'
import { cn } from '@/lib/utils'
import { useTranslation } from '@/i18n/react-i18next-compat'
// Hook for the dropdown position
function useDropdownPosition(open: boolean, containerRef: React.RefObject<HTMLDivElement | null>) {
const [dropdownPosition, setDropdownPosition] = useState({ top: 0, left: 0, width: 0 })
const updateDropdownPosition = useCallback(() => {
if (containerRef.current) {
const rect = containerRef.current.getBoundingClientRect()
setDropdownPosition({
top: rect.bottom + window.scrollY + 4,
left: rect.left + window.scrollX,
width: rect.width,
})
}
}, [containerRef])
// Update the position when the dropdown opens
useEffect(() => {
if (open) {
requestAnimationFrame(() => {
updateDropdownPosition()
})
}
}, [open, updateDropdownPosition])
// Update the position when the window is resized
useEffect(() => {
if (!open) return
const handleResize = () => {
updateDropdownPosition()
}
window.addEventListener('resize', handleResize)
window.addEventListener('scroll', handleResize)
return () => {
window.removeEventListener('resize', handleResize)
window.removeEventListener('scroll', handleResize)
}
}, [open, updateDropdownPosition])
return { dropdownPosition, updateDropdownPosition }
}
// Components for the different sections of the dropdown
const ErrorSection = ({ error, t }: { error: string; t: (key: string) => string }) => (
<div className="px-3 py-2 text-sm text-destructive">
<div className="flex items-center justify-between">
<span className="text-destructive font-medium">{t('common:failedToLoadModels')}</span>
</div>
<div className="text-xs text-main-view-fg/50 mt-0">{error}</div>
</div>
)
const LoadingSection = ({ t }: { t: (key: string) => string }) => (
<div className="flex items-center justify-center px-3 py-3 text-sm text-main-view-fg/50">
<IconLoader2 className="h-4 w-4 animate-spin mr-2 text-main-view-fg/50" />
<span className="text-sm text-main-view-fg/50">{t('common:loading')}</span>
</div>
)
const EmptySection = ({ inputValue, t }: { inputValue: string; t: (key: string, options?: Record<string, string>) => string }) => (
<div className="px-3 py-3 text-sm text-main-view-fg/50 text-center">
<div className="flex items-center justify-between">
<div className="flex-1">
{inputValue.trim() ? (
<span className="text-main-view-fg/50">{t('common:noModelsFoundFor', { searchValue: inputValue })}</span>
) : (
<span className="text-main-view-fg/50">{t('common:noModels')}</span>
)}
</div>
</div>
</div>
)
const ModelsList = ({
filteredModels,
value,
highlightedIndex,
onModelSelect,
onHighlight
}: {
filteredModels: string[]
value: string
highlightedIndex: number
onModelSelect: (model: string) => void
onHighlight: (index: number) => void
}) => (
<>
{filteredModels.map((model, index) => (
<div
key={model}
data-model={model}
onClick={(e) => {
e.stopPropagation()
onModelSelect(model)
}}
onMouseEnter={() => onHighlight(index)}
className={cn(
'cursor-pointer px-3 py-2 hover:bg-main-view-fg/15 hover:shadow-sm transition-all duration-200 text-main-view-fg',
value === model && 'bg-main-view-fg/12 shadow-sm',
highlightedIndex === index && 'bg-main-view-fg/20 shadow-md'
)}
>
<span className="text-sm truncate text-main-view-fg">{model}</span>
</div>
))}
</>
)
// Custom hook for keyboard navigation
function useKeyboardNavigation(
open: boolean,
setOpen: React.Dispatch<React.SetStateAction<boolean>>,
models: string[],
filteredModels: string[],
highlightedIndex: number,
setHighlightedIndex: React.Dispatch<React.SetStateAction<number>>,
onModelSelect: (model: string) => void,
dropdownRef: React.RefObject<HTMLDivElement | null>
) {
// Scroll to the highlighted element
useEffect(() => {
if (highlightedIndex >= 0 && dropdownRef.current) {
requestAnimationFrame(() => {
const modelElements = dropdownRef.current?.querySelectorAll('[data-model]')
const highlightedElement = modelElements?.[highlightedIndex] as HTMLElement
if (highlightedElement) {
highlightedElement.scrollIntoView({
block: 'nearest',
behavior: 'auto'
})
}
})
}
}, [highlightedIndex, dropdownRef])
const handleKeyDown = useCallback((e: React.KeyboardEvent) => {
// Open the dropdown with the arrows if closed
if (!open && (e.key === 'ArrowDown' || e.key === 'ArrowUp')) {
if (models.length > 0) {
e.preventDefault()
setOpen(true)
setHighlightedIndex(0)
}
return
}
if (!open) return
switch (e.key) {
case 'ArrowDown':
e.preventDefault()
setHighlightedIndex((prev: number) => filteredModels.length === 0 ? 0 : (prev < filteredModels.length - 1 ? prev + 1 : 0))
break
case 'ArrowUp':
e.preventDefault()
setHighlightedIndex((prev: number) => filteredModels.length === 0 ? 0 : (prev > 0 ? prev - 1 : filteredModels.length - 1))
break
case 'Enter':
e.preventDefault()
if (highlightedIndex >= 0 && highlightedIndex < filteredModels.length) {
onModelSelect(filteredModels[highlightedIndex])
}
break
case 'Escape':
e.preventDefault()
e.stopPropagation()
setOpen(false)
setHighlightedIndex(-1)
break
case 'PageUp':
e.preventDefault()
setHighlightedIndex(0)
break
case 'PageDown':
e.preventDefault()
setHighlightedIndex(filteredModels.length - 1)
break
}
}, [open, setOpen, models.length, filteredModels, highlightedIndex, setHighlightedIndex, onModelSelect])
return { handleKeyDown }
}
type ModelComboboxProps = {
value: string
onChange: (value: string) => void
models: string[]
loading?: boolean
error?: string | null
onRefresh?: () => void
placeholder?: string
disabled?: boolean
className?: string
onOpenChange?: (open: boolean) => void
}
export function ModelCombobox({
value,
onChange,
models,
loading = false,
error = null,
onRefresh,
placeholder = 'Type or select a model...',
disabled = false,
className,
onOpenChange,
}: ModelComboboxProps) {
const [open, setOpen] = useState(false)
const [inputValue, setInputValue] = useState(value)
const [highlightedIndex, setHighlightedIndex] = useState(-1)
const inputRef = useRef<HTMLInputElement>(null)
const containerRef = useRef<HTMLDivElement>(null)
const dropdownRef = useRef<HTMLDivElement | null>(null)
const { t } = useTranslation()
// Sync input value with prop value
useEffect(() => {
setInputValue(value)
}, [value])
// Notify parent when open state changes
useEffect(() => {
onOpenChange?.(open)
}, [open, onOpenChange])
// Hook for the dropdown position
const { dropdownPosition } = useDropdownPosition(open, containerRef)
// Optimized model filtering
const filteredModels = useMemo(() => {
if (!inputValue.trim()) return models
const searchValue = inputValue.toLowerCase()
return models.filter((model) => model.toLowerCase().includes(searchValue))
}, [models, inputValue])
// Reset highlighted index when filtered models change
useEffect(() => {
setHighlightedIndex(-1)
}, [filteredModels])
// Close the dropdown when clicking outside
useEffect(() => {
if (!open) return
const handleClickOutside = (event: Event) => {
const target = event.target as Node
const isInsideContainer = containerRef.current?.contains(target)
const isInsideDropdown = dropdownRef.current?.contains(target)
if (!isInsideContainer && !isInsideDropdown) {
setOpen(false)
setHighlightedIndex(-1)
}
}
const events = ['mousedown', 'touchstart']
events.forEach(eventType => {
document.addEventListener(eventType, handleClickOutside, { capture: true, passive: true })
})
return () => {
events.forEach(eventType => {
document.removeEventListener(eventType, handleClickOutside, { capture: true })
})
}
}, [open])
// Cleanup: close the dropdown when the component is unmounted
useEffect(() => {
return () => {
setOpen(false)
setHighlightedIndex(-1)
}
}, [])
// Handler for the input change
const handleInputChange = useCallback((newValue: string) => {
setInputValue(newValue)
onChange(newValue)
// Open the dropdown if the user types and there are models
if (newValue.trim() && models.length > 0) {
setOpen(true)
} else {
setOpen(false)
}
}, [onChange, models.length])
// Handler for the model selection
const handleModelSelect = useCallback((model: string) => {
setInputValue(model)
onChange(model)
setOpen(false)
setHighlightedIndex(-1)
inputRef.current?.focus()
}, [onChange])
// Hook for the keyboard navigation
const { handleKeyDown } = useKeyboardNavigation(
open,
setOpen,
models,
filteredModels,
highlightedIndex,
setHighlightedIndex,
handleModelSelect,
dropdownRef
)
// Handler for the dropdown opening
const handleDropdownToggle = useCallback(() => {
inputRef.current?.focus()
setOpen(!open)
}, [open])
// Handler for the input click
const handleInputClick = useCallback(() => {
if (models.length > 0) {
setOpen(true)
}
}, [models.length])
return (
<div className={cn('relative', className)} ref={containerRef}>
<div className="relative">
<Input
ref={inputRef}
value={inputValue}
onChange={(e) => handleInputChange(e.target.value)}
onKeyDown={handleKeyDown}
onClick={handleInputClick}
placeholder={placeholder}
disabled={disabled}
className="pr-16"
/>
{/* Input action buttons */}
<div className="absolute right-1 top-1/2 -translate-y-1/2 flex gap-1">
{onRefresh && (
<Button
variant="link"
size="sm"
disabled={disabled || loading}
onMouseDown={(e) => e.preventDefault()}
onClick={(e) => {
e.stopPropagation()
onRefresh()
}}
className="h-6 w-6 p-0 no-underline hover:bg-main-view-fg/10"
aria-label="Refresh models"
>
{loading ? (
<IconLoader2 className="h-3 w-3 animate-spin" />
) : (
<IconRefresh className="h-3 w-3 opacity-70" />
)}
</Button>
)}
<Button
variant="link"
size="sm"
disabled={disabled}
onMouseDown={(e) => e.preventDefault()}
onClick={handleDropdownToggle}
className="h-6 w-6 p-0 no-underline hover:bg-main-view-fg/10"
>
{loading ? (
<IconLoader2 className="h-3 w-3 animate-spin" />
) : (
<IconChevronDown className="h-3 w-3 opacity-50" />
)}
</Button>
</div>
{/* Custom dropdown rendered as portal */}
{open && dropdownPosition.width > 0 && createPortal(
<div
ref={dropdownRef}
className="fixed z-[9999] bg-main-view border border-main-view-fg/10 rounded-md shadow-lg max-h-[300px] overflow-y-auto text-main-view-fg animate-in fade-in-0 zoom-in-95 duration-200"
style={{
top: dropdownPosition.top,
left: dropdownPosition.left,
width: dropdownPosition.width,
minWidth: dropdownPosition.width,
maxWidth: dropdownPosition.width,
pointerEvents: 'auto',
}}
data-dropdown="model-combobox"
onPointerDown={(e) => e.stopPropagation()}
onWheel={(e) => e.stopPropagation()}
>
{/* Error state */}
{error && <ErrorSection error={error} t={t} />}
{/* Loading state */}
{loading && <LoadingSection t={t} />}
{/* Models list */}
{!loading && !error && (
filteredModels.length === 0 ? (
<EmptySection inputValue={inputValue} t={t} />
) : (
<ModelsList
filteredModels={filteredModels}
value={value}
highlightedIndex={highlightedIndex}
onModelSelect={handleModelSelect}
onHighlight={setHighlightedIndex}
/>
)
)}
</div>,
document.body
)}
</div>
</div>
)
}

View File

@ -0,0 +1,490 @@
import { describe, it, expect, vi, beforeEach, beforeAll, afterAll } from 'vitest'
import { render, screen, fireEvent, waitFor, act } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import '@testing-library/jest-dom/vitest'
import React from 'react'
import { ModelCombobox } from '../ModelCombobox'
// Mock translation hook
vi.mock('@/i18n/react-i18next-compat', () => ({
useTranslation: () => ({
t: (key: string, options?: Record<string, string>) => {
if (key === 'common:failedToLoadModels') return 'Failed to load models'
if (key === 'common:loading') return 'Loading'
if (key === 'common:noModelsFoundFor') return `No models found for "${options?.searchValue}"`
if (key === 'common:noModels') return 'No models available'
return key
},
}),
}))
describe('ModelCombobox', () => {
const mockOnChange = vi.fn()
const mockOnRefresh = vi.fn()
const defaultProps = {
value: '',
onChange: mockOnChange,
models: ['gpt-3.5-turbo', 'gpt-4', 'claude-3-haiku'],
}
let bcrSpy: ReturnType<typeof vi.spyOn>
let scrollSpy: ReturnType<typeof vi.spyOn>
beforeAll(() => {
const mockRect = {
width: 300,
height: 40,
top: 100,
left: 50,
bottom: 140,
right: 350,
x: 50,
y: 100,
toJSON: () => {},
} as unknown as DOMRect
bcrSpy = vi
.spyOn(Element.prototype as any, 'getBoundingClientRect')
.mockReturnValue(mockRect)
Element.prototype.scrollIntoView = () => {}
})
beforeEach(() => {
vi.clearAllMocks()
})
afterAll(() => {
bcrSpy?.mockRestore()
scrollSpy?.mockRestore()
})
it('renders input field with default placeholder', () => {
act(() => {
render(<ModelCombobox {...defaultProps} />)
})
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
expect(input).toHaveAttribute('placeholder', 'Type or select a model...')
})
it('renders custom placeholder', () => {
act(() => {
render(<ModelCombobox {...defaultProps} placeholder="Choose a model" />)
})
const input = screen.getByRole('textbox')
expect(input).toHaveAttribute('placeholder', 'Choose a model')
})
it('renders dropdown trigger button', () => {
act(() => {
render(<ModelCombobox {...defaultProps} />)
})
const button = screen.getByRole('button')
expect(button).toBeInTheDocument()
})
it('displays current value in input', () => {
act(() => {
render(<ModelCombobox {...defaultProps} value="gpt-4" />)
})
const input = screen.getByDisplayValue('gpt-4')
expect(input).toBeInTheDocument()
})
it('applies custom className', () => {
const { container } = render(
<ModelCombobox {...defaultProps} className="custom-class" />
)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveClass('custom-class')
})
it('disables input when disabled prop is true', () => {
act(() => {
render(<ModelCombobox {...defaultProps} disabled />)
})
const input = screen.getByRole('textbox')
const button = screen.getByRole('button')
expect(input).toBeDisabled()
expect(button).toBeDisabled()
})
it('shows loading spinner in trigger button', () => {
act(() => {
render(<ModelCombobox {...defaultProps} loading />)
})
const button = screen.getByRole('button')
const spinner = button.querySelector('.animate-spin')
expect(spinner).toBeInTheDocument()
})
it('shows loading section when dropdown is opened during loading', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} loading />)
// Click input to trigger dropdown opening
const input = screen.getByRole('textbox')
await user.click(input)
// Wait for dropdown to appear and check loading section
await waitFor(() => {
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
expect(screen.getByText('Loading')).toBeInTheDocument()
})
})
it('calls onChange when typing', async () => {
const user = userEvent.setup()
const localMockOnChange = vi.fn()
render(<ModelCombobox {...defaultProps} onChange={localMockOnChange} />)
const input = screen.getByRole('textbox')
await user.type(input, 'g')
expect(localMockOnChange).toHaveBeenCalledWith('g')
})
it('updates input value when typing', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
await user.type(input, 'test')
expect(input).toHaveValue('test')
})
it('handles input focus', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
await user.click(input)
expect(input).toHaveFocus()
})
it('renders with empty models array', () => {
act(() => {
render(<ModelCombobox {...defaultProps} models={[]} />)
})
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
})
it('renders with models array', () => {
act(() => {
render(<ModelCombobox {...defaultProps} models={['model1', 'model2']} />)
})
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
})
it('handles mount and unmount without errors', () => {
const { unmount } = render(<ModelCombobox {...defaultProps} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
unmount()
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
})
it('handles props changes', () => {
const { rerender } = render(<ModelCombobox {...defaultProps} value="" />)
expect(screen.getByDisplayValue('')).toBeInTheDocument()
rerender(<ModelCombobox {...defaultProps} value="gpt-4" />)
expect(screen.getByDisplayValue('gpt-4')).toBeInTheDocument()
})
it('handles models array changes', () => {
const { rerender } = render(<ModelCombobox {...defaultProps} models={[]} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
rerender(<ModelCombobox {...defaultProps} models={['model1', 'model2']} />)
expect(screen.getByRole('textbox')).toBeInTheDocument()
})
it('does not open dropdown when clicking input with no models', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} models={[]} />)
const input = screen.getByRole('textbox')
await user.click(input)
// Should focus but not open dropdown
expect(input).toHaveFocus()
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).not.toBeInTheDocument()
})
it('accepts error prop without crashing', () => {
act(() => {
render(<ModelCombobox {...defaultProps} error="Test error message" />)
})
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
expect(input).toHaveAttribute('placeholder', 'Type or select a model...')
})
it('renders with all props', () => {
act(() => {
render(
<ModelCombobox
{...defaultProps}
loading
error="Error message"
onRefresh={mockOnRefresh}
placeholder="Custom placeholder"
disabled
/>
)
})
const input = screen.getByRole('textbox')
expect(input).toBeInTheDocument()
expect(input).toBeDisabled()
})
it('opens dropdown when clicking trigger button', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const button = screen.getByRole('button')
await user.click(button)
await waitFor(() => {
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
})
})
it('opens dropdown when clicking input', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
await user.click(input)
expect(input).toHaveFocus()
await waitFor(() => {
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
})
})
it('filters models based on input value', async () => {
const user = userEvent.setup()
const localMockOnChange = vi.fn()
render(<ModelCombobox {...defaultProps} onChange={localMockOnChange} />)
const input = screen.getByRole('textbox')
await user.type(input, 'gpt-4')
expect(localMockOnChange).toHaveBeenCalledWith('gpt-4')
})
it('shows filtered models in dropdown when typing', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
// Type 'gpt' to trigger dropdown opening
await user.type(input, 'gpt')
await waitFor(() => {
// Dropdown should be open
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
// Should show GPT models
expect(screen.getByText('gpt-3.5-turbo')).toBeInTheDocument()
expect(screen.getByText('gpt-4')).toBeInTheDocument()
// Should not show Claude
expect(screen.queryByText('claude-3-haiku')).not.toBeInTheDocument()
})
})
it('handles case insensitive filtering', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
await user.type(input, 'GPT')
expect(mockOnChange).toHaveBeenCalledWith('GPT')
})
it('shows empty state when no models match filter', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
// Type something that doesn't match any model to trigger dropdown + empty state
await user.type(input, 'nonexistent')
await waitFor(() => {
// Dropdown should be open
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
// Should show empty state message
expect(screen.getByText('No models found for "nonexistent"')).toBeInTheDocument()
})
})
it('selects model from dropdown when clicked', async () => {
const user = userEvent.setup()
const localMockOnChange = vi.fn()
render(<ModelCombobox {...defaultProps} onChange={localMockOnChange} />)
const input = screen.getByRole('textbox')
await user.click(input)
await waitFor(() => {
const modelOption = screen.getByText('gpt-4')
expect(modelOption).toBeInTheDocument()
})
const modelOption = screen.getByText('gpt-4')
await user.click(modelOption)
expect(localMockOnChange).toHaveBeenCalledWith('gpt-4')
expect(input).toHaveValue('gpt-4')
})
it('submits input value with Enter key', async () => {
const user = userEvent.setup()
const localMockOnChange = vi.fn()
render(<ModelCombobox {...defaultProps} onChange={localMockOnChange} />)
const input = screen.getByRole('textbox')
await user.type(input, 'gpt')
await user.keyboard('{Enter}')
expect(localMockOnChange).toHaveBeenCalledWith('gpt')
})
it('displays error message in dropdown', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} error="Network connection failed" />)
const input = screen.getByRole('textbox')
// Click input to open dropdown
await user.click(input)
await waitFor(() => {
// Dropdown should be open
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
// Error messages should be displayed
expect(screen.getByText('Failed to load models')).toBeInTheDocument()
expect(screen.getByText('Network connection failed')).toBeInTheDocument()
})
})
it('calls onRefresh when refresh button is clicked', async () => {
const user = userEvent.setup()
const localMockOnRefresh = vi.fn()
render(<ModelCombobox {...defaultProps} error="Network error" onRefresh={localMockOnRefresh} />)
const input = screen.getByRole('textbox')
// Click input to open dropdown
await user.click(input)
await waitFor(() => {
// Dropdown should be open with error section
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
const refreshButton = document.querySelector('[aria-label="Refresh models"]')
expect(refreshButton).toBeInTheDocument()
})
const refreshButton = document.querySelector('[aria-label="Refresh models"]')
if (refreshButton) {
await user.click(refreshButton)
expect(localMockOnRefresh).toHaveBeenCalledTimes(1)
}
})
it('opens dropdown when pressing ArrowDown', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
input.focus()
await user.keyboard('{ArrowDown}')
expect(input).toHaveFocus()
await waitFor(() => {
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
})
})
it('navigates through models with arrow keys', async () => {
const user = userEvent.setup()
render(<ModelCombobox {...defaultProps} />)
const input = screen.getByRole('textbox')
input.focus()
// ArrowDown should open dropdown
await user.keyboard('{ArrowDown}')
await waitFor(() => {
// Dropdown should be open
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
})
// Navigate to second item
await user.keyboard('{ArrowDown}')
await waitFor(() => {
const secondModel = screen.getByText('gpt-4')
const modelElement = secondModel.closest('[data-model]')
expect(modelElement).toHaveClass('bg-main-view-fg/20')
})
})
it('handles Enter key to select highlighted model', async () => {
const user = userEvent.setup()
const localMockOnChange = vi.fn()
render(<ModelCombobox {...defaultProps} onChange={localMockOnChange} />)
const input = screen.getByRole('textbox')
// Type 'gpt' to open dropdown and filter models
await user.type(input, 'gpt')
await waitFor(() => {
// Dropdown should be open with filtered models
const dropdown = document.querySelector('[data-dropdown="model-combobox"]')
expect(dropdown).toBeInTheDocument()
})
// Navigate to highlight first model and select it
await user.keyboard('{ArrowDown}')
await user.keyboard('{Enter}')
expect(localMockOnChange).toHaveBeenCalledWith('gpt-3.5-turbo')
})
})

View File

@ -8,8 +8,9 @@ import {
DialogFooter, DialogFooter,
} from '@/components/ui/dialog' } from '@/components/ui/dialog'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { useModelProvider } from '@/hooks/useModelProvider' import { useModelProvider } from '@/hooks/useModelProvider'
import { useProviderModels } from '@/hooks/useProviderModels'
import { ModelCombobox } from '@/containers/ModelCombobox'
import { IconPlus } from '@tabler/icons-react' import { IconPlus } from '@tabler/icons-react'
import { useState } from 'react' import { useState } from 'react'
import { getProviderTitle } from '@/lib/utils' import { getProviderTitle } from '@/lib/utils'
@ -25,6 +26,12 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
const { updateProvider } = useModelProvider() const { updateProvider } = useModelProvider()
const [modelId, setModelId] = useState<string>('') const [modelId, setModelId] = useState<string>('')
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const [isComboboxOpen, setIsComboboxOpen] = useState(false)
// Fetch models from provider API (API key is optional)
const { models, loading, error, refetch } = useProviderModels(
provider.base_url ? provider : undefined
)
// Handle form submission // Handle form submission
const handleSubmit = () => { const handleSubmit = () => {
@ -62,7 +69,13 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
</div> </div>
)} )}
</DialogTrigger> </DialogTrigger>
<DialogContent> <DialogContent
onEscapeKeyDown={(e: KeyboardEvent) => {
if (isComboboxOpen) {
e.preventDefault()
}
}}
>
<DialogHeader> <DialogHeader>
<DialogTitle>{t('providers:addModel.title')}</DialogTitle> <DialogTitle>{t('providers:addModel.title')}</DialogTitle>
<DialogDescription> <DialogDescription>
@ -72,7 +85,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
</DialogDescription> </DialogDescription>
</DialogHeader> </DialogHeader>
{/* Model ID field - required */} {/* Model selection field - required */}
<div className="space-y-2"> <div className="space-y-2">
<label <label
htmlFor="model-id" htmlFor="model-id"
@ -81,12 +94,16 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => {
{t('providers:addModel.modelId')}{' '} {t('providers:addModel.modelId')}{' '}
<span className="text-destructive">*</span> <span className="text-destructive">*</span>
</label> </label>
<Input <ModelCombobox
id="model-id" key={`${provider.provider}-${provider.base_url || ''}`}
value={modelId} value={modelId}
onChange={(e) => setModelId(e.target.value)} onChange={setModelId}
models={models}
loading={loading}
error={error}
onRefresh={refetch}
placeholder={t('providers:addModel.enterModelId')} placeholder={t('providers:addModel.enterModelId')}
required onOpenChange={setIsComboboxOpen}
/> />
</div> </div>

View File

@ -0,0 +1,102 @@
import { describe, it, expect, beforeEach, vi } from 'vitest'
import { renderHook, waitFor } from '@testing-library/react'
import { useProviderModels } from '../useProviderModels'
import { useServiceHub } from '@/hooks/useServiceHub'
// Local minimal provider type for tests
type MockModelProvider = {
active: boolean
provider: string
base_url?: string
api_key?: string
settings: any[]
models: any[]
}
describe('useProviderModels', () => {
const mockProvider: MockModelProvider = {
active: true,
provider: 'openai',
base_url: 'https://api.openai.com/v1',
api_key: 'test-api-key',
settings: [],
models: [],
}
const mockModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-4-turbo']
let fetchModelsSpy: ReturnType<typeof vi.fn>
beforeEach(() => {
vi.restoreAllMocks()
vi.clearAllMocks()
const hub = (useServiceHub as unknown as () => any)()
const mockedFetch = vi.fn()
vi.spyOn(hub, 'providers').mockReturnValue({
fetchModelsFromProvider: mockedFetch,
} as any)
fetchModelsSpy = mockedFetch
})
it('should initialize with empty state', () => {
const { result } = renderHook(() => useProviderModels())
expect(result.current.models).toEqual([])
expect(result.current.loading).toBe(false)
expect(result.current.error).toBe(null)
expect(typeof result.current.refetch).toBe('function')
})
it('should not fetch models when provider is undefined', () => {
renderHook(() => useProviderModels(undefined))
expect(fetchModelsSpy).not.toHaveBeenCalled()
})
it('should not fetch models when provider has no base_url', () => {
const providerWithoutUrl = { ...mockProvider, base_url: undefined }
renderHook(() => useProviderModels(providerWithoutUrl))
expect(fetchModelsSpy).not.toHaveBeenCalled()
})
it('should fetch and sort models', async () => {
fetchModelsSpy.mockResolvedValueOnce(mockModels)
const { result } = renderHook(() => useProviderModels(mockProvider))
await waitFor(() => {
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
})
expect(result.current.error).toBe(null)
expect(fetchModelsSpy).toHaveBeenCalledWith(mockProvider)
})
it('should clear models when switching to invalid provider', async () => {
fetchModelsSpy.mockResolvedValueOnce(mockModels)
const { result, rerender } = renderHook(
({ provider }) => useProviderModels(provider),
{ initialProps: { provider: mockProvider } }
)
await waitFor(() => {
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
expect(result.current.loading).toBe(false)
}, { timeout: 500 })
// Switch to invalid provider
rerender({ provider: { ...mockProvider, base_url: undefined } })
expect(result.current.models).toEqual([])
expect(result.current.error).toBe(null)
expect(result.current.loading).toBe(false)
})
it('should not refetch when provider is undefined', () => {
const { result } = renderHook(() => useProviderModels(undefined))
result.current.refetch()
expect(fetchModelsSpy).not.toHaveBeenCalled()
})
})

View File

@ -0,0 +1,93 @@
import { useState, useEffect, useCallback, useRef } from 'react'
import { useServiceHub } from './useServiceHub'
type UseProviderModelsState = {
models: string[]
loading: boolean
error: string | null
refetch: () => void
}
const modelsCache = new Map<string, { models: string[]; timestamp: number }>()
const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes
export const useProviderModels = (provider?: ModelProvider): UseProviderModelsState => {
const serviceHub = useServiceHub()
const [models, setModels] = useState<string[]>([])
const [loading, setLoading] = useState(false)
const [error, setError] = useState<string | null>(null)
const prevProviderKey = useRef<string>('')
const requestIdRef = useRef(0)
const fetchModels = useCallback(async () => {
if (!provider || !provider.base_url) {
// Clear models if provider is invalid (base_url is required, api_key is optional)
setModels([])
setError(null)
setLoading(false)
return
}
// Clear any previous state when starting a new fetch for a different provider
const currentProviderKey = `${provider.provider}-${provider.base_url}`
if (currentProviderKey !== prevProviderKey.current) {
setModels([])
setError(null)
setLoading(false)
prevProviderKey.current = currentProviderKey
}
const cacheKey = `${provider.provider}-${provider.base_url}`
const cached = modelsCache.get(cacheKey)
// Check cache first
if (cached && Date.now() - cached.timestamp < CACHE_DURATION) {
setModels(cached.models)
return
}
const currentRequestId = ++requestIdRef.current
setLoading(true)
setError(null)
try {
const fetchedModels = await serviceHub.providers().fetchModelsFromProvider(provider)
if (currentRequestId !== requestIdRef.current) return
const sortedModels = fetchedModels.sort((a, b) => a.localeCompare(b))
setModels(sortedModels)
// Cache the results
modelsCache.set(cacheKey, {
models: sortedModels,
timestamp: Date.now(),
})
} catch (err) {
if (currentRequestId !== requestIdRef.current) return
const errorMessage = err instanceof Error ? err.message : 'Failed to fetch models'
setError(errorMessage)
console.error(`Error fetching models from ${provider.provider}:`, err)
} finally {
if (currentRequestId === requestIdRef.current) setLoading(false)
}
}, [provider, serviceHub])
const refetch = useCallback(() => {
if (provider) {
const cacheKey = `${provider.provider}-${provider.base_url}`
modelsCache.delete(cacheKey)
fetchModels()
}
}, [provider, fetchModels])
useEffect(() => {
fetchModels()
}, [fetchModels])
return {
models,
loading,
error,
refetch,
}
}

View File

@ -75,6 +75,8 @@
"selectAModel": "Select a model", "selectAModel": "Select a model",
"noToolsAvailable": "No tools available", "noToolsAvailable": "No tools available",
"noModelsFoundFor": "No models found for \"{{searchValue}}\"", "noModelsFoundFor": "No models found for \"{{searchValue}}\"",
"failedToLoadModels": "Failed to load models",
"noModels": "No models found",
"customAvatar": "Custom avatar", "customAvatar": "Custom avatar",
"editAssistant": "Edit Assistant", "editAssistant": "Edit Assistant",
"jan": "Jan", "jan": "Jan",

View File

@ -215,7 +215,7 @@ describe('WebProvidersService', () => {
) )
}) })
it('should throw error when API response is not ok', async () => { it('should throw error when API response is not ok (404)', async () => {
const mockResponse = { const mockResponse = {
ok: false, ok: false,
status: 404, status: 404,
@ -229,7 +229,43 @@ describe('WebProvidersService', () => {
} }
await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow( await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow(
'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.' 'Models endpoint not found for custom. Check the base URL configuration.'
)
})
it('should throw error when API response is not ok (403)', async () => {
const mockResponse = {
ok: false,
status: 403,
statusText: 'Forbidden',
}
vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow(
'Access forbidden: Check your API key permissions for custom'
)
})
it('should throw error when API response is not ok (401)', async () => {
const mockResponse = {
ok: false,
status: 401,
statusText: 'Unauthorized',
}
vi.mocked(global.fetch).mockResolvedValue(mockResponse as any)
const provider = {
provider: 'custom',
base_url: 'https://api.custom.com',
} as ModelProvider
await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow(
'Authentication failed: API key is required or invalid for custom'
) )
}) })

View File

@ -127,7 +127,7 @@ export class TauriProvidersService extends DefaultProvidersService {
} }
return runtimeProviders.concat(builtinProviders as ModelProvider[]) return runtimeProviders.concat(builtinProviders as ModelProvider[])
} catch (error) { } catch (error: unknown) {
console.error('Error getting providers in Tauri:', error) console.error('Error getting providers in Tauri:', error)
return [] return []
} }
@ -162,9 +162,24 @@ export class TauriProvidersService extends DefaultProvidersService {
}) })
if (!response.ok) { if (!response.ok) {
// Provide more specific error messages based on status code (aligned with web implementation)
if (response.status === 401) {
throw new Error( throw new Error(
`Failed to fetch models: ${response.status} ${response.statusText}` `Authentication failed: API key is required or invalid for ${provider.provider}`
) )
} else if (response.status === 403) {
throw new Error(
`Access forbidden: Check your API key permissions for ${provider.provider}`
)
} else if (response.status === 404) {
throw new Error(
`Models endpoint not found for ${provider.provider}. Check the base URL configuration.`
)
} else {
throw new Error(
`Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}`
)
}
} }
const data = await response.json() const data = await response.json()
@ -194,14 +209,30 @@ export class TauriProvidersService extends DefaultProvidersService {
} catch (error) { } catch (error) {
console.error('Error fetching models from provider:', error) console.error('Error fetching models from provider:', error)
// Provide helpful error message // Preserve structured error messages thrown above
const structuredErrorPrefixes = [
'Authentication failed',
'Access forbidden',
'Models endpoint not found',
'Failed to fetch models from'
]
if (error instanceof Error &&
structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) {
throw new Error(error.message)
}
// Provide helpful error message for any connection errors
if (error instanceof Error && error.message.includes('fetch')) { if (error instanceof Error && error.message.includes('fetch')) {
throw new Error( throw new Error(
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
) )
} }
throw error // Generic fallback
throw new Error(
`Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}`
)
} }
} }

View File

@ -138,9 +138,24 @@ export class WebProvidersService implements ProvidersService {
}) })
if (!response.ok) { if (!response.ok) {
// Provide more specific error messages based on status code
if (response.status === 401) {
throw new Error( throw new Error(
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` `Authentication failed: API key is required or invalid for ${provider.provider}`
) )
} else if (response.status === 403) {
throw new Error(
`Access forbidden: Check your API key permissions for ${provider.provider}`
)
} else if (response.status === 404) {
throw new Error(
`Models endpoint not found for ${provider.provider}. Check the base URL configuration.`
)
} else {
throw new Error(
`Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}`
)
}
} }
const data = await response.json() const data = await response.json()
@ -170,15 +185,30 @@ export class WebProvidersService implements ProvidersService {
} catch (error) { } catch (error) {
console.error('Error fetching models from provider:', error) console.error('Error fetching models from provider:', error)
// Provide helpful error message for any connection errors const structuredErrorPrefixes = [
if (error instanceof Error && error.message.includes('Cannot connect')) { 'Authentication failed',
throw error 'Access forbidden',
'Models endpoint not found',
'Failed to fetch models from'
]
if (error instanceof Error &&
structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) {
throw new Error(error.message)
} }
// Provide helpful error message for any connection errors
if (error instanceof Error && error.message.includes('fetch')) {
throw new Error( throw new Error(
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
) )
} }
// Generic fallback
throw new Error(
`Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}`
)
}
} }
async updateSettings(providerName: string, settings: ProviderSetting[]): Promise<void> { async updateSettings(providerName: string, settings: ProviderSetting[]): Promise<void> {

View File

@ -104,6 +104,7 @@ const mockServiceHub = {
deleteProvider: vi.fn().mockResolvedValue(undefined), deleteProvider: vi.fn().mockResolvedValue(undefined),
updateProvider: vi.fn().mockResolvedValue(undefined), updateProvider: vi.fn().mockResolvedValue(undefined),
getProvider: vi.fn().mockResolvedValue(null), getProvider: vi.fn().mockResolvedValue(null),
fetchModelsFromProvider: vi.fn().mockResolvedValue([]),
}), }),
models: () => ({ models: () => ({
getModels: vi.fn().mockResolvedValue([]), getModels: vi.fn().mockResolvedValue([]),