fix: use serviceHub to fetch models and fix error message on app

This commit is contained in:
lugnicca 2025-09-08 17:45:35 +02:00
parent 9fcd9503e7
commit 2db9af94fa
4 changed files with 55 additions and 24 deletions

View File

@ -1,9 +1,7 @@
import { describe, it, expect, beforeEach, vi } from 'vitest' import { describe, it, expect, beforeEach, vi } from 'vitest'
import { renderHook, waitFor } from '@testing-library/react' import { renderHook, waitFor } from '@testing-library/react'
import { useProviderModels } from '../useProviderModels' import { useProviderModels } from '../useProviderModels'
import { WebProvidersService } from '../../services/providers/web' import { useServiceHub } from '@/hooks/useServiceHub'
let fetchModelsSpy: ReturnType<typeof vi.spyOn>
// Local minimal provider type for tests // Local minimal provider type for tests
type MockModelProvider = { type MockModelProvider = {
@ -27,13 +25,17 @@ describe('useProviderModels', () => {
const mockModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-4-turbo'] const mockModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-4-turbo']
let fetchModelsSpy: ReturnType<typeof vi.fn>
beforeEach(() => { beforeEach(() => {
vi.restoreAllMocks() vi.restoreAllMocks()
vi.clearAllMocks() vi.clearAllMocks()
fetchModelsSpy = vi.spyOn( const hub = (useServiceHub as unknown as () => any)()
WebProvidersService.prototype, const mockedFetch = vi.fn()
'fetchModelsFromProvider' vi.spyOn(hub, 'providers').mockReturnValue({
) fetchModelsFromProvider: mockedFetch,
} as any)
fetchModelsSpy = mockedFetch
}) })
it('should initialize with empty state', () => { it('should initialize with empty state', () => {
@ -62,11 +64,9 @@ describe('useProviderModels', () => {
const { result } = renderHook(() => useProviderModels(mockProvider)) const { result } = renderHook(() => useProviderModels(mockProvider))
await waitFor(() => { await waitFor(() => {
expect(result.current.loading).toBe(false) expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
}) })
// Should be sorted alphabetically
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
expect(result.current.error).toBe(null) expect(result.current.error).toBe(null)
expect(fetchModelsSpy).toHaveBeenCalledWith(mockProvider) expect(fetchModelsSpy).toHaveBeenCalledWith(mockProvider)
}) })
@ -80,10 +80,9 @@ describe('useProviderModels', () => {
) )
await waitFor(() => { await waitFor(() => {
expect(result.current.loading).toBe(false)
})
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo']) expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
expect(result.current.loading).toBe(false)
}, { timeout: 500 })
// Switch to invalid provider // Switch to invalid provider
rerender({ provider: { ...mockProvider, base_url: undefined } }) rerender({ provider: { ...mockProvider, base_url: undefined } })

View File

@ -1,5 +1,5 @@
import { useState, useEffect, useCallback, useRef, useMemo } from 'react' import { useState, useEffect, useCallback, useRef } from 'react'
import { WebProvidersService } from '../services/providers/web' import { useServiceHub } from './useServiceHub'
type UseProviderModelsState = { type UseProviderModelsState = {
models: string[] models: string[]
@ -12,7 +12,7 @@ const modelsCache = new Map<string, { models: string[]; timestamp: number }>()
const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes
export const useProviderModels = (provider?: ModelProvider): UseProviderModelsState => { export const useProviderModels = (provider?: ModelProvider): UseProviderModelsState => {
const providersService = useMemo(() => new WebProvidersService(), []) const serviceHub = useServiceHub()
const [models, setModels] = useState<string[]>([]) const [models, setModels] = useState<string[]>([])
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const [error, setError] = useState<string | null>(null) const [error, setError] = useState<string | null>(null)
@ -51,7 +51,7 @@ export const useProviderModels = (provider?: ModelProvider): UseProviderModelsSt
setError(null) setError(null)
try { try {
const fetchedModels = await providersService.fetchModelsFromProvider(provider) const fetchedModels = await serviceHub.providers().fetchModelsFromProvider(provider)
if (currentRequestId !== requestIdRef.current) return if (currentRequestId !== requestIdRef.current) return
const sortedModels = fetchedModels.sort((a, b) => a.localeCompare(b)) const sortedModels = fetchedModels.sort((a, b) => a.localeCompare(b))
@ -70,7 +70,7 @@ export const useProviderModels = (provider?: ModelProvider): UseProviderModelsSt
} finally { } finally {
if (currentRequestId === requestIdRef.current) setLoading(false) if (currentRequestId === requestIdRef.current) setLoading(false)
} }
}, [provider, providersService]) }, [provider, serviceHub])
const refetch = useCallback(() => { const refetch = useCallback(() => {
if (provider) { if (provider) {

View File

@ -113,7 +113,7 @@ export class TauriProvidersService extends DefaultProvidersService {
} }
return runtimeProviders.concat(builtinProviders as ModelProvider[]) return runtimeProviders.concat(builtinProviders as ModelProvider[])
} catch (error) { } catch (error: unknown) {
console.error('Error getting providers in Tauri:', error) console.error('Error getting providers in Tauri:', error)
return [] return []
} }
@ -142,9 +142,24 @@ export class TauriProvidersService extends DefaultProvidersService {
}) })
if (!response.ok) { if (!response.ok) {
// Provide more specific error messages based on status code (aligned with web implementation)
if (response.status === 401) {
throw new Error( throw new Error(
`Failed to fetch models: ${response.status} ${response.statusText}` `Authentication failed: API key is required or invalid for ${provider.provider}`
) )
} else if (response.status === 403) {
throw new Error(
`Access forbidden: Check your API key permissions for ${provider.provider}`
)
} else if (response.status === 404) {
throw new Error(
`Models endpoint not found for ${provider.provider}. Check the base URL configuration.`
)
} else {
throw new Error(
`Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}`
)
}
} }
const data = await response.json() const data = await response.json()
@ -174,14 +189,30 @@ export class TauriProvidersService extends DefaultProvidersService {
} catch (error) { } catch (error) {
console.error('Error fetching models from provider:', error) console.error('Error fetching models from provider:', error)
// Provide helpful error message // Preserve structured error messages thrown above
const structuredErrorPrefixes = [
'Authentication failed',
'Access forbidden',
'Models endpoint not found',
'Failed to fetch models from'
]
if (error instanceof Error &&
structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) {
throw new Error(error.message)
}
// Provide helpful error message for any connection errors
if (error instanceof Error && error.message.includes('fetch')) { if (error instanceof Error && error.message.includes('fetch')) {
throw new Error( throw new Error(
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.` `Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
) )
} }
throw error // Generic fallback
throw new Error(
`Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}`
)
} }
} }

View File

@ -103,6 +103,7 @@ const mockServiceHub = {
deleteProvider: vi.fn().mockResolvedValue(undefined), deleteProvider: vi.fn().mockResolvedValue(undefined),
updateProvider: vi.fn().mockResolvedValue(undefined), updateProvider: vi.fn().mockResolvedValue(undefined),
getProvider: vi.fn().mockResolvedValue(null), getProvider: vi.fn().mockResolvedValue(null),
fetchModelsFromProvider: vi.fn().mockResolvedValue([]),
}), }),
models: () => ({ models: () => ({
getModels: vi.fn().mockResolvedValue([]), getModels: vi.fn().mockResolvedValue([]),