fix: bring back HF repo ID search in Hub (#5880)

* fix: bring back HF search input

* test: fix useModelSources tests for updated addSource signature
This commit is contained in:
Louis 2025-07-24 09:46:13 +07:00 committed by GitHub
parent d8b6b10870
commit 6599d91660
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 903 additions and 89 deletions

View File

@ -234,61 +234,149 @@ describe('useModelSources', () => {
})
describe('addSource', () => {
it('should log the source parameter', async () => {
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
it('should add a new source to the store', () => {
const { result } = renderHook(() => useModelSources())
await act(async () => {
await result.current.addSource('test-source')
})
expect(consoleSpy).toHaveBeenCalledWith('test-source')
consoleSpy.mockRestore()
})
it('should set loading state during addSource', async () => {
const { result } = renderHook(() => useModelSources())
await act(async () => {
await result.current.addSource('test-source')
})
expect(result.current.loading).toBe(true)
expect(result.current.error).toBe(null)
})
it('should handle different source types', async () => {
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
const { result } = renderHook(() => useModelSources())
const sources = [
'http://example.com/source1',
'https://secure.example.com/source2',
'file:///local/path/source3',
'custom-source-name',
]
for (const source of sources) {
await act(async () => {
await result.current.addSource(source)
})
expect(consoleSpy).toHaveBeenCalledWith(source)
const testModel: CatalogModel = {
model_name: 'test-model',
description: 'Test model description',
developer: 'test-developer',
downloads: 100,
num_quants: 2,
quants: [
{
model_id: 'test-model-q4',
path: 'https://example.com/test-model-q4.gguf',
file_size: '2.0 GB',
},
],
created_at: '2023-01-01T00:00:00Z',
}
consoleSpy.mockRestore()
})
it('should handle empty source string', async () => {
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
const { result } = renderHook(() => useModelSources())
await act(async () => {
await result.current.addSource('')
act(() => {
result.current.addSource(testModel)
})
expect(consoleSpy).toHaveBeenCalledWith('')
consoleSpy.mockRestore()
expect(result.current.sources).toHaveLength(1)
expect(result.current.sources[0]).toEqual(testModel)
})
it('should replace existing source with same model_name', () => {
const { result } = renderHook(() => useModelSources())
const originalModel: CatalogModel = {
model_name: 'duplicate-model',
description: 'Original description',
developer: 'original-developer',
downloads: 50,
num_quants: 1,
quants: [],
created_at: '2023-01-01T00:00:00Z',
}
const updatedModel: CatalogModel = {
model_name: 'duplicate-model',
description: 'Updated description',
developer: 'updated-developer',
downloads: 150,
num_quants: 2,
quants: [
{
model_id: 'duplicate-model-q4',
path: 'https://example.com/duplicate-model-q4.gguf',
file_size: '3.0 GB',
},
],
created_at: '2023-02-01T00:00:00Z',
}
act(() => {
result.current.addSource(originalModel)
})
expect(result.current.sources).toHaveLength(1)
act(() => {
result.current.addSource(updatedModel)
})
expect(result.current.sources).toHaveLength(1)
expect(result.current.sources[0]).toEqual(updatedModel)
})
it('should handle multiple different sources', () => {
const { result } = renderHook(() => useModelSources())
const model1: CatalogModel = {
model_name: 'model-1',
description: 'First model',
developer: 'developer-1',
downloads: 100,
num_quants: 1,
quants: [],
created_at: '2023-01-01T00:00:00Z',
}
const model2: CatalogModel = {
model_name: 'model-2',
description: 'Second model',
developer: 'developer-2',
downloads: 200,
num_quants: 1,
quants: [],
created_at: '2023-01-02T00:00:00Z',
}
act(() => {
result.current.addSource(model1)
})
act(() => {
result.current.addSource(model2)
})
expect(result.current.sources).toHaveLength(2)
expect(result.current.sources).toContainEqual(model1)
expect(result.current.sources).toContainEqual(model2)
})
it('should handle CatalogModel with complete quants data', () => {
const { result } = renderHook(() => useModelSources())
const modelWithQuants: CatalogModel = {
model_name: 'model-with-quants',
description: 'Model with quantizations',
developer: 'quant-developer',
downloads: 500,
num_quants: 3,
quants: [
{
model_id: 'model-q4_k_m',
path: 'https://example.com/model-q4_k_m.gguf',
file_size: '2.0 GB',
},
{
model_id: 'model-q8_0',
path: 'https://example.com/model-q8_0.gguf',
file_size: '4.0 GB',
},
{
model_id: 'model-f16',
path: 'https://example.com/model-f16.gguf',
file_size: '8.0 GB',
},
],
created_at: '2023-01-01T00:00:00Z',
readme: 'https://example.com/readme.md',
}
act(() => {
result.current.addSource(modelWithQuants)
})
expect(result.current.sources).toHaveLength(1)
expect(result.current.sources[0]).toEqual(modelWithQuants)
expect(result.current.sources[0].quants).toHaveLength(3)
})
})

View File

@ -8,8 +8,8 @@ type ModelSourcesState = {
sources: CatalogModel[]
error: Error | null
loading: boolean
addSource: (source: CatalogModel) => void
fetchSources: () => Promise<void>
addSource: (source: string) => Promise<void>
}
export const useModelSources = create<ModelSourcesState>()(
@ -19,6 +19,14 @@ export const useModelSources = create<ModelSourcesState>()(
error: null,
loading: false,
addSource: (source: CatalogModel) => {
set((state) => ({
sources: [
...state.sources.filter((e) => e.model_name !== source.model_name),
source,
],
}))
},
fetchSources: async () => {
set({ loading: true, error: null })
try {
@ -38,24 +46,6 @@ export const useModelSources = create<ModelSourcesState>()(
set({ error: error as Error, loading: false })
}
},
addSource: async (source: string) => {
set({ loading: true, error: null })
console.log(source)
// try {
// await addModelSource(source)
// const newSources = await fetchModelSources()
// const currentSources = get().sources
// if (!deepCompareModelSources(currentSources, newSources)) {
// set({ sources: newSources, loading: false })
// } else {
// set({ loading: false })
// }
// } catch (error) {
// set({ error: error as Error, loading: false })
// }
},
}),
{
name: localStorageKey.modelSources,

View File

@ -0,0 +1,307 @@
import { describe, it, expect } from 'vitest'
import { HuggingFaceRepo, CatalogModel } from '@/services/models'
// Helper function to test the conversion logic (extracted from the component)
const convertHfRepoToCatalogModel = (repo: HuggingFaceRepo): CatalogModel => {
// Extract GGUF files from the repository siblings
const ggufFiles =
repo.siblings?.filter((file) =>
file.rfilename.toLowerCase().endsWith('.gguf')
) || []
// Convert GGUF files to quants format
const quants = ggufFiles.map((file) => {
// Format file size
const formatFileSize = (size?: number) => {
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
// Generate model_id from filename (remove .gguf extension, case-insensitive)
const modelId = file.rfilename.replace(/\.gguf$/i, '')
return {
model_id: modelId,
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
file_size: formatFileSize(file.size),
}
})
return {
model_name: repo.modelId,
description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`,
developer: repo.author,
downloads: repo.downloads || 0,
num_quants: quants.length,
quants: quants,
created_at: repo.created_at,
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
}
}
describe('HuggingFace Repository Conversion', () => {
const mockHuggingFaceRepo: HuggingFaceRepo = {
id: 'microsoft/DialoGPT-medium',
modelId: 'microsoft/DialoGPT-medium',
sha: 'abc123',
downloads: 5000,
likes: 100,
tags: ['conversational', 'pytorch', 'text-generation'],
pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z',
private: false,
disabled: false,
gated: false,
author: 'microsoft',
siblings: [
{
rfilename: 'model-Q4_K_M.gguf',
size: 2147483648, // 2GB
blobId: 'blob123',
},
{
rfilename: 'model-Q8_0.gguf',
size: 4294967296, // 4GB
blobId: 'blob456',
},
{
rfilename: 'model-small.gguf',
size: 536870912, // 512MB
blobId: 'blob789',
},
{
rfilename: 'README.md',
size: 1024,
blobId: 'blob101',
},
],
readme: '# DialoGPT Model\nThis is a conversational AI model.',
}
describe('convertHfRepoToCatalogModel', () => {
it('should convert HuggingFace repository to CatalogModel correctly', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result).toEqual({
model_name: 'microsoft/DialoGPT-medium',
description: '**Metadata:** text-generation\n\n **Tags**: conversational, pytorch, text-generation',
developer: 'microsoft',
downloads: 5000,
num_quants: 3,
quants: [
{
model_id: 'model-Q4_K_M',
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf',
file_size: '2.0 GB',
},
{
model_id: 'model-Q8_0',
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf',
file_size: '4.0 GB',
},
{
model_id: 'model-small',
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-small.gguf',
file_size: '512.0 MB',
},
],
created_at: '2023-01-01T00:00:00Z',
readme: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
})
})
it('should filter only GGUF files from siblings', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
// Should have 3 GGUF files, not 4 total files
expect(result.num_quants).toBe(3)
expect(result.quants).toHaveLength(3)
// All quants should be from GGUF files
result.quants.forEach((quant) => {
expect(quant.path).toContain('.gguf')
})
})
it('should format file sizes correctly', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.quants[0].file_size).toBe('2.0 GB') // 2GB
expect(result.quants[1].file_size).toBe('4.0 GB') // 4GB
expect(result.quants[2].file_size).toBe('512.0 MB') // 512MB
})
it('should generate correct download paths', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.quants[0].path).toBe(
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf'
)
expect(result.quants[1].path).toBe(
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf'
)
})
it('should generate correct model IDs by removing .gguf extension', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.quants[0].model_id).toBe('model-Q4_K_M')
expect(result.quants[1].model_id).toBe('model-Q8_0')
expect(result.quants[2].model_id).toBe('model-small')
})
it('should handle repository with no siblings', () => {
const repoWithoutSiblings = {
...mockHuggingFaceRepo,
siblings: undefined,
}
const result = convertHfRepoToCatalogModel(repoWithoutSiblings)
expect(result.num_quants).toBe(0)
expect(result.quants).toEqual([])
})
it('should handle repository with empty siblings array', () => {
const repoWithEmptySiblings = {
...mockHuggingFaceRepo,
siblings: [],
}
const result = convertHfRepoToCatalogModel(repoWithEmptySiblings)
expect(result.num_quants).toBe(0)
expect(result.quants).toEqual([])
})
it('should handle repository with no GGUF files', () => {
const repoWithoutGGUF = {
...mockHuggingFaceRepo,
siblings: [
{
rfilename: 'README.md',
size: 1024,
blobId: 'blob101',
},
{
rfilename: 'config.json',
size: 512,
blobId: 'blob102',
},
],
}
const result = convertHfRepoToCatalogModel(repoWithoutGGUF)
expect(result.num_quants).toBe(0)
expect(result.quants).toEqual([])
})
it('should handle files with unknown sizes', () => {
const repoWithUnknownSizes = {
...mockHuggingFaceRepo,
siblings: [
{
rfilename: 'model-unknown.gguf',
size: undefined,
blobId: 'blob123',
},
],
}
const result = convertHfRepoToCatalogModel(repoWithUnknownSizes)
expect(result.quants[0].file_size).toBe('Unknown size')
})
it('should handle repository with zero downloads', () => {
const repoWithZeroDownloads = {
...mockHuggingFaceRepo,
downloads: 0,
}
const result = convertHfRepoToCatalogModel(repoWithZeroDownloads)
expect(result.downloads).toBe(0)
})
it('should handle repository with no tags', () => {
const repoWithoutTags = {
...mockHuggingFaceRepo,
tags: [],
}
const result = convertHfRepoToCatalogModel(repoWithoutTags)
expect(result.description).toContain('**Tags**: ')
})
it('should handle repository with no pipeline_tag', () => {
const repoWithoutPipelineTag = {
...mockHuggingFaceRepo,
pipeline_tag: undefined,
}
const result = convertHfRepoToCatalogModel(repoWithoutPipelineTag)
expect(result.description).toContain('**Metadata:** undefined')
})
it('should generate README URL correctly', () => {
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
expect(result.readme).toBe(
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
)
})
it('should handle case-insensitive GGUF file extensions', () => {
const repoWithMixedCase = {
...mockHuggingFaceRepo,
siblings: [
{
rfilename: 'model-uppercase.GGUF',
size: 1024,
blobId: 'blob1',
},
{
rfilename: 'model-mixedcase.Gguf',
size: 2048,
blobId: 'blob2',
},
{
rfilename: 'model-lowercase.gguf',
size: 4096,
blobId: 'blob3',
},
],
}
const result = convertHfRepoToCatalogModel(repoWithMixedCase)
expect(result.num_quants).toBe(3)
expect(result.quants[0].model_id).toBe('model-uppercase')
expect(result.quants[1].model_id).toBe('model-mixedcase')
expect(result.quants[2].model_id).toBe('model-lowercase')
})
it('should handle very large file sizes (> 1TB)', () => {
const repoWithLargeFiles = {
...mockHuggingFaceRepo,
siblings: [
{
rfilename: 'huge-model.gguf',
size: 1099511627776, // 1TB
blobId: 'blob1',
},
],
}
const result = convertHfRepoToCatalogModel(repoWithLargeFiles)
expect(result.quants[0].file_size).toBe('1024.0 GB')
})
})
})

View File

@ -26,7 +26,12 @@ import {
DropdownMenuItem,
DropdownMenuTrigger,
} from '@/components/ui/dropdown-menu'
import { CatalogModel, pullModel } from '@/services/models'
import {
CatalogModel,
pullModel,
fetchHuggingFaceRepo,
HuggingFaceRepo,
} from '@/services/models'
import { useDownloadStore } from '@/hooks/useDownloadStore'
import { Progress } from '@/components/ui/progress'
import HeaderPage from '@/containers/HeaderPage'
@ -54,7 +59,7 @@ function Hub() {
{ value: 'newest', name: t('hub:sortNewest') },
{ value: 'most-downloaded', name: t('hub:sortMostDownloaded') },
]
const { sources, fetchSources, addSource, loading } = useModelSources()
const { sources, addSource, fetchSources, loading } = useModelSources()
const search = useSearch({ from: route.hub.index as any })
const [searchValue, setSearchValue] = useState('')
const [sortSelected, setSortSelected] = useState('newest')
@ -63,6 +68,9 @@ function Hub() {
)
const [isSearching, setIsSearching] = useState(false)
const [showOnlyDownloaded, setShowOnlyDownloaded] = useState(false)
const [huggingFaceRepo, setHuggingFaceRepo] = useState<CatalogModel | null>(
null
)
const [joyrideReady, setJoyrideReady] = useState(false)
const [currentStepIndex, setCurrentStepIndex] = useState(0)
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
@ -74,6 +82,48 @@ function Hub() {
const { getProviderByName } = useModelProvider()
const llamaProvider = getProviderByName('llamacpp')
// Convert HuggingFace repository to CatalogModel format
const convertHfRepoToCatalogModel = useCallback(
(repo: HuggingFaceRepo): CatalogModel => {
// Extract GGUF files from the repository siblings
const ggufFiles =
repo.siblings?.filter((file) =>
file.rfilename.toLowerCase().endsWith('.gguf')
) || []
// Convert GGUF files to quants format
const quants = ggufFiles.map((file) => {
// Format file size
const formatFileSize = (size?: number) => {
if (!size) return 'Unknown size'
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
return `${(size / 1024 ** 3).toFixed(1)} GB`
}
// Generate model_id from filename (remove .gguf extension, case-insensitive)
const modelId = file.rfilename.replace(/\.gguf$/i, '')
return {
model_id: modelId,
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
file_size: formatFileSize(file.size),
}
})
return {
model_name: repo.modelId,
description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`,
developer: repo.author,
downloads: repo.downloads || 0,
num_quants: quants.length,
quants: quants,
created_at: repo.created_at,
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
}
},
[]
)
const toggleModelExpansion = (modelId: string) => {
setExpandedModels((prev) => ({
...prev,
@ -85,17 +135,26 @@ function Hub() {
if (search.repo) {
setSearchValue(search.repo || '')
setIsSearching(true)
addModelSourceTimeoutRef.current = setTimeout(() => {
addSource(search.repo)
.then(() => {
fetchSources()
})
.finally(() => {
setIsSearching(false)
})
addModelSourceTimeoutRef.current = setTimeout(async () => {
try {
// Fetch HuggingFace repository information
const repoInfo = await fetchHuggingFaceRepo(search.repo)
if (repoInfo) {
const catalogModel = convertHfRepoToCatalogModel(repoInfo)
setHuggingFaceRepo(catalogModel)
addSource(catalogModel)
}
await fetchSources()
} catch (error) {
console.error('Error fetching repository info:', error)
} finally {
setIsSearching(false)
}
}, 500)
}
}, [addSource, fetchSources, search])
}, [convertHfRepoToCatalogModel, fetchSources, addSource, search])
// Sorting functionality
const sortedModels = useMemo(() => {
@ -143,8 +202,19 @@ function Hub() {
)
}
// Add HuggingFace repo at the beginning if available
if (huggingFaceRepo) {
filtered = [huggingFaceRepo, ...filtered]
}
return filtered
}, [searchValue, sortedModels, showOnlyDownloaded, llamaProvider?.models])
}, [
searchValue,
sortedModels,
showOnlyDownloaded,
llamaProvider?.models,
huggingFaceRepo,
])
useEffect(() => {
fetchSources()
@ -153,22 +223,35 @@ function Hub() {
const handleSearchChange = (e: ChangeEvent<HTMLInputElement>) => {
setIsSearching(false)
setSearchValue(e.target.value)
setHuggingFaceRepo(null) // Clear previous repo info
if (addModelSourceTimeoutRef.current) {
clearTimeout(addModelSourceTimeoutRef.current)
}
if (
e.target.value.length &&
(e.target.value.includes('/') || e.target.value.startsWith('http'))
) {
setIsSearching(true)
addModelSourceTimeoutRef.current = setTimeout(() => {
addSource(e.target.value)
.then(() => {
fetchSources()
})
.finally(() => {
setIsSearching(false)
})
addModelSourceTimeoutRef.current = setTimeout(async () => {
try {
// Fetch HuggingFace repository information
const repoInfo = await fetchHuggingFaceRepo(e.target.value)
if (repoInfo) {
const catalogModel = convertHfRepoToCatalogModel(repoInfo)
setHuggingFaceRepo(catalogModel)
addSource(catalogModel)
}
// Original addSource logic (if needed)
await fetchSources()
} catch (error) {
console.error('Error fetching repository info:', error)
} finally {
setIsSearching(false)
}
}, 500)
}
}
@ -213,6 +296,25 @@ function Hub() {
const DownloadButtonPlaceholder = useMemo(() => {
return ({ model }: ModelProps) => {
// Check if this is a HuggingFace repository (no quants)
if (model.quants.length === 0) {
return (
<div className="flex items-center gap-2">
<Button
size="sm"
onClick={() => {
window.open(
`https://huggingface.co/${model.model_name}`,
'_blank'
)
}}
>
View on HuggingFace
</Button>
</div>
)
}
const quant =
model.quants.find((e) =>
defaultModelQuantizations.some((m) =>

View File

@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'
import {
fetchModels,
fetchModelCatalog,
fetchHuggingFaceRepo,
updateModel,
pullModel,
abortDownload,
@ -271,4 +272,259 @@ describe('models service', () => {
await expect(startModel(provider, model)).resolves.toBe(undefined)
})
})
describe('fetchHuggingFaceRepo', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('should fetch HuggingFace repository successfully with blobs=true', async () => {
const mockRepoData = {
id: 'microsoft/DialoGPT-medium',
modelId: 'microsoft/DialoGPT-medium',
sha: 'abc123',
downloads: 5000,
likes: 100,
tags: ['conversational', 'pytorch'],
pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z',
private: false,
disabled: false,
gated: false,
author: 'microsoft',
siblings: [
{
rfilename: 'model-Q4_K_M.gguf',
size: 2147483648,
blobId: 'blob123',
},
{
rfilename: 'model-Q8_0.gguf',
size: 4294967296,
blobId: 'blob456',
},
{
rfilename: 'README.md',
size: 1024,
blobId: 'blob789',
},
],
readme: '# DialoGPT Model\nThis is a conversational AI model.',
}
;(fetch as any).mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toEqual(mockRepoData)
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
)
})
it('should clean repository ID from various input formats', async () => {
const mockRepoData = { modelId: 'microsoft/DialoGPT-medium' }
;(fetch as any).mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue(mockRepoData),
})
// Test with full URL
await fetchHuggingFaceRepo('https://huggingface.co/microsoft/DialoGPT-medium')
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
)
// Test with domain prefix
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
)
// Test with trailing slash
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
)
})
it('should return null for invalid repository IDs', async () => {
// Test empty string
expect(await fetchHuggingFaceRepo('')).toBeNull()
// Test string without slash
expect(await fetchHuggingFaceRepo('invalid-repo')).toBeNull()
// Test whitespace only
expect(await fetchHuggingFaceRepo(' ')).toBeNull()
})
it('should return null for 404 responses', async () => {
;(fetch as any).mockResolvedValue({
ok: false,
status: 404,
statusText: 'Not Found',
})
const result = await fetchHuggingFaceRepo('nonexistent/model')
expect(result).toBeNull()
expect(fetch).toHaveBeenCalledWith(
'https://huggingface.co/api/models/nonexistent/model?blobs=true'
)
})
it('should handle other HTTP errors', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
;(fetch as any).mockResolvedValue({
ok: false,
status: 500,
statusText: 'Internal Server Error',
})
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toBeNull()
expect(consoleSpy).toHaveBeenCalledWith(
'Error fetching HuggingFace repository:',
expect.any(Error)
)
consoleSpy.mockRestore()
})
it('should handle network errors', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
;(fetch as any).mockRejectedValue(new Error('Network error'))
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toBeNull()
expect(consoleSpy).toHaveBeenCalledWith(
'Error fetching HuggingFace repository:',
expect.any(Error)
)
consoleSpy.mockRestore()
})
it('should handle repository with no siblings', async () => {
const mockRepoData = {
id: 'microsoft/DialoGPT-medium',
modelId: 'microsoft/DialoGPT-medium',
sha: 'abc123',
downloads: 5000,
likes: 100,
tags: ['conversational'],
pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z',
private: false,
disabled: false,
gated: false,
author: 'microsoft',
siblings: undefined,
}
;(fetch as any).mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toEqual(mockRepoData)
})
it('should handle repository with no GGUF files', async () => {
const mockRepoData = {
id: 'microsoft/DialoGPT-medium',
modelId: 'microsoft/DialoGPT-medium',
sha: 'abc123',
downloads: 5000,
likes: 100,
tags: ['conversational'],
pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z',
private: false,
disabled: false,
gated: false,
author: 'microsoft',
siblings: [
{
rfilename: 'README.md',
size: 1024,
blobId: 'blob789',
},
{
rfilename: 'config.json',
size: 512,
blobId: 'blob101',
},
],
}
;(fetch as any).mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toEqual(mockRepoData)
})
it('should handle repository with mixed file types including GGUF', async () => {
const mockRepoData = {
id: 'microsoft/DialoGPT-medium',
modelId: 'microsoft/DialoGPT-medium',
sha: 'abc123',
downloads: 5000,
likes: 100,
tags: ['conversational'],
pipeline_tag: 'text-generation',
created_at: '2023-01-01T00:00:00Z',
last_modified: '2023-12-01T00:00:00Z',
private: false,
disabled: false,
gated: false,
author: 'microsoft',
siblings: [
{
rfilename: 'model-Q4_K_M.gguf',
size: 2147483648, // 2GB
blobId: 'blob123',
},
{
rfilename: 'README.md',
size: 1024,
blobId: 'blob789',
},
{
rfilename: 'config.json',
size: 512,
blobId: 'blob101',
},
],
}
;(fetch as any).mockResolvedValue({
ok: true,
json: vi.fn().mockResolvedValue(mockRepoData),
})
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
expect(result).toEqual(mockRepoData)
// Verify the GGUF file is present in siblings
expect(result?.siblings?.some(s => s.rfilename.endsWith('.gguf'))).toBe(true)
})
})
})

View File

@ -25,6 +25,36 @@ export interface CatalogModel {
export type ModelCatalog = CatalogModel[]
// HuggingFace repository information
export interface HuggingFaceRepo {
id: string
modelId: string
sha: string
downloads: number
likes: number
library_name?: string
tags: string[]
pipeline_tag?: string
created_at: string
last_modified: string
private: boolean
disabled: boolean
gated: boolean | string
author: string
cardData?: {
license?: string
language?: string[]
datasets?: string[]
metrics?: string[]
}
siblings?: Array<{
rfilename: string
size?: number
blobId?: string
}>
readme?: string
}
// TODO: Replace this with the actual provider later
const defaultProvider = 'llamacpp'
@ -63,6 +93,47 @@ export const fetchModelCatalog = async (): Promise<ModelCatalog> => {
}
}
/**
* Fetches HuggingFace repository information.
* @param repoId The repository ID (e.g., "microsoft/DialoGPT-medium")
* @returns A promise that resolves to the repository information.
*/
export const fetchHuggingFaceRepo = async (
repoId: string
): Promise<HuggingFaceRepo | null> => {
try {
// Clean the repo ID to handle various input formats
const cleanRepoId = repoId
.replace(/^https?:\/\/huggingface\.co\//, '')
.replace(/^huggingface\.co\//, '')
.replace(/\/$/, '') // Remove trailing slash
.trim()
if (!cleanRepoId || !cleanRepoId.includes('/')) {
return null
}
const response = await fetch(
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true`
)
if (!response.ok) {
if (response.status === 404) {
return null // Repository not found
}
throw new Error(
`Failed to fetch HuggingFace repository: ${response.status} ${response.statusText}`
)
}
const repoData: HuggingFaceRepo = await response.json()
return repoData
} catch (error) {
console.error('Error fetching HuggingFace repository:', error)
return null
}
}
/**
* Updates a model.
* @param model The model to update.