feat: add refresh button list model remote provider (#5136)
This commit is contained in:
parent
dc1071fff8
commit
8046f95b67
@ -30,9 +30,9 @@ import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||
import { route } from '@/constants/routes'
|
||||
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
|
||||
import { updateSettings } from '@/services/providers'
|
||||
import { updateSettings, fetchModelsFromProvider } from '@/services/providers'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { IconFolderPlus, IconLoader } from '@tabler/icons-react'
|
||||
import { IconFolderPlus, IconLoader, IconRefresh } from '@tabler/icons-react'
|
||||
import { getProviders } from '@/services/providers'
|
||||
import { toast } from 'sonner'
|
||||
import { ActiveModel } from '@/types/models'
|
||||
@ -76,6 +76,7 @@ function ProviderDetail() {
|
||||
const { step } = useSearch({ from: Route.id })
|
||||
const [activeModels, setActiveModels] = useState<ActiveModel[]>([])
|
||||
const [loadingModels, setLoadingModels] = useState<string[]>([])
|
||||
const [refreshingModels, setRefreshingModels] = useState(false)
|
||||
const { providerName } = useParams({ from: Route.id })
|
||||
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
|
||||
const provider = getProviderByName(providerName)
|
||||
@ -104,6 +105,61 @@ function ProviderDetail() {
|
||||
}
|
||||
}
|
||||
|
||||
const handleRefreshModels = async () => {
|
||||
if (!provider || !provider.base_url) {
|
||||
toast.error('Refresh Models', {
|
||||
description:
|
||||
'Provider must have base URL and API key configured to fetch models.',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
setRefreshingModels(true)
|
||||
try {
|
||||
const modelIds = await fetchModelsFromProvider(provider)
|
||||
|
||||
// Create new models from the fetched IDs
|
||||
const newModels: Model[] = modelIds.map((id) => ({
|
||||
id,
|
||||
model: id,
|
||||
name: id,
|
||||
capabilities: ['completion'], // Default capability
|
||||
version: '1.0',
|
||||
}))
|
||||
|
||||
// Filter out models that already exist
|
||||
const existingModelIds = provider.models.map((m) => m.id)
|
||||
const modelsToAdd = newModels.filter(
|
||||
(model) => !existingModelIds.includes(model.id)
|
||||
)
|
||||
|
||||
if (modelsToAdd.length > 0) {
|
||||
// Update the provider with new models
|
||||
const updatedModels = [...provider.models, ...modelsToAdd]
|
||||
updateProvider(providerName, {
|
||||
...provider,
|
||||
models: updatedModels,
|
||||
})
|
||||
|
||||
toast.success('Refresh Models', {
|
||||
description: `Added ${modelsToAdd.length} new model(s) from ${provider.provider}.`,
|
||||
})
|
||||
} else {
|
||||
toast.success('Refresh Models', {
|
||||
description:
|
||||
'No new models found. All available models are already added.',
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to refresh models:', error)
|
||||
toast.error('Refresh Models', {
|
||||
description: `Failed to fetch models from ${provider.provider}. Please check your API key and base URL.`,
|
||||
})
|
||||
} finally {
|
||||
setRefreshingModels(false)
|
||||
}
|
||||
}
|
||||
|
||||
const handleStartModel = (modelId: string) => {
|
||||
// Add model to loading state
|
||||
setLoadingModels((prev) => [...prev, modelId])
|
||||
@ -292,7 +348,33 @@ function ProviderDetail() {
|
||||
</h1>
|
||||
<div className="flex items-center gap-2">
|
||||
{provider && provider.provider !== 'llama.cpp' && (
|
||||
<>
|
||||
<Button
|
||||
variant="link"
|
||||
size="sm"
|
||||
className="hover:no-underline"
|
||||
onClick={handleRefreshModels}
|
||||
disabled={refreshingModels}
|
||||
>
|
||||
<div className="cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/15 bg-main-view-fg/10 transition-all duration-200 ease-in-out px-1.5 py-1 gap-1">
|
||||
{refreshingModels ? (
|
||||
<IconLoader
|
||||
size={18}
|
||||
className="text-main-view-fg/50 animate-spin"
|
||||
/>
|
||||
) : (
|
||||
<IconRefresh
|
||||
size={18}
|
||||
className="text-main-view-fg/50"
|
||||
/>
|
||||
)}
|
||||
<span className="text-main-view-fg/70">
|
||||
{refreshingModels ? 'Refreshing...' : 'Refresh'}
|
||||
</span>
|
||||
</div>
|
||||
</Button>
|
||||
<DialogAddModel provider={provider} />
|
||||
</>
|
||||
)}
|
||||
{provider && provider.provider === 'llama.cpp' && (
|
||||
<Button
|
||||
|
||||
@ -101,6 +101,60 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
||||
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches models from a provider's API endpoint
|
||||
* @param provider The provider object containing base_url and api_key
|
||||
* @returns Promise<string[]> Array of model IDs
|
||||
*/
|
||||
export const fetchModelsFromProvider = async (
|
||||
provider: ModelProvider
|
||||
): Promise<string[]> => {
|
||||
if (!provider.base_url || !provider.api_key) {
|
||||
throw new Error('Provider must have base_url and api_key configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${provider.base_url}/models`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'x-api-key': provider.api_key,
|
||||
'Authorization': `Bearer ${provider.api_key}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch models: ${response.status} ${response.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
// Handle different response formats that providers might use
|
||||
if (data.data && Array.isArray(data.data)) {
|
||||
// OpenAI format: { data: [{ id: "model-id" }, ...] }
|
||||
return data.data.map((model: { id: string }) => model.id).filter(Boolean)
|
||||
} else if (Array.isArray(data)) {
|
||||
// Direct array format: ["model-id1", "model-id2", ...]
|
||||
return data.filter(Boolean)
|
||||
} else if (data.models && Array.isArray(data.models)) {
|
||||
// Alternative format: { models: [...] }
|
||||
return data.models
|
||||
.map((model: string | { id: string }) =>
|
||||
typeof model === 'string' ? model : model.id
|
||||
)
|
||||
.filter(Boolean)
|
||||
} else {
|
||||
console.warn('Unexpected response format from provider API:', data)
|
||||
return []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error fetching models from provider:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the settings of a provider extension.
|
||||
* TODO: Later on we don't retrieve this using provider name
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user