feat: add refresh button list model remote provider (#5136)

This commit is contained in:
Faisal Amir 2025-05-29 22:06:25 +07:00 committed by GitHub
parent dc1071fff8
commit 8046f95b67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 139 additions and 3 deletions

View File

@ -30,9 +30,9 @@ import Joyride, { CallBackProps, STATUS } from 'react-joyride'
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
import { route } from '@/constants/routes'
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
import { updateSettings } from '@/services/providers'
import { updateSettings, fetchModelsFromProvider } from '@/services/providers'
import { Button } from '@/components/ui/button'
import { IconFolderPlus, IconLoader } from '@tabler/icons-react'
import { IconFolderPlus, IconLoader, IconRefresh } from '@tabler/icons-react'
import { getProviders } from '@/services/providers'
import { toast } from 'sonner'
import { ActiveModel } from '@/types/models'
@ -76,6 +76,7 @@ function ProviderDetail() {
const { step } = useSearch({ from: Route.id })
const [activeModels, setActiveModels] = useState<ActiveModel[]>([])
const [loadingModels, setLoadingModels] = useState<string[]>([])
const [refreshingModels, setRefreshingModels] = useState(false)
const { providerName } = useParams({ from: Route.id })
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
const provider = getProviderByName(providerName)
@ -104,6 +105,61 @@ function ProviderDetail() {
}
}
const handleRefreshModels = async () => {
if (!provider || !provider.base_url) {
toast.error('Refresh Models', {
description:
'Provider must have base URL and API key configured to fetch models.',
})
return
}
setRefreshingModels(true)
try {
const modelIds = await fetchModelsFromProvider(provider)
// Create new models from the fetched IDs
const newModels: Model[] = modelIds.map((id) => ({
id,
model: id,
name: id,
capabilities: ['completion'], // Default capability
version: '1.0',
}))
// Filter out models that already exist
const existingModelIds = provider.models.map((m) => m.id)
const modelsToAdd = newModels.filter(
(model) => !existingModelIds.includes(model.id)
)
if (modelsToAdd.length > 0) {
// Update the provider with new models
const updatedModels = [...provider.models, ...modelsToAdd]
updateProvider(providerName, {
...provider,
models: updatedModels,
})
toast.success('Refresh Models', {
description: `Added ${modelsToAdd.length} new model(s) from ${provider.provider}.`,
})
} else {
toast.success('Refresh Models', {
description:
'No new models found. All available models are already added.',
})
}
} catch (error) {
console.error('Failed to refresh models:', error)
toast.error('Refresh Models', {
description: `Failed to fetch models from ${provider.provider}. Please check your API key and base URL.`,
})
} finally {
setRefreshingModels(false)
}
}
const handleStartModel = (modelId: string) => {
// Add model to loading state
setLoadingModels((prev) => [...prev, modelId])
@ -292,7 +348,33 @@ function ProviderDetail() {
</h1>
<div className="flex items-center gap-2">
{provider && provider.provider !== 'llama.cpp' && (
<>
<Button
variant="link"
size="sm"
className="hover:no-underline"
onClick={handleRefreshModels}
disabled={refreshingModels}
>
<div className="cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/15 bg-main-view-fg/10 transition-all duration-200 ease-in-out px-1.5 py-1 gap-1">
{refreshingModels ? (
<IconLoader
size={18}
className="text-main-view-fg/50 animate-spin"
/>
) : (
<IconRefresh
size={18}
className="text-main-view-fg/50"
/>
)}
<span className="text-main-view-fg/70">
{refreshingModels ? 'Refreshing...' : 'Refresh'}
</span>
</div>
</Button>
<DialogAddModel provider={provider} />
</>
)}
{provider && provider.provider === 'llama.cpp' && (
<Button

View File

@ -101,6 +101,60 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
return runtimeProviders.concat(builtinProviders as ModelProvider[])
}
/**
* Fetches models from a provider's API endpoint
* @param provider The provider object containing base_url and api_key
* @returns Promise<string[]> Array of model IDs
*/
export const fetchModelsFromProvider = async (
provider: ModelProvider
): Promise<string[]> => {
if (!provider.base_url || !provider.api_key) {
throw new Error('Provider must have base_url and api_key configured')
}
try {
const response = await fetch(`${provider.base_url}/models`, {
method: 'GET',
headers: {
'x-api-key': provider.api_key,
'Authorization': `Bearer ${provider.api_key}`,
'Content-Type': 'application/json',
},
})
if (!response.ok) {
throw new Error(
`Failed to fetch models: ${response.status} ${response.statusText}`
)
}
const data = await response.json()
// Handle different response formats that providers might use
if (data.data && Array.isArray(data.data)) {
// OpenAI format: { data: [{ id: "model-id" }, ...] }
return data.data.map((model: { id: string }) => model.id).filter(Boolean)
} else if (Array.isArray(data)) {
// Direct array format: ["model-id1", "model-id2", ...]
return data.filter(Boolean)
} else if (data.models && Array.isArray(data.models)) {
// Alternative format: { models: [...] }
return data.models
.map((model: string | { id: string }) =>
typeof model === 'string' ? model : model.id
)
.filter(Boolean)
} else {
console.warn('Unexpected response format from provider API:', data)
return []
}
} catch (error) {
console.error('Error fetching models from provider:', error)
throw error
}
}
/**
* Update the settings of a provider extension.
* TODO: Later on we don't retrieve this using provider name