fix: use serviceHub to fetch models and fix error message on app
This commit is contained in:
parent
9fcd9503e7
commit
2db9af94fa
@ -1,9 +1,7 @@
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest'
|
||||
import { renderHook, waitFor } from '@testing-library/react'
|
||||
import { useProviderModels } from '../useProviderModels'
|
||||
import { WebProvidersService } from '../../services/providers/web'
|
||||
|
||||
let fetchModelsSpy: ReturnType<typeof vi.spyOn>
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
|
||||
// Local minimal provider type for tests
|
||||
type MockModelProvider = {
|
||||
@ -27,13 +25,17 @@ describe('useProviderModels', () => {
|
||||
|
||||
const mockModels = ['gpt-4', 'gpt-3.5-turbo', 'gpt-4-turbo']
|
||||
|
||||
let fetchModelsSpy: ReturnType<typeof vi.fn>
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.clearAllMocks()
|
||||
fetchModelsSpy = vi.spyOn(
|
||||
WebProvidersService.prototype,
|
||||
'fetchModelsFromProvider'
|
||||
)
|
||||
const hub = (useServiceHub as unknown as () => any)()
|
||||
const mockedFetch = vi.fn()
|
||||
vi.spyOn(hub, 'providers').mockReturnValue({
|
||||
fetchModelsFromProvider: mockedFetch,
|
||||
} as any)
|
||||
fetchModelsSpy = mockedFetch
|
||||
})
|
||||
|
||||
it('should initialize with empty state', () => {
|
||||
@ -62,11 +64,9 @@ describe('useProviderModels', () => {
|
||||
const { result } = renderHook(() => useProviderModels(mockProvider))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
|
||||
})
|
||||
|
||||
// Should be sorted alphabetically
|
||||
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
|
||||
expect(result.current.error).toBe(null)
|
||||
expect(fetchModelsSpy).toHaveBeenCalledWith(mockProvider)
|
||||
})
|
||||
@ -80,10 +80,9 @@ describe('useProviderModels', () => {
|
||||
)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.loading).toBe(false)
|
||||
})
|
||||
|
||||
expect(result.current.models).toEqual(['gpt-3.5-turbo', 'gpt-4', 'gpt-4-turbo'])
|
||||
expect(result.current.loading).toBe(false)
|
||||
}, { timeout: 500 })
|
||||
|
||||
// Switch to invalid provider
|
||||
rerender({ provider: { ...mockProvider, base_url: undefined } })
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { useState, useEffect, useCallback, useRef, useMemo } from 'react'
|
||||
import { WebProvidersService } from '../services/providers/web'
|
||||
import { useState, useEffect, useCallback, useRef } from 'react'
|
||||
import { useServiceHub } from './useServiceHub'
|
||||
|
||||
type UseProviderModelsState = {
|
||||
models: string[]
|
||||
@ -12,7 +12,7 @@ const modelsCache = new Map<string, { models: string[]; timestamp: number }>()
|
||||
const CACHE_DURATION = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
export const useProviderModels = (provider?: ModelProvider): UseProviderModelsState => {
|
||||
const providersService = useMemo(() => new WebProvidersService(), [])
|
||||
const serviceHub = useServiceHub()
|
||||
const [models, setModels] = useState<string[]>([])
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [error, setError] = useState<string | null>(null)
|
||||
@ -51,7 +51,7 @@ export const useProviderModels = (provider?: ModelProvider): UseProviderModelsSt
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const fetchedModels = await providersService.fetchModelsFromProvider(provider)
|
||||
const fetchedModels = await serviceHub.providers().fetchModelsFromProvider(provider)
|
||||
if (currentRequestId !== requestIdRef.current) return
|
||||
const sortedModels = fetchedModels.sort((a, b) => a.localeCompare(b))
|
||||
|
||||
@ -70,7 +70,7 @@ export const useProviderModels = (provider?: ModelProvider): UseProviderModelsSt
|
||||
} finally {
|
||||
if (currentRequestId === requestIdRef.current) setLoading(false)
|
||||
}
|
||||
}, [provider, providersService])
|
||||
}, [provider, serviceHub])
|
||||
|
||||
const refetch = useCallback(() => {
|
||||
if (provider) {
|
||||
|
||||
@ -113,7 +113,7 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
}
|
||||
|
||||
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
||||
} catch (error) {
|
||||
} catch (error: unknown) {
|
||||
console.error('Error getting providers in Tauri:', error)
|
||||
return []
|
||||
}
|
||||
@ -142,9 +142,24 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
// Provide more specific error messages based on status code (aligned with web implementation)
|
||||
if (response.status === 401) {
|
||||
throw new Error(
|
||||
`Failed to fetch models: ${response.status} ${response.statusText}`
|
||||
`Authentication failed: API key is required or invalid for ${provider.provider}`
|
||||
)
|
||||
} else if (response.status === 403) {
|
||||
throw new Error(
|
||||
`Access forbidden: Check your API key permissions for ${provider.provider}`
|
||||
)
|
||||
} else if (response.status === 404) {
|
||||
throw new Error(
|
||||
`Models endpoint not found for ${provider.provider}. Check the base URL configuration.`
|
||||
)
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to fetch models from ${provider.provider}: ${response.status} ${response.statusText}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
@ -174,14 +189,30 @@ export class TauriProvidersService extends DefaultProvidersService {
|
||||
} catch (error) {
|
||||
console.error('Error fetching models from provider:', error)
|
||||
|
||||
// Provide helpful error message
|
||||
// Preserve structured error messages thrown above
|
||||
const structuredErrorPrefixes = [
|
||||
'Authentication failed',
|
||||
'Access forbidden',
|
||||
'Models endpoint not found',
|
||||
'Failed to fetch models from'
|
||||
]
|
||||
|
||||
if (error instanceof Error &&
|
||||
structuredErrorPrefixes.some(prefix => (error as Error).message.startsWith(prefix))) {
|
||||
throw new Error(error.message)
|
||||
}
|
||||
|
||||
// Provide helpful error message for any connection errors
|
||||
if (error instanceof Error && error.message.includes('fetch')) {
|
||||
throw new Error(
|
||||
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
|
||||
)
|
||||
}
|
||||
|
||||
throw error
|
||||
// Generic fallback
|
||||
throw new Error(
|
||||
`Unexpected error while fetching models from ${provider.provider}: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -103,6 +103,7 @@ const mockServiceHub = {
|
||||
deleteProvider: vi.fn().mockResolvedValue(undefined),
|
||||
updateProvider: vi.fn().mockResolvedValue(undefined),
|
||||
getProvider: vi.fn().mockResolvedValue(null),
|
||||
fetchModelsFromProvider: vi.fn().mockResolvedValue([]),
|
||||
}),
|
||||
models: () => ({
|
||||
getModels: vi.fn().mockResolvedValue([]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user