feat: add refresh button list model remote provider (#5136)
This commit is contained in:
parent
dc1071fff8
commit
8046f95b67
@ -30,9 +30,9 @@ import Joyride, { CallBackProps, STATUS } from 'react-joyride'
|
|||||||
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
import { CustomTooltipJoyRide } from '@/containers/CustomeTooltipJoyRide'
|
||||||
import { route } from '@/constants/routes'
|
import { route } from '@/constants/routes'
|
||||||
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
|
import DeleteProvider from '@/containers/dialogs/DeleteProvider'
|
||||||
import { updateSettings } from '@/services/providers'
|
import { updateSettings, fetchModelsFromProvider } from '@/services/providers'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { IconFolderPlus, IconLoader } from '@tabler/icons-react'
|
import { IconFolderPlus, IconLoader, IconRefresh } from '@tabler/icons-react'
|
||||||
import { getProviders } from '@/services/providers'
|
import { getProviders } from '@/services/providers'
|
||||||
import { toast } from 'sonner'
|
import { toast } from 'sonner'
|
||||||
import { ActiveModel } from '@/types/models'
|
import { ActiveModel } from '@/types/models'
|
||||||
@ -76,6 +76,7 @@ function ProviderDetail() {
|
|||||||
const { step } = useSearch({ from: Route.id })
|
const { step } = useSearch({ from: Route.id })
|
||||||
const [activeModels, setActiveModels] = useState<ActiveModel[]>([])
|
const [activeModels, setActiveModels] = useState<ActiveModel[]>([])
|
||||||
const [loadingModels, setLoadingModels] = useState<string[]>([])
|
const [loadingModels, setLoadingModels] = useState<string[]>([])
|
||||||
|
const [refreshingModels, setRefreshingModels] = useState(false)
|
||||||
const { providerName } = useParams({ from: Route.id })
|
const { providerName } = useParams({ from: Route.id })
|
||||||
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
|
const { getProviderByName, setProviders, updateProvider } = useModelProvider()
|
||||||
const provider = getProviderByName(providerName)
|
const provider = getProviderByName(providerName)
|
||||||
@ -104,6 +105,61 @@ function ProviderDetail() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleRefreshModels = async () => {
|
||||||
|
if (!provider || !provider.base_url) {
|
||||||
|
toast.error('Refresh Models', {
|
||||||
|
description:
|
||||||
|
'Provider must have base URL and API key configured to fetch models.',
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setRefreshingModels(true)
|
||||||
|
try {
|
||||||
|
const modelIds = await fetchModelsFromProvider(provider)
|
||||||
|
|
||||||
|
// Create new models from the fetched IDs
|
||||||
|
const newModels: Model[] = modelIds.map((id) => ({
|
||||||
|
id,
|
||||||
|
model: id,
|
||||||
|
name: id,
|
||||||
|
capabilities: ['completion'], // Default capability
|
||||||
|
version: '1.0',
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Filter out models that already exist
|
||||||
|
const existingModelIds = provider.models.map((m) => m.id)
|
||||||
|
const modelsToAdd = newModels.filter(
|
||||||
|
(model) => !existingModelIds.includes(model.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (modelsToAdd.length > 0) {
|
||||||
|
// Update the provider with new models
|
||||||
|
const updatedModels = [...provider.models, ...modelsToAdd]
|
||||||
|
updateProvider(providerName, {
|
||||||
|
...provider,
|
||||||
|
models: updatedModels,
|
||||||
|
})
|
||||||
|
|
||||||
|
toast.success('Refresh Models', {
|
||||||
|
description: `Added ${modelsToAdd.length} new model(s) from ${provider.provider}.`,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
toast.success('Refresh Models', {
|
||||||
|
description:
|
||||||
|
'No new models found. All available models are already added.',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Failed to refresh models:', error)
|
||||||
|
toast.error('Refresh Models', {
|
||||||
|
description: `Failed to fetch models from ${provider.provider}. Please check your API key and base URL.`,
|
||||||
|
})
|
||||||
|
} finally {
|
||||||
|
setRefreshingModels(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const handleStartModel = (modelId: string) => {
|
const handleStartModel = (modelId: string) => {
|
||||||
// Add model to loading state
|
// Add model to loading state
|
||||||
setLoadingModels((prev) => [...prev, modelId])
|
setLoadingModels((prev) => [...prev, modelId])
|
||||||
@ -292,7 +348,33 @@ function ProviderDetail() {
|
|||||||
</h1>
|
</h1>
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
{provider && provider.provider !== 'llama.cpp' && (
|
{provider && provider.provider !== 'llama.cpp' && (
|
||||||
<DialogAddModel provider={provider} />
|
<>
|
||||||
|
<Button
|
||||||
|
variant="link"
|
||||||
|
size="sm"
|
||||||
|
className="hover:no-underline"
|
||||||
|
onClick={handleRefreshModels}
|
||||||
|
disabled={refreshingModels}
|
||||||
|
>
|
||||||
|
<div className="cursor-pointer flex items-center justify-center rounded hover:bg-main-view-fg/15 bg-main-view-fg/10 transition-all duration-200 ease-in-out px-1.5 py-1 gap-1">
|
||||||
|
{refreshingModels ? (
|
||||||
|
<IconLoader
|
||||||
|
size={18}
|
||||||
|
className="text-main-view-fg/50 animate-spin"
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<IconRefresh
|
||||||
|
size={18}
|
||||||
|
className="text-main-view-fg/50"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
<span className="text-main-view-fg/70">
|
||||||
|
{refreshingModels ? 'Refreshing...' : 'Refresh'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</Button>
|
||||||
|
<DialogAddModel provider={provider} />
|
||||||
|
</>
|
||||||
)}
|
)}
|
||||||
{provider && provider.provider === 'llama.cpp' && (
|
{provider && provider.provider === 'llama.cpp' && (
|
||||||
<Button
|
<Button
|
||||||
|
|||||||
@ -101,6 +101,60 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
|
|||||||
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
return runtimeProviders.concat(builtinProviders as ModelProvider[])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches models from a provider's API endpoint
|
||||||
|
* @param provider The provider object containing base_url and api_key
|
||||||
|
* @returns Promise<string[]> Array of model IDs
|
||||||
|
*/
|
||||||
|
export const fetchModelsFromProvider = async (
|
||||||
|
provider: ModelProvider
|
||||||
|
): Promise<string[]> => {
|
||||||
|
if (!provider.base_url || !provider.api_key) {
|
||||||
|
throw new Error('Provider must have base_url and api_key configured')
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`${provider.base_url}/models`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
'x-api-key': provider.api_key,
|
||||||
|
'Authorization': `Bearer ${provider.api_key}`,
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(
|
||||||
|
`Failed to fetch models: ${response.status} ${response.statusText}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json()
|
||||||
|
|
||||||
|
// Handle different response formats that providers might use
|
||||||
|
if (data.data && Array.isArray(data.data)) {
|
||||||
|
// OpenAI format: { data: [{ id: "model-id" }, ...] }
|
||||||
|
return data.data.map((model: { id: string }) => model.id).filter(Boolean)
|
||||||
|
} else if (Array.isArray(data)) {
|
||||||
|
// Direct array format: ["model-id1", "model-id2", ...]
|
||||||
|
return data.filter(Boolean)
|
||||||
|
} else if (data.models && Array.isArray(data.models)) {
|
||||||
|
// Alternative format: { models: [...] }
|
||||||
|
return data.models
|
||||||
|
.map((model: string | { id: string }) =>
|
||||||
|
typeof model === 'string' ? model : model.id
|
||||||
|
)
|
||||||
|
.filter(Boolean)
|
||||||
|
} else {
|
||||||
|
console.warn('Unexpected response format from provider API:', data)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error fetching models from provider:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the settings of a provider extension.
|
* Update the settings of a provider extension.
|
||||||
* TODO: Later on we don't retrieve this using provider name
|
* TODO: Later on we don't retrieve this using provider name
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user