diff --git a/web-app/src/containers/ModelCombobox.tsx b/web-app/src/containers/ModelCombobox.tsx new file mode 100644 index 000000000..ea5b3d670 --- /dev/null +++ b/web-app/src/containers/ModelCombobox.tsx @@ -0,0 +1,430 @@ +import { useState, useMemo, useRef, useEffect, useCallback } from 'react' +import { createPortal } from 'react-dom' +import { Input } from '@/components/ui/input' +import { Button } from '@/components/ui/button' +import { IconChevronDown, IconLoader2, IconRefresh } from '@tabler/icons-react' +import { cn } from '@/lib/utils' +import { useTranslation } from '@/i18n/react-i18next-compat' + +// Hook for the dropdown position +function useDropdownPosition(open: boolean, containerRef: React.RefObject) { + const [dropdownPosition, setDropdownPosition] = useState({ top: 0, left: 0, width: 0 }) + + const updateDropdownPosition = useCallback(() => { + if (containerRef.current) { + const rect = containerRef.current.getBoundingClientRect() + setDropdownPosition({ + top: rect.bottom + window.scrollY + 4, + left: rect.left + window.scrollX, + width: rect.width, + }) + } + }, [containerRef]) + + // Update the position when the dropdown opens + useEffect(() => { + if (open) { + requestAnimationFrame(() => { + updateDropdownPosition() + }) + } + }, [open, updateDropdownPosition]) + + // Update the position when the window is resized + useEffect(() => { + if (!open) return + + const handleResize = () => { + updateDropdownPosition() + } + + window.addEventListener('resize', handleResize) + window.addEventListener('scroll', handleResize) + + return () => { + window.removeEventListener('resize', handleResize) + window.removeEventListener('scroll', handleResize) + } + }, [open, updateDropdownPosition]) + + return { dropdownPosition, updateDropdownPosition } +} + +// Components for the different sections of the dropdown +const ErrorSection = ({ error, t }: { error: string; t: (key: string) => string }) => ( +
+
+ {t('common:failedToLoadModels')} +
+
{error}
+
+) + +const LoadingSection = ({ t }: { t: (key: string) => string }) => ( +
+ + {t('common:loading')} +
+) + +const EmptySection = ({ inputValue, t }: { inputValue: string; t: (key: string, options?: Record) => string }) => ( +
+
+
+ {inputValue.trim() ? ( + {t('common:noModelsFoundFor', { searchValue: inputValue })} + ) : ( + {t('common:noModels')} + )} +
+
+
+) + +const ModelsList = ({ + filteredModels, + value, + highlightedIndex, + onModelSelect, + onHighlight +}: { + filteredModels: string[] + value: string + highlightedIndex: number + onModelSelect: (model: string) => void + onHighlight: (index: number) => void +}) => ( + <> + {filteredModels.map((model, index) => ( +
{ + e.stopPropagation() + onModelSelect(model) + }} + onMouseEnter={() => onHighlight(index)} + className={cn( + 'cursor-pointer px-3 py-2 hover:bg-main-view-fg/15 hover:shadow-sm transition-all duration-200 text-main-view-fg', + value === model && 'bg-main-view-fg/12 shadow-sm', + highlightedIndex === index && 'bg-main-view-fg/20 shadow-md' + )} + > + {model} +
+ ))} + +) + +// Custom hook for keyboard navigation +function useKeyboardNavigation( + open: boolean, + setOpen: React.Dispatch>, + models: string[], + filteredModels: string[], + highlightedIndex: number, + setHighlightedIndex: React.Dispatch>, + onModelSelect: (model: string) => void, + dropdownRef: React.RefObject +) { + + // Scroll to the highlighted element + useEffect(() => { + if (highlightedIndex >= 0 && dropdownRef.current) { + requestAnimationFrame(() => { + const modelElements = dropdownRef.current?.querySelectorAll('[data-model]') + const highlightedElement = modelElements?.[highlightedIndex] as HTMLElement + if (highlightedElement) { + highlightedElement.scrollIntoView({ + block: 'nearest', + behavior: 'auto' + }) + } + }) + } + }, [highlightedIndex, dropdownRef]) + + const handleKeyDown = useCallback((e: React.KeyboardEvent) => { + // Open the dropdown with the arrows if closed + if (!open && (e.key === 'ArrowDown' || e.key === 'ArrowUp')) { + if (models.length > 0) { + e.preventDefault() + setOpen(true) + setHighlightedIndex(0) + } + return + } + + if (!open) return + + switch (e.key) { + case 'ArrowDown': + e.preventDefault() + setHighlightedIndex((prev: number) => filteredModels.length === 0 ? 0 : (prev < filteredModels.length - 1 ? prev + 1 : 0)) + break + case 'ArrowUp': + e.preventDefault() + setHighlightedIndex((prev: number) => filteredModels.length === 0 ? 0 : (prev > 0 ? prev - 1 : filteredModels.length - 1)) + break + case 'Enter': + e.preventDefault() + if (highlightedIndex >= 0 && highlightedIndex < filteredModels.length) { + onModelSelect(filteredModels[highlightedIndex]) + } + break + case 'Escape': + e.preventDefault() + e.stopPropagation() + setOpen(false) + setHighlightedIndex(-1) + break + case 'PageUp': + e.preventDefault() + setHighlightedIndex(0) + break + case 'PageDown': + e.preventDefault() + setHighlightedIndex(filteredModels.length - 1) + break + } + }, [open, setOpen, models.length, filteredModels, highlightedIndex, setHighlightedIndex, onModelSelect]) + + return { handleKeyDown } +} + +type ModelComboboxProps = { + value: string + onChange: (value: string) => void + models: string[] + loading?: boolean + error?: string | null + onRefresh?: () => void + placeholder?: string + disabled?: boolean + className?: string + onOpenChange?: (open: boolean) => void +} + +export function ModelCombobox({ + value, + onChange, + models, + loading = false, + error = null, + onRefresh, + placeholder = 'Type or select a model...', + disabled = false, + className, + onOpenChange, +}: ModelComboboxProps) { + const [open, setOpen] = useState(false) + const [inputValue, setInputValue] = useState(value) + const [highlightedIndex, setHighlightedIndex] = useState(-1) + const inputRef = useRef(null) + const containerRef = useRef(null) + const dropdownRef = useRef(null) + const { t } = useTranslation() + + // Sync input value with prop value + useEffect(() => { + setInputValue(value) + }, [value]) + + // Notify parent when open state changes + useEffect(() => { + onOpenChange?.(open) + }, [open, onOpenChange]) + + // Hook for the dropdown position + const { dropdownPosition } = useDropdownPosition(open, containerRef) + + // Optimized model filtering + const filteredModels = useMemo(() => { + if (!inputValue.trim()) return models + const searchValue = inputValue.toLowerCase() + return models.filter((model) => model.toLowerCase().includes(searchValue)) + }, [models, inputValue]) + + // Reset highlighted index when filtered models change + useEffect(() => { + setHighlightedIndex(-1) + }, [filteredModels]) + + // Close the dropdown when clicking outside + useEffect(() => { + if (!open) return + + const handleClickOutside = (event: Event) => { + const target = event.target as Node + const isInsideContainer = containerRef.current?.contains(target) + const isInsideDropdown = dropdownRef.current?.contains(target) + + if (!isInsideContainer && !isInsideDropdown) { + setOpen(false) + setHighlightedIndex(-1) + } + } + + const events = ['mousedown', 'touchstart'] + events.forEach(eventType => { + document.addEventListener(eventType, handleClickOutside, { capture: true, passive: true }) + }) + + return () => { + events.forEach(eventType => { + document.removeEventListener(eventType, handleClickOutside, { capture: true }) + }) + } + }, [open]) + + // Cleanup: close the dropdown when the component is unmounted + useEffect(() => { + return () => { + setOpen(false) + setHighlightedIndex(-1) + } + }, []) + + // Handler for the input change + const handleInputChange = useCallback((newValue: string) => { + setInputValue(newValue) + onChange(newValue) + + // Open the dropdown if the user types and there are models + if (newValue.trim() && models.length > 0) { + setOpen(true) + } else { + setOpen(false) + } + }, [onChange, models.length]) + + // Handler for the model selection + const handleModelSelect = useCallback((model: string) => { + setInputValue(model) + onChange(model) + setOpen(false) + setHighlightedIndex(-1) + inputRef.current?.focus() + }, [onChange]) + + // Hook for the keyboard navigation + const { handleKeyDown } = useKeyboardNavigation( + open, + setOpen, + models, + filteredModels, + highlightedIndex, + setHighlightedIndex, + handleModelSelect, + dropdownRef + ) + + // Handler for the dropdown opening + const handleDropdownToggle = useCallback(() => { + inputRef.current?.focus() + setOpen(!open) + }, [open]) + + // Handler for the input click + const handleInputClick = useCallback(() => { + if (models.length > 0) { + setOpen(true) + } + }, [models.length]) + + return ( +
+
+ handleInputChange(e.target.value)} + onKeyDown={handleKeyDown} + onClick={handleInputClick} + placeholder={placeholder} + disabled={disabled} + className="pr-16" + /> + + {/* Input action buttons */} +
+ {onRefresh && ( + + )} + +
+ + {/* Custom dropdown rendered as portal */} + {open && dropdownPosition.width > 0 && createPortal( +
e.stopPropagation()} + onWheel={(e) => e.stopPropagation()} + > + {/* Error state */} + {error && } + + {/* Loading state */} + {loading && } + + {/* Models list */} + {!loading && !error && ( + filteredModels.length === 0 ? ( + + ) : ( + + ) + )} +
, + document.body + )} +
+
+ ) +} diff --git a/web-app/src/containers/__tests__/ModelCombobox.test.tsx b/web-app/src/containers/__tests__/ModelCombobox.test.tsx new file mode 100644 index 000000000..38f9b97c8 --- /dev/null +++ b/web-app/src/containers/__tests__/ModelCombobox.test.tsx @@ -0,0 +1,490 @@ +import { describe, it, expect, vi, beforeEach, beforeAll, afterAll } from 'vitest' +import { render, screen, fireEvent, waitFor, act } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import '@testing-library/jest-dom/vitest' +import React from 'react' +import { ModelCombobox } from '../ModelCombobox' + +// Mock translation hook +vi.mock('@/i18n/react-i18next-compat', () => ({ + useTranslation: () => ({ + t: (key: string, options?: Record) => { + if (key === 'common:failedToLoadModels') return 'Failed to load models' + if (key === 'common:loading') return 'Loading' + if (key === 'common:noModelsFoundFor') return `No models found for "${options?.searchValue}"` + if (key === 'common:noModels') return 'No models available' + return key + }, + }), +})) + +describe('ModelCombobox', () => { + const mockOnChange = vi.fn() + const mockOnRefresh = vi.fn() + + const defaultProps = { + value: '', + onChange: mockOnChange, + models: ['gpt-3.5-turbo', 'gpt-4', 'claude-3-haiku'], + } + + let bcrSpy: ReturnType + let scrollSpy: ReturnType + + beforeAll(() => { + const mockRect = { + width: 300, + height: 40, + top: 100, + left: 50, + bottom: 140, + right: 350, + x: 50, + y: 100, + toJSON: () => {}, + } as unknown as DOMRect + + bcrSpy = vi + .spyOn(Element.prototype as any, 'getBoundingClientRect') + .mockReturnValue(mockRect) + + Element.prototype.scrollIntoView = () => {} + }) + + beforeEach(() => { + vi.clearAllMocks() + }) + + afterAll(() => { + bcrSpy?.mockRestore() + scrollSpy?.mockRestore() + }) + + it('renders input field with default placeholder', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + expect(input).toHaveAttribute('placeholder', 'Type or select a model...') + }) + + it('renders custom placeholder', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + expect(input).toHaveAttribute('placeholder', 'Choose a model') + }) + + it('renders dropdown trigger button', () => { + act(() => { + render() + }) + + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('displays current value in input', () => { + act(() => { + render() + }) + + const input = screen.getByDisplayValue('gpt-4') + expect(input).toBeInTheDocument() + }) + + it('applies custom className', () => { + const { container } = render( + + ) + + const wrapper = container.firstChild as HTMLElement + expect(wrapper).toHaveClass('custom-class') + }) + + it('disables input when disabled prop is true', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + const button = screen.getByRole('button') + + expect(input).toBeDisabled() + expect(button).toBeDisabled() + }) + + it('shows loading spinner in trigger button', () => { + act(() => { + render() + }) + + const button = screen.getByRole('button') + const spinner = button.querySelector('.animate-spin') + expect(spinner).toBeInTheDocument() + }) + + it('shows loading section when dropdown is opened during loading', async () => { + const user = userEvent.setup() + render() + + // Click input to trigger dropdown opening + const input = screen.getByRole('textbox') + await user.click(input) + + // Wait for dropdown to appear and check loading section + await waitFor(() => { + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + expect(screen.getByText('Loading')).toBeInTheDocument() + }) + }) + + it('calls onChange when typing', async () => { + const user = userEvent.setup() + const localMockOnChange = vi.fn() + render() + + const input = screen.getByRole('textbox') + await user.type(input, 'g') + + expect(localMockOnChange).toHaveBeenCalledWith('g') + }) + + it('updates input value when typing', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + await user.type(input, 'test') + + expect(input).toHaveValue('test') + }) + + it('handles input focus', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + await user.click(input) + + expect(input).toHaveFocus() + }) + + it('renders with empty models array', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + }) + + it('renders with models array', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + }) + + it('handles mount and unmount without errors', () => { + const { unmount } = render() + + expect(screen.getByRole('textbox')).toBeInTheDocument() + + unmount() + + expect(screen.queryByRole('textbox')).not.toBeInTheDocument() + }) + + it('handles props changes', () => { + const { rerender } = render() + + expect(screen.getByDisplayValue('')).toBeInTheDocument() + + rerender() + + expect(screen.getByDisplayValue('gpt-4')).toBeInTheDocument() + }) + + it('handles models array changes', () => { + const { rerender } = render() + + expect(screen.getByRole('textbox')).toBeInTheDocument() + + rerender() + + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('does not open dropdown when clicking input with no models', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + await user.click(input) + + // Should focus but not open dropdown + expect(input).toHaveFocus() + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).not.toBeInTheDocument() + }) + + it('accepts error prop without crashing', () => { + act(() => { + render() + }) + + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + expect(input).toHaveAttribute('placeholder', 'Type or select a model...') + }) + + it('renders with all props', () => { + act(() => { + render( + + ) + }) + + const input = screen.getByRole('textbox') + expect(input).toBeInTheDocument() + expect(input).toBeDisabled() + }) + + it('opens dropdown when clicking trigger button', async () => { + const user = userEvent.setup() + render() + + const button = screen.getByRole('button') + await user.click(button) + + await waitFor(() => { + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + }) + }) + + it('opens dropdown when clicking input', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + await user.click(input) + + expect(input).toHaveFocus() + await waitFor(() => { + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + }) + }) + + it('filters models based on input value', async () => { + const user = userEvent.setup() + const localMockOnChange = vi.fn() + render() + + const input = screen.getByRole('textbox') + await user.type(input, 'gpt-4') + + expect(localMockOnChange).toHaveBeenCalledWith('gpt-4') + }) + + it('shows filtered models in dropdown when typing', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + // Type 'gpt' to trigger dropdown opening + await user.type(input, 'gpt') + + await waitFor(() => { + // Dropdown should be open + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + + // Should show GPT models + expect(screen.getByText('gpt-3.5-turbo')).toBeInTheDocument() + expect(screen.getByText('gpt-4')).toBeInTheDocument() + // Should not show Claude + expect(screen.queryByText('claude-3-haiku')).not.toBeInTheDocument() + }) + }) + + it('handles case insensitive filtering', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + await user.type(input, 'GPT') + + expect(mockOnChange).toHaveBeenCalledWith('GPT') + }) + + it('shows empty state when no models match filter', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + // Type something that doesn't match any model to trigger dropdown + empty state + await user.type(input, 'nonexistent') + + await waitFor(() => { + // Dropdown should be open + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + // Should show empty state message + expect(screen.getByText('No models found for "nonexistent"')).toBeInTheDocument() + }) + }) + + it('selects model from dropdown when clicked', async () => { + const user = userEvent.setup() + const localMockOnChange = vi.fn() + render() + + const input = screen.getByRole('textbox') + await user.click(input) + + await waitFor(() => { + const modelOption = screen.getByText('gpt-4') + expect(modelOption).toBeInTheDocument() + }) + + const modelOption = screen.getByText('gpt-4') + await user.click(modelOption) + + expect(localMockOnChange).toHaveBeenCalledWith('gpt-4') + expect(input).toHaveValue('gpt-4') + }) + + it('submits input value with Enter key', async () => { + const user = userEvent.setup() + const localMockOnChange = vi.fn() + render() + + const input = screen.getByRole('textbox') + await user.type(input, 'gpt') + await user.keyboard('{Enter}') + + expect(localMockOnChange).toHaveBeenCalledWith('gpt') + }) + + it('displays error message in dropdown', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + // Click input to open dropdown + await user.click(input) + + await waitFor(() => { + // Dropdown should be open + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + // Error messages should be displayed + expect(screen.getByText('Failed to load models')).toBeInTheDocument() + expect(screen.getByText('Network connection failed')).toBeInTheDocument() + }) + }) + + it('calls onRefresh when refresh button is clicked', async () => { + const user = userEvent.setup() + const localMockOnRefresh = vi.fn() + render() + + const input = screen.getByRole('textbox') + // Click input to open dropdown + await user.click(input) + + await waitFor(() => { + // Dropdown should be open with error section + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + const refreshButton = document.querySelector('[aria-label="Refresh models"]') + expect(refreshButton).toBeInTheDocument() + }) + + const refreshButton = document.querySelector('[aria-label="Refresh models"]') + if (refreshButton) { + await user.click(refreshButton) + expect(localMockOnRefresh).toHaveBeenCalledTimes(1) + } + }) + + it('opens dropdown when pressing ArrowDown', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + input.focus() + await user.keyboard('{ArrowDown}') + + expect(input).toHaveFocus() + await waitFor(() => { + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + }) + }) + + it('navigates through models with arrow keys', async () => { + const user = userEvent.setup() + render() + + const input = screen.getByRole('textbox') + input.focus() + + // ArrowDown should open dropdown + await user.keyboard('{ArrowDown}') + + await waitFor(() => { + // Dropdown should be open + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + }) + + // Navigate to second item + await user.keyboard('{ArrowDown}') + + await waitFor(() => { + const secondModel = screen.getByText('gpt-4') + const modelElement = secondModel.closest('[data-model]') + expect(modelElement).toHaveClass('bg-main-view-fg/20') + }) + }) + + it('handles Enter key to select highlighted model', async () => { + const user = userEvent.setup() + const localMockOnChange = vi.fn() + render() + + const input = screen.getByRole('textbox') + // Type 'gpt' to open dropdown and filter models + await user.type(input, 'gpt') + + await waitFor(() => { + // Dropdown should be open with filtered models + const dropdown = document.querySelector('[data-dropdown="model-combobox"]') + expect(dropdown).toBeInTheDocument() + }) + + // Navigate to highlight first model and select it + await user.keyboard('{ArrowDown}') + await user.keyboard('{Enter}') + + expect(localMockOnChange).toHaveBeenCalledWith('gpt-3.5-turbo') + }) +}) diff --git a/web-app/src/containers/dialogs/AddModel.tsx b/web-app/src/containers/dialogs/AddModel.tsx index 248600212..3ccdc6d65 100644 --- a/web-app/src/containers/dialogs/AddModel.tsx +++ b/web-app/src/containers/dialogs/AddModel.tsx @@ -8,8 +8,9 @@ import { DialogFooter, } from '@/components/ui/dialog' import { Button } from '@/components/ui/button' -import { Input } from '@/components/ui/input' import { useModelProvider } from '@/hooks/useModelProvider' +import { useProviderModels } from '@/hooks/useProviderModels' +import { ModelCombobox } from '@/containers/ModelCombobox' import { IconPlus } from '@tabler/icons-react' import { useState } from 'react' import { getProviderTitle } from '@/lib/utils' @@ -25,6 +26,12 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => { const { updateProvider } = useModelProvider() const [modelId, setModelId] = useState('') const [open, setOpen] = useState(false) + const [isComboboxOpen, setIsComboboxOpen] = useState(false) + + // Fetch models from provider API (API key is optional) + const { models, loading, error, refetch } = useProviderModels( + provider.base_url ? provider : undefined + ) // Handle form submission const handleSubmit = () => { @@ -62,7 +69,13 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => { )} - + { + if (isComboboxOpen) { + e.preventDefault() + } + }} + > {t('providers:addModel.title')} @@ -72,7 +85,7 @@ export const DialogAddModel = ({ provider, trigger }: DialogAddModelProps) => { - {/* Model ID field - required */} + {/* Model selection field - required */}
- setModelId(e.target.value)} + onChange={setModelId} + models={models} + loading={loading} + error={error} + onRefresh={refetch} placeholder={t('providers:addModel.enterModelId')} - required + onOpenChange={setIsComboboxOpen} />
diff --git a/web-app/src/hooks/__tests__/useProviderModels.test.ts b/web-app/src/hooks/__tests__/useProviderModels.test.ts new file mode 100644 index 000000000..da9b60e07 --- /dev/null +++ b/web-app/src/hooks/__tests__/useProviderModels.test.ts @@ -0,0 +1,102 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { renderHook, waitFor } from '@testing-library/react' +import { useProviderModels } from '../useProviderModels' +import { useServiceHub } from '@/hooks/useServiceHub' + +// Local minimal provider type for tests +type MockModelProvider = { + active: boolean + provider: string + base_url?: string + api_key?: string + settings: any[] + models: any[] +} + +describe('useProviderModels', () => { + const mockProvider: MockModelProvider = { + active: true, + provider: 'openai', + base_url: 'https://api.openai.com/v1', + api_key: 'test-api-key', + settings: [], + models: [], + } + + const mockModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-4-turbo'] + + let fetchModelsSpy: ReturnType + + beforeEach(() => { + vi.restoreAllMocks() + vi.clearAllMocks() + const hub = (useServiceHub as unknown as () => any)() + const mockedFetch = vi.fn() + vi.spyOn(hub, 'providers').mockReturnValue({ + fetchModelsFromProvider: mockedFetch, + } as any) + fetchModelsSpy = mockedFetch + }) + + it('should initialize with empty state', () => { + const { result } = renderHook(() => useProviderModels()) + + expect(result.current.models).toEqual([]) + expect(result.current.loading).toBe(false) + expect(result.current.error).toBe(null) + expect(typeof result.current.refetch).toBe('function') + }) + + it('should not fetch models when provider is undefined', () => { + renderHook(() => useProviderModels(undefined)) + expect(fetchModelsSpy).not.toHaveBeenCalled() + }) + + it('should not fetch models when provider has no base_url', () => { + const providerWithoutUrl = { ...mockProvider, base_url: undefined } + renderHook(() => useProviderModels(providerWithoutUrl)) + expect(fetchModelsSpy).not.toHaveBeenCalled() + }) + + it('should fetch and sort models', async () => { + fetchModelsSpy.mockResolvedValueOnce(mockModels) + + const { result } = renderHook(() => useProviderModels(mockProvider)) + + await waitFor(() => { + expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo']) + }) + + expect(result.current.error).toBe(null) + expect(fetchModelsSpy).toHaveBeenCalledWith(mockProvider) + }) + + it('should clear models when switching to invalid provider', async () => { + fetchModelsSpy.mockResolvedValueOnce(mockModels) + + const { result, rerender } = renderHook( + ({ provider }) => useProviderModels(provider), + { initialProps: { provider: mockProvider } } + ) + + await waitFor(() => { + expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo']) + expect(result.current.loading).toBe(false) + }, { timeout: 500 }) + + // Switch to invalid provider + rerender({ provider: { ...mockProvider, base_url: undefined } }) + + expect(result.current.models).toEqual([]) + expect(result.current.error).toBe(null) + expect(result.current.loading).toBe(false) + }) + + it('should not refetch when provider is undefined', () => { + const { result } = renderHook(() => useProviderModels(undefined)) + + result.current.refetch() + + expect(fetchModelsSpy).not.toHaveBeenCalled() + }) +}) \ No newline at end of file diff --git a/web-app/src/hooks/useProviderModels.ts b/web-app/src/hooks/useProviderModels.ts new file mode 100644 index 000000000..3c51a7f70 --- /dev/null +++ b/web-app/src/hooks/useProviderModels.ts @@ -0,0 +1,93 @@ +import { useState, useEffect, useCallback, useRef } from 'react' +import { useServiceHub } from './useServiceHub' + +type UseProviderModelsState = { + models: string[] + loading: boolean + error: string | null + refetch: () => void +} + +const modelsCache = new Map() +const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes + +export const useProviderModels = (provider?: ModelProvider): UseProviderModelsState => { + const serviceHub = useServiceHub() + const [models, setModels] = useState([]) + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + const prevProviderKey = useRef('') + const requestIdRef = useRef(0) + + const fetchModels = useCallback(async () => { + if (!provider || !provider.base_url) { + // Clear models if provider is invalid (base_url is required, api_key is optional) + setModels([]) + setError(null) + setLoading(false) + return + } + + // Clear any previous state when starting a new fetch for a different provider + const currentProviderKey = `${provider.provider}-${provider.base_url}` + if (currentProviderKey !== prevProviderKey.current) { + setModels([]) + setError(null) + setLoading(false) + prevProviderKey.current = currentProviderKey + } + + const cacheKey = `${provider.provider}-${provider.base_url}` + const cached = modelsCache.get(cacheKey) + + // Check cache first + if (cached && Date.now() - cached.timestamp < CACHE_DURATION) { + setModels(cached.models) + return + } + + const currentRequestId = ++requestIdRef.current + setLoading(true) + setError(null) + + try { + const fetchedModels = await serviceHub.providers().fetchModelsFromProvider(provider) + if (currentRequestId !== requestIdRef.current) return + const sortedModels = fetchedModels.sort((a, b) => a.localeCompare(b)) + + setModels(sortedModels) + + // Cache the results + modelsCache.set(cacheKey, { + models: sortedModels, + timestamp: Date.now(), + }) + } catch (err) { + if (currentRequestId !== requestIdRef.current) return + const errorMessage = err instanceof Error ? err.message : 'Failed to fetch models' + setError(errorMessage) + console.error(`Error fetching models from ${provider.provider}:`, err) + } finally { + if (currentRequestId === requestIdRef.current) setLoading(false) + } + }, [provider, serviceHub]) + + const refetch = useCallback(() => { + if (provider) { + const cacheKey = `${provider.provider}-${provider.base_url}` + modelsCache.delete(cacheKey) + fetchModels() + } + }, [provider, fetchModels]) + + useEffect(() => { + fetchModels() + }, [fetchModels]) + + return { + models, + loading, + error, + refetch, + } +} diff --git a/web-app/src/locales/en/common.json b/web-app/src/locales/en/common.json index 46f2d5a8a..e5f5aa9f7 100644 --- a/web-app/src/locales/en/common.json +++ b/web-app/src/locales/en/common.json @@ -75,6 +75,8 @@ "selectAModel": "Select a model", "noToolsAvailable": "No tools available", "noModelsFoundFor": "No models found for \"{{searchValue}}\"", + "failedToLoadModels": "Failed to load models", + "noModels": "No models found", "customAvatar": "Custom avatar", "editAssistant": "Edit Assistant", "jan": "Jan", diff --git a/web-app/src/services/__tests__/providers.test.ts b/web-app/src/services/__tests__/providers.test.ts index 63b2b71e4..ed447dba7 100644 --- a/web-app/src/services/__tests__/providers.test.ts +++ b/web-app/src/services/__tests__/providers.test.ts @@ -215,7 +215,7 @@ describe('WebProvidersService', () => { ) }) - it('should throw error when API response is not ok', async () => { + it('should throw error when API response is not ok (404)', async () => { const mockResponse = { ok: false, status: 404, @@ -229,7 +229,43 @@ describe('WebProvidersService', () => { } await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow( - 'Cannot connect to custom at https://api.custom.com. Please check that the service is running and accessible.' + 'Models endpoint not found for custom. Check the base URL configuration.' + ) + }) + + it('should throw error when API response is not ok (403)', async () => { + const mockResponse = { + ok: false, + status: 403, + statusText: 'Forbidden', + } + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const provider = { + provider: 'custom', + base_url: 'https://api.custom.com', + } as ModelProvider + + await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow( + 'Access forbidden: Check your API key permissions for custom' + ) + }) + + it('should throw error when API response is not ok (401)', async () => { + const mockResponse = { + ok: false, + status: 401, + statusText: 'Unauthorized', + } + vi.mocked(global.fetch).mockResolvedValue(mockResponse as any) + + const provider = { + provider: 'custom', + base_url: 'https://api.custom.com', + } as ModelProvider + + await expect(providersService.fetchModelsFromProvider(provider)).rejects.toThrow( + 'Authentication failed: API key is required or invalid for custom' ) }) diff --git a/web-app/src/services/providers/tauri.ts b/web-app/src/services/providers/tauri.ts index d1554e3cf..f6155a8d9 100644 --- a/web-app/src/services/providers/tauri.ts +++ b/web-app/src/services/providers/tauri.ts @@ -127,7 +127,7 @@ export class TauriProvidersService extends DefaultProvidersService { } return runtimeProviders.concat(builtinProviders as ModelProvider[]) - } catch (error) { + } catch (error: unknown) { console.error('Error getting providers in Tauri:', error) return [] } @@ -162,9 +162,24 @@ export class TauriProvidersService extends DefaultProvidersService { }) if (!response.ok) { - throw new Error( - `Failed to fetch models: ${response.status} ${response.statusText}` - ) + // Provide more specific error messages based on status code (aligned with web implementation) + if (response.status === 401) { + throw new Error( + `Authentication failed: API key is required or invalid for ${provider.provider}` + ) + } else if (response.status === 403) { + throw new Error( + `Access forbidden: Check your API key permissions for ${provider.provider}` + ) + } else if (response.status === 404) { + throw new Error( + `Models endpoint not found for ${provider.provider}. Check the base URL configuration.` + ) + } else { + throw new Error( + `Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}` + ) + } } const data = await response.json() @@ -194,14 +209,30 @@ export class TauriProvidersService extends DefaultProvidersService { } catch (error) { console.error('Error fetching models from provider:', error) - // Provide helpful error message + // Preserve structured error messages thrown above + const structuredErrorPrefixes = [ + 'Authentication failed', + 'Access forbidden', + 'Models endpoint not found', + 'Failed to fetch models from' + ] + + if (error instanceof Error && + structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) { + throw new Error(error.message) + } + + // Provide helpful error message for any connection errors if (error instanceof Error && error.message.includes('fetch')) { throw new Error( `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` ) } - throw error + // Generic fallback + throw new Error( + `Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}` + ) } } diff --git a/web-app/src/services/providers/web.ts b/web-app/src/services/providers/web.ts index 30fe71366..5ad426a11 100644 --- a/web-app/src/services/providers/web.ts +++ b/web-app/src/services/providers/web.ts @@ -138,9 +138,24 @@ export class WebProvidersService implements ProvidersService { }) if (!response.ok) { - throw new Error( - `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` - ) + // Provide more specific error messages based on status code + if (response.status === 401) { + throw new Error( + `Authentication failed: API key is required or invalid for ${provider.provider}` + ) + } else if (response.status === 403) { + throw new Error( + `Access forbidden: Check your API key permissions for ${provider.provider}` + ) + } else if (response.status === 404) { + throw new Error( + `Models endpoint not found for ${provider.provider}. Check the base URL configuration.` + ) + } else { + throw new Error( + `Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}` + ) + } } const data = await response.json() @@ -170,13 +185,28 @@ export class WebProvidersService implements ProvidersService { } catch (error) { console.error('Error fetching models from provider:', error) - // Provide helpful error message for any connection errors - if (error instanceof Error && error.message.includes('Cannot connect')) { - throw error + const structuredErrorPrefixes = [ + 'Authentication failed', + 'Access forbidden', + 'Models endpoint not found', + 'Failed to fetch models from' + ] + + if (error instanceof Error && + structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) { + throw new Error(error.message) } - + + // Provide helpful error message for any connection errors + if (error instanceof Error && error.message.includes('fetch')) { + throw new Error( + `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` + ) + } + + // Generic fallback throw new Error( - `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` + `Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}` ) } } diff --git a/web-app/src/test/setup.ts b/web-app/src/test/setup.ts index 9fb6d928d..f03b6c4f2 100644 --- a/web-app/src/test/setup.ts +++ b/web-app/src/test/setup.ts @@ -104,6 +104,7 @@ const mockServiceHub = { deleteProvider: vi.fn().mockResolvedValue(undefined), updateProvider: vi.fn().mockResolvedValue(undefined), getProvider: vi.fn().mockResolvedValue(null), + fetchModelsFromProvider: vi.fn().mockResolvedValue([]), }), models: () => ({ getModels: vi.fn().mockResolvedValue([]),