fix: bring back HF repo ID search in Hub (#5880)
* fix: bring back HF search input * test: fix useModelSources tests for updated addSource signature
This commit is contained in:
parent
d8b6b10870
commit
6599d91660
@ -234,61 +234,149 @@ describe('useModelSources', () => {
|
||||
})
|
||||
|
||||
describe('addSource', () => {
|
||||
it('should log the source parameter', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||
it('should add a new source to the store', () => {
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.addSource('test-source')
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith('test-source')
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should set loading state during addSource', async () => {
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.addSource('test-source')
|
||||
})
|
||||
|
||||
expect(result.current.loading).toBe(true)
|
||||
expect(result.current.error).toBe(null)
|
||||
})
|
||||
|
||||
it('should handle different source types', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
const sources = [
|
||||
'http://example.com/source1',
|
||||
'https://secure.example.com/source2',
|
||||
'file:///local/path/source3',
|
||||
'custom-source-name',
|
||||
]
|
||||
|
||||
for (const source of sources) {
|
||||
await act(async () => {
|
||||
await result.current.addSource(source)
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(source)
|
||||
|
||||
const testModel: CatalogModel = {
|
||||
model_name: 'test-model',
|
||||
description: 'Test model description',
|
||||
developer: 'test-developer',
|
||||
downloads: 100,
|
||||
num_quants: 2,
|
||||
quants: [
|
||||
{
|
||||
model_id: 'test-model-q4',
|
||||
path: 'https://example.com/test-model-q4.gguf',
|
||||
file_size: '2.0 GB',
|
||||
},
|
||||
],
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
}
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle empty source string', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.addSource('')
|
||||
act(() => {
|
||||
result.current.addSource(testModel)
|
||||
})
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith('')
|
||||
consoleSpy.mockRestore()
|
||||
expect(result.current.sources).toHaveLength(1)
|
||||
expect(result.current.sources[0]).toEqual(testModel)
|
||||
})
|
||||
|
||||
it('should replace existing source with same model_name', () => {
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
const originalModel: CatalogModel = {
|
||||
model_name: 'duplicate-model',
|
||||
description: 'Original description',
|
||||
developer: 'original-developer',
|
||||
downloads: 50,
|
||||
num_quants: 1,
|
||||
quants: [],
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
}
|
||||
|
||||
const updatedModel: CatalogModel = {
|
||||
model_name: 'duplicate-model',
|
||||
description: 'Updated description',
|
||||
developer: 'updated-developer',
|
||||
downloads: 150,
|
||||
num_quants: 2,
|
||||
quants: [
|
||||
{
|
||||
model_id: 'duplicate-model-q4',
|
||||
path: 'https://example.com/duplicate-model-q4.gguf',
|
||||
file_size: '3.0 GB',
|
||||
},
|
||||
],
|
||||
created_at: '2023-02-01T00:00:00Z',
|
||||
}
|
||||
|
||||
act(() => {
|
||||
result.current.addSource(originalModel)
|
||||
})
|
||||
|
||||
expect(result.current.sources).toHaveLength(1)
|
||||
|
||||
act(() => {
|
||||
result.current.addSource(updatedModel)
|
||||
})
|
||||
|
||||
expect(result.current.sources).toHaveLength(1)
|
||||
expect(result.current.sources[0]).toEqual(updatedModel)
|
||||
})
|
||||
|
||||
it('should handle multiple different sources', () => {
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
const model1: CatalogModel = {
|
||||
model_name: 'model-1',
|
||||
description: 'First model',
|
||||
developer: 'developer-1',
|
||||
downloads: 100,
|
||||
num_quants: 1,
|
||||
quants: [],
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
}
|
||||
|
||||
const model2: CatalogModel = {
|
||||
model_name: 'model-2',
|
||||
description: 'Second model',
|
||||
developer: 'developer-2',
|
||||
downloads: 200,
|
||||
num_quants: 1,
|
||||
quants: [],
|
||||
created_at: '2023-01-02T00:00:00Z',
|
||||
}
|
||||
|
||||
act(() => {
|
||||
result.current.addSource(model1)
|
||||
})
|
||||
|
||||
act(() => {
|
||||
result.current.addSource(model2)
|
||||
})
|
||||
|
||||
expect(result.current.sources).toHaveLength(2)
|
||||
expect(result.current.sources).toContainEqual(model1)
|
||||
expect(result.current.sources).toContainEqual(model2)
|
||||
})
|
||||
|
||||
it('should handle CatalogModel with complete quants data', () => {
|
||||
const { result } = renderHook(() => useModelSources())
|
||||
|
||||
const modelWithQuants: CatalogModel = {
|
||||
model_name: 'model-with-quants',
|
||||
description: 'Model with quantizations',
|
||||
developer: 'quant-developer',
|
||||
downloads: 500,
|
||||
num_quants: 3,
|
||||
quants: [
|
||||
{
|
||||
model_id: 'model-q4_k_m',
|
||||
path: 'https://example.com/model-q4_k_m.gguf',
|
||||
file_size: '2.0 GB',
|
||||
},
|
||||
{
|
||||
model_id: 'model-q8_0',
|
||||
path: 'https://example.com/model-q8_0.gguf',
|
||||
file_size: '4.0 GB',
|
||||
},
|
||||
{
|
||||
model_id: 'model-f16',
|
||||
path: 'https://example.com/model-f16.gguf',
|
||||
file_size: '8.0 GB',
|
||||
},
|
||||
],
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
readme: 'https://example.com/readme.md',
|
||||
}
|
||||
|
||||
act(() => {
|
||||
result.current.addSource(modelWithQuants)
|
||||
})
|
||||
|
||||
expect(result.current.sources).toHaveLength(1)
|
||||
expect(result.current.sources[0]).toEqual(modelWithQuants)
|
||||
expect(result.current.sources[0].quants).toHaveLength(3)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -8,8 +8,8 @@ type ModelSourcesState = {
|
||||
sources: CatalogModel[]
|
||||
error: Error | null
|
||||
loading: boolean
|
||||
addSource: (source: CatalogModel) => void
|
||||
fetchSources: () => Promise<void>
|
||||
addSource: (source: string) => Promise<void>
|
||||
}
|
||||
|
||||
export const useModelSources = create<ModelSourcesState>()(
|
||||
@ -19,6 +19,14 @@ export const useModelSources = create<ModelSourcesState>()(
|
||||
error: null,
|
||||
loading: false,
|
||||
|
||||
addSource: (source: CatalogModel) => {
|
||||
set((state) => ({
|
||||
sources: [
|
||||
...state.sources.filter((e) => e.model_name !== source.model_name),
|
||||
source,
|
||||
],
|
||||
}))
|
||||
},
|
||||
fetchSources: async () => {
|
||||
set({ loading: true, error: null })
|
||||
try {
|
||||
@ -38,24 +46,6 @@ export const useModelSources = create<ModelSourcesState>()(
|
||||
set({ error: error as Error, loading: false })
|
||||
}
|
||||
},
|
||||
|
||||
addSource: async (source: string) => {
|
||||
set({ loading: true, error: null })
|
||||
console.log(source)
|
||||
// try {
|
||||
// await addModelSource(source)
|
||||
// const newSources = await fetchModelSources()
|
||||
// const currentSources = get().sources
|
||||
|
||||
// if (!deepCompareModelSources(currentSources, newSources)) {
|
||||
// set({ sources: newSources, loading: false })
|
||||
// } else {
|
||||
// set({ loading: false })
|
||||
// }
|
||||
// } catch (error) {
|
||||
// set({ error: error as Error, loading: false })
|
||||
// }
|
||||
},
|
||||
}),
|
||||
{
|
||||
name: localStorageKey.modelSources,
|
||||
|
||||
307
web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts
Normal file
307
web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts
Normal file
@ -0,0 +1,307 @@
|
||||
import { describe, it, expect } from 'vitest'
|
||||
import { HuggingFaceRepo, CatalogModel } from '@/services/models'
|
||||
|
||||
// Helper function to test the conversion logic (extracted from the component)
|
||||
const convertHfRepoToCatalogModel = (repo: HuggingFaceRepo): CatalogModel => {
|
||||
// Extract GGUF files from the repository siblings
|
||||
const ggufFiles =
|
||||
repo.siblings?.filter((file) =>
|
||||
file.rfilename.toLowerCase().endsWith('.gguf')
|
||||
) || []
|
||||
|
||||
// Convert GGUF files to quants format
|
||||
const quants = ggufFiles.map((file) => {
|
||||
// Format file size
|
||||
const formatFileSize = (size?: number) => {
|
||||
if (!size) return 'Unknown size'
|
||||
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
|
||||
return `${(size / 1024 ** 3).toFixed(1)} GB`
|
||||
}
|
||||
|
||||
// Generate model_id from filename (remove .gguf extension, case-insensitive)
|
||||
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||
|
||||
return {
|
||||
model_id: modelId,
|
||||
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
||||
file_size: formatFileSize(file.size),
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
model_name: repo.modelId,
|
||||
description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`,
|
||||
developer: repo.author,
|
||||
downloads: repo.downloads || 0,
|
||||
num_quants: quants.length,
|
||||
quants: quants,
|
||||
created_at: repo.created_at,
|
||||
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
|
||||
}
|
||||
}
|
||||
|
||||
describe('HuggingFace Repository Conversion', () => {
|
||||
const mockHuggingFaceRepo: HuggingFaceRepo = {
|
||||
id: 'microsoft/DialoGPT-medium',
|
||||
modelId: 'microsoft/DialoGPT-medium',
|
||||
sha: 'abc123',
|
||||
downloads: 5000,
|
||||
likes: 100,
|
||||
tags: ['conversational', 'pytorch', 'text-generation'],
|
||||
pipeline_tag: 'text-generation',
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
last_modified: '2023-12-01T00:00:00Z',
|
||||
private: false,
|
||||
disabled: false,
|
||||
gated: false,
|
||||
author: 'microsoft',
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'model-Q4_K_M.gguf',
|
||||
size: 2147483648, // 2GB
|
||||
blobId: 'blob123',
|
||||
},
|
||||
{
|
||||
rfilename: 'model-Q8_0.gguf',
|
||||
size: 4294967296, // 4GB
|
||||
blobId: 'blob456',
|
||||
},
|
||||
{
|
||||
rfilename: 'model-small.gguf',
|
||||
size: 536870912, // 512MB
|
||||
blobId: 'blob789',
|
||||
},
|
||||
{
|
||||
rfilename: 'README.md',
|
||||
size: 1024,
|
||||
blobId: 'blob101',
|
||||
},
|
||||
],
|
||||
readme: '# DialoGPT Model\nThis is a conversational AI model.',
|
||||
}
|
||||
|
||||
describe('convertHfRepoToCatalogModel', () => {
|
||||
it('should convert HuggingFace repository to CatalogModel correctly', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result).toEqual({
|
||||
model_name: 'microsoft/DialoGPT-medium',
|
||||
description: '**Metadata:** text-generation\n\n **Tags**: conversational, pytorch, text-generation',
|
||||
developer: 'microsoft',
|
||||
downloads: 5000,
|
||||
num_quants: 3,
|
||||
quants: [
|
||||
{
|
||||
model_id: 'model-Q4_K_M',
|
||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf',
|
||||
file_size: '2.0 GB',
|
||||
},
|
||||
{
|
||||
model_id: 'model-Q8_0',
|
||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf',
|
||||
file_size: '4.0 GB',
|
||||
},
|
||||
{
|
||||
model_id: 'model-small',
|
||||
path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-small.gguf',
|
||||
file_size: '512.0 MB',
|
||||
},
|
||||
],
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
readme: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md',
|
||||
})
|
||||
})
|
||||
|
||||
it('should filter only GGUF files from siblings', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
// Should have 3 GGUF files, not 4 total files
|
||||
expect(result.num_quants).toBe(3)
|
||||
expect(result.quants).toHaveLength(3)
|
||||
|
||||
// All quants should be from GGUF files
|
||||
result.quants.forEach((quant) => {
|
||||
expect(quant.path).toContain('.gguf')
|
||||
})
|
||||
})
|
||||
|
||||
it('should format file sizes correctly', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.quants[0].file_size).toBe('2.0 GB') // 2GB
|
||||
expect(result.quants[1].file_size).toBe('4.0 GB') // 4GB
|
||||
expect(result.quants[2].file_size).toBe('512.0 MB') // 512MB
|
||||
})
|
||||
|
||||
it('should generate correct download paths', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.quants[0].path).toBe(
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf'
|
||||
)
|
||||
expect(result.quants[1].path).toBe(
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf'
|
||||
)
|
||||
})
|
||||
|
||||
it('should generate correct model IDs by removing .gguf extension', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.quants[0].model_id).toBe('model-Q4_K_M')
|
||||
expect(result.quants[1].model_id).toBe('model-Q8_0')
|
||||
expect(result.quants[2].model_id).toBe('model-small')
|
||||
})
|
||||
|
||||
it('should handle repository with no siblings', () => {
|
||||
const repoWithoutSiblings = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: undefined,
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithoutSiblings)
|
||||
|
||||
expect(result.num_quants).toBe(0)
|
||||
expect(result.quants).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle repository with empty siblings array', () => {
|
||||
const repoWithEmptySiblings = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: [],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithEmptySiblings)
|
||||
|
||||
expect(result.num_quants).toBe(0)
|
||||
expect(result.quants).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle repository with no GGUF files', () => {
|
||||
const repoWithoutGGUF = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'README.md',
|
||||
size: 1024,
|
||||
blobId: 'blob101',
|
||||
},
|
||||
{
|
||||
rfilename: 'config.json',
|
||||
size: 512,
|
||||
blobId: 'blob102',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithoutGGUF)
|
||||
|
||||
expect(result.num_quants).toBe(0)
|
||||
expect(result.quants).toEqual([])
|
||||
})
|
||||
|
||||
it('should handle files with unknown sizes', () => {
|
||||
const repoWithUnknownSizes = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'model-unknown.gguf',
|
||||
size: undefined,
|
||||
blobId: 'blob123',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithUnknownSizes)
|
||||
|
||||
expect(result.quants[0].file_size).toBe('Unknown size')
|
||||
})
|
||||
|
||||
it('should handle repository with zero downloads', () => {
|
||||
const repoWithZeroDownloads = {
|
||||
...mockHuggingFaceRepo,
|
||||
downloads: 0,
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithZeroDownloads)
|
||||
|
||||
expect(result.downloads).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle repository with no tags', () => {
|
||||
const repoWithoutTags = {
|
||||
...mockHuggingFaceRepo,
|
||||
tags: [],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithoutTags)
|
||||
|
||||
expect(result.description).toContain('**Tags**: ')
|
||||
})
|
||||
|
||||
it('should handle repository with no pipeline_tag', () => {
|
||||
const repoWithoutPipelineTag = {
|
||||
...mockHuggingFaceRepo,
|
||||
pipeline_tag: undefined,
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithoutPipelineTag)
|
||||
|
||||
expect(result.description).toContain('**Metadata:** undefined')
|
||||
})
|
||||
|
||||
it('should generate README URL correctly', () => {
|
||||
const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo)
|
||||
|
||||
expect(result.readme).toBe(
|
||||
'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle case-insensitive GGUF file extensions', () => {
|
||||
const repoWithMixedCase = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'model-uppercase.GGUF',
|
||||
size: 1024,
|
||||
blobId: 'blob1',
|
||||
},
|
||||
{
|
||||
rfilename: 'model-mixedcase.Gguf',
|
||||
size: 2048,
|
||||
blobId: 'blob2',
|
||||
},
|
||||
{
|
||||
rfilename: 'model-lowercase.gguf',
|
||||
size: 4096,
|
||||
blobId: 'blob3',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithMixedCase)
|
||||
|
||||
expect(result.num_quants).toBe(3)
|
||||
expect(result.quants[0].model_id).toBe('model-uppercase')
|
||||
expect(result.quants[1].model_id).toBe('model-mixedcase')
|
||||
expect(result.quants[2].model_id).toBe('model-lowercase')
|
||||
})
|
||||
|
||||
it('should handle very large file sizes (> 1TB)', () => {
|
||||
const repoWithLargeFiles = {
|
||||
...mockHuggingFaceRepo,
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'huge-model.gguf',
|
||||
size: 1099511627776, // 1TB
|
||||
blobId: 'blob1',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
const result = convertHfRepoToCatalogModel(repoWithLargeFiles)
|
||||
|
||||
expect(result.quants[0].file_size).toBe('1024.0 GB')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -26,7 +26,12 @@ import {
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from '@/components/ui/dropdown-menu'
|
||||
import { CatalogModel, pullModel } from '@/services/models'
|
||||
import {
|
||||
CatalogModel,
|
||||
pullModel,
|
||||
fetchHuggingFaceRepo,
|
||||
HuggingFaceRepo,
|
||||
} from '@/services/models'
|
||||
import { useDownloadStore } from '@/hooks/useDownloadStore'
|
||||
import { Progress } from '@/components/ui/progress'
|
||||
import HeaderPage from '@/containers/HeaderPage'
|
||||
@ -54,7 +59,7 @@ function Hub() {
|
||||
{ value: 'newest', name: t('hub:sortNewest') },
|
||||
{ value: 'most-downloaded', name: t('hub:sortMostDownloaded') },
|
||||
]
|
||||
const { sources, fetchSources, addSource, loading } = useModelSources()
|
||||
const { sources, addSource, fetchSources, loading } = useModelSources()
|
||||
const search = useSearch({ from: route.hub.index as any })
|
||||
const [searchValue, setSearchValue] = useState('')
|
||||
const [sortSelected, setSortSelected] = useState('newest')
|
||||
@ -63,6 +68,9 @@ function Hub() {
|
||||
)
|
||||
const [isSearching, setIsSearching] = useState(false)
|
||||
const [showOnlyDownloaded, setShowOnlyDownloaded] = useState(false)
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState<CatalogModel | null>(
|
||||
null
|
||||
)
|
||||
const [joyrideReady, setJoyrideReady] = useState(false)
|
||||
const [currentStepIndex, setCurrentStepIndex] = useState(0)
|
||||
const addModelSourceTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
@ -74,6 +82,48 @@ function Hub() {
|
||||
const { getProviderByName } = useModelProvider()
|
||||
const llamaProvider = getProviderByName('llamacpp')
|
||||
|
||||
// Convert HuggingFace repository to CatalogModel format
|
||||
const convertHfRepoToCatalogModel = useCallback(
|
||||
(repo: HuggingFaceRepo): CatalogModel => {
|
||||
// Extract GGUF files from the repository siblings
|
||||
const ggufFiles =
|
||||
repo.siblings?.filter((file) =>
|
||||
file.rfilename.toLowerCase().endsWith('.gguf')
|
||||
) || []
|
||||
|
||||
// Convert GGUF files to quants format
|
||||
const quants = ggufFiles.map((file) => {
|
||||
// Format file size
|
||||
const formatFileSize = (size?: number) => {
|
||||
if (!size) return 'Unknown size'
|
||||
if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB`
|
||||
return `${(size / 1024 ** 3).toFixed(1)} GB`
|
||||
}
|
||||
|
||||
// Generate model_id from filename (remove .gguf extension, case-insensitive)
|
||||
const modelId = file.rfilename.replace(/\.gguf$/i, '')
|
||||
|
||||
return {
|
||||
model_id: modelId,
|
||||
path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`,
|
||||
file_size: formatFileSize(file.size),
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
model_name: repo.modelId,
|
||||
description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`,
|
||||
developer: repo.author,
|
||||
downloads: repo.downloads || 0,
|
||||
num_quants: quants.length,
|
||||
quants: quants,
|
||||
created_at: repo.created_at,
|
||||
readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`,
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const toggleModelExpansion = (modelId: string) => {
|
||||
setExpandedModels((prev) => ({
|
||||
...prev,
|
||||
@ -85,17 +135,26 @@ function Hub() {
|
||||
if (search.repo) {
|
||||
setSearchValue(search.repo || '')
|
||||
setIsSearching(true)
|
||||
addModelSourceTimeoutRef.current = setTimeout(() => {
|
||||
addSource(search.repo)
|
||||
.then(() => {
|
||||
fetchSources()
|
||||
})
|
||||
.finally(() => {
|
||||
setIsSearching(false)
|
||||
})
|
||||
|
||||
addModelSourceTimeoutRef.current = setTimeout(async () => {
|
||||
try {
|
||||
// Fetch HuggingFace repository information
|
||||
const repoInfo = await fetchHuggingFaceRepo(search.repo)
|
||||
if (repoInfo) {
|
||||
const catalogModel = convertHfRepoToCatalogModel(repoInfo)
|
||||
setHuggingFaceRepo(catalogModel)
|
||||
addSource(catalogModel)
|
||||
}
|
||||
|
||||
await fetchSources()
|
||||
} catch (error) {
|
||||
console.error('Error fetching repository info:', error)
|
||||
} finally {
|
||||
setIsSearching(false)
|
||||
}
|
||||
}, 500)
|
||||
}
|
||||
}, [addSource, fetchSources, search])
|
||||
}, [convertHfRepoToCatalogModel, fetchSources, addSource, search])
|
||||
|
||||
// Sorting functionality
|
||||
const sortedModels = useMemo(() => {
|
||||
@ -143,8 +202,19 @@ function Hub() {
|
||||
)
|
||||
}
|
||||
|
||||
// Add HuggingFace repo at the beginning if available
|
||||
if (huggingFaceRepo) {
|
||||
filtered = [huggingFaceRepo, ...filtered]
|
||||
}
|
||||
|
||||
return filtered
|
||||
}, [searchValue, sortedModels, showOnlyDownloaded, llamaProvider?.models])
|
||||
}, [
|
||||
searchValue,
|
||||
sortedModels,
|
||||
showOnlyDownloaded,
|
||||
llamaProvider?.models,
|
||||
huggingFaceRepo,
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
fetchSources()
|
||||
@ -153,22 +223,35 @@ function Hub() {
|
||||
const handleSearchChange = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
setIsSearching(false)
|
||||
setSearchValue(e.target.value)
|
||||
setHuggingFaceRepo(null) // Clear previous repo info
|
||||
|
||||
if (addModelSourceTimeoutRef.current) {
|
||||
clearTimeout(addModelSourceTimeoutRef.current)
|
||||
}
|
||||
|
||||
if (
|
||||
e.target.value.length &&
|
||||
(e.target.value.includes('/') || e.target.value.startsWith('http'))
|
||||
) {
|
||||
setIsSearching(true)
|
||||
addModelSourceTimeoutRef.current = setTimeout(() => {
|
||||
addSource(e.target.value)
|
||||
.then(() => {
|
||||
fetchSources()
|
||||
})
|
||||
.finally(() => {
|
||||
setIsSearching(false)
|
||||
})
|
||||
|
||||
addModelSourceTimeoutRef.current = setTimeout(async () => {
|
||||
try {
|
||||
// Fetch HuggingFace repository information
|
||||
const repoInfo = await fetchHuggingFaceRepo(e.target.value)
|
||||
if (repoInfo) {
|
||||
const catalogModel = convertHfRepoToCatalogModel(repoInfo)
|
||||
setHuggingFaceRepo(catalogModel)
|
||||
addSource(catalogModel)
|
||||
}
|
||||
|
||||
// Original addSource logic (if needed)
|
||||
await fetchSources()
|
||||
} catch (error) {
|
||||
console.error('Error fetching repository info:', error)
|
||||
} finally {
|
||||
setIsSearching(false)
|
||||
}
|
||||
}, 500)
|
||||
}
|
||||
}
|
||||
@ -213,6 +296,25 @@ function Hub() {
|
||||
|
||||
const DownloadButtonPlaceholder = useMemo(() => {
|
||||
return ({ model }: ModelProps) => {
|
||||
// Check if this is a HuggingFace repository (no quants)
|
||||
if (model.quants.length === 0) {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
window.open(
|
||||
`https://huggingface.co/${model.model_name}`,
|
||||
'_blank'
|
||||
)
|
||||
}}
|
||||
>
|
||||
View on HuggingFace
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const quant =
|
||||
model.quants.find((e) =>
|
||||
defaultModelQuantizations.some((m) =>
|
||||
|
||||
@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
import {
|
||||
fetchModels,
|
||||
fetchModelCatalog,
|
||||
fetchHuggingFaceRepo,
|
||||
updateModel,
|
||||
pullModel,
|
||||
abortDownload,
|
||||
@ -271,4 +272,259 @@ describe('models service', () => {
|
||||
await expect(startModel(provider, model)).resolves.toBe(undefined)
|
||||
})
|
||||
})
|
||||
|
||||
describe('fetchHuggingFaceRepo', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('should fetch HuggingFace repository successfully with blobs=true', async () => {
|
||||
const mockRepoData = {
|
||||
id: 'microsoft/DialoGPT-medium',
|
||||
modelId: 'microsoft/DialoGPT-medium',
|
||||
sha: 'abc123',
|
||||
downloads: 5000,
|
||||
likes: 100,
|
||||
tags: ['conversational', 'pytorch'],
|
||||
pipeline_tag: 'text-generation',
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
last_modified: '2023-12-01T00:00:00Z',
|
||||
private: false,
|
||||
disabled: false,
|
||||
gated: false,
|
||||
author: 'microsoft',
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'model-Q4_K_M.gguf',
|
||||
size: 2147483648,
|
||||
blobId: 'blob123',
|
||||
},
|
||||
{
|
||||
rfilename: 'model-Q8_0.gguf',
|
||||
size: 4294967296,
|
||||
blobId: 'blob456',
|
||||
},
|
||||
{
|
||||
rfilename: 'README.md',
|
||||
size: 1024,
|
||||
blobId: 'blob789',
|
||||
},
|
||||
],
|
||||
readme: '# DialoGPT Model\nThis is a conversational AI model.',
|
||||
}
|
||||
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
|
||||
)
|
||||
})
|
||||
|
||||
it('should clean repository ID from various input formats', async () => {
|
||||
const mockRepoData = { modelId: 'microsoft/DialoGPT-medium' }
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
// Test with full URL
|
||||
await fetchHuggingFaceRepo('https://huggingface.co/microsoft/DialoGPT-medium')
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
|
||||
)
|
||||
|
||||
// Test with domain prefix
|
||||
await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium')
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
|
||||
)
|
||||
|
||||
// Test with trailing slash
|
||||
await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/')
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true'
|
||||
)
|
||||
})
|
||||
|
||||
it('should return null for invalid repository IDs', async () => {
|
||||
// Test empty string
|
||||
expect(await fetchHuggingFaceRepo('')).toBeNull()
|
||||
|
||||
// Test string without slash
|
||||
expect(await fetchHuggingFaceRepo('invalid-repo')).toBeNull()
|
||||
|
||||
// Test whitespace only
|
||||
expect(await fetchHuggingFaceRepo(' ')).toBeNull()
|
||||
})
|
||||
|
||||
it('should return null for 404 responses', async () => {
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('nonexistent/model')
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
'https://huggingface.co/api/models/nonexistent/model?blobs=true'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle other HTTP errors', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
statusText: 'Internal Server Error',
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Error fetching HuggingFace repository:',
|
||||
expect.any(Error)
|
||||
)
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle network errors', async () => {
|
||||
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
;(fetch as any).mockRejectedValue(new Error('Network error'))
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toBeNull()
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
'Error fetching HuggingFace repository:',
|
||||
expect.any(Error)
|
||||
)
|
||||
|
||||
consoleSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('should handle repository with no siblings', async () => {
|
||||
const mockRepoData = {
|
||||
id: 'microsoft/DialoGPT-medium',
|
||||
modelId: 'microsoft/DialoGPT-medium',
|
||||
sha: 'abc123',
|
||||
downloads: 5000,
|
||||
likes: 100,
|
||||
tags: ['conversational'],
|
||||
pipeline_tag: 'text-generation',
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
last_modified: '2023-12-01T00:00:00Z',
|
||||
private: false,
|
||||
disabled: false,
|
||||
gated: false,
|
||||
author: 'microsoft',
|
||||
siblings: undefined,
|
||||
}
|
||||
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
})
|
||||
|
||||
it('should handle repository with no GGUF files', async () => {
|
||||
const mockRepoData = {
|
||||
id: 'microsoft/DialoGPT-medium',
|
||||
modelId: 'microsoft/DialoGPT-medium',
|
||||
sha: 'abc123',
|
||||
downloads: 5000,
|
||||
likes: 100,
|
||||
tags: ['conversational'],
|
||||
pipeline_tag: 'text-generation',
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
last_modified: '2023-12-01T00:00:00Z',
|
||||
private: false,
|
||||
disabled: false,
|
||||
gated: false,
|
||||
author: 'microsoft',
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'README.md',
|
||||
size: 1024,
|
||||
blobId: 'blob789',
|
||||
},
|
||||
{
|
||||
rfilename: 'config.json',
|
||||
size: 512,
|
||||
blobId: 'blob101',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
})
|
||||
|
||||
it('should handle repository with mixed file types including GGUF', async () => {
|
||||
const mockRepoData = {
|
||||
id: 'microsoft/DialoGPT-medium',
|
||||
modelId: 'microsoft/DialoGPT-medium',
|
||||
sha: 'abc123',
|
||||
downloads: 5000,
|
||||
likes: 100,
|
||||
tags: ['conversational'],
|
||||
pipeline_tag: 'text-generation',
|
||||
created_at: '2023-01-01T00:00:00Z',
|
||||
last_modified: '2023-12-01T00:00:00Z',
|
||||
private: false,
|
||||
disabled: false,
|
||||
gated: false,
|
||||
author: 'microsoft',
|
||||
siblings: [
|
||||
{
|
||||
rfilename: 'model-Q4_K_M.gguf',
|
||||
size: 2147483648, // 2GB
|
||||
blobId: 'blob123',
|
||||
},
|
||||
{
|
||||
rfilename: 'README.md',
|
||||
size: 1024,
|
||||
blobId: 'blob789',
|
||||
},
|
||||
{
|
||||
rfilename: 'config.json',
|
||||
size: 512,
|
||||
blobId: 'blob101',
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
;(fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
json: vi.fn().mockResolvedValue(mockRepoData),
|
||||
})
|
||||
|
||||
const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium')
|
||||
|
||||
expect(result).toEqual(mockRepoData)
|
||||
// Verify the GGUF file is present in siblings
|
||||
expect(result?.siblings?.some(s => s.rfilename.endsWith('.gguf'))).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -25,6 +25,36 @@ export interface CatalogModel {
|
||||
|
||||
export type ModelCatalog = CatalogModel[]
|
||||
|
||||
// HuggingFace repository information
|
||||
export interface HuggingFaceRepo {
|
||||
id: string
|
||||
modelId: string
|
||||
sha: string
|
||||
downloads: number
|
||||
likes: number
|
||||
library_name?: string
|
||||
tags: string[]
|
||||
pipeline_tag?: string
|
||||
created_at: string
|
||||
last_modified: string
|
||||
private: boolean
|
||||
disabled: boolean
|
||||
gated: boolean | string
|
||||
author: string
|
||||
cardData?: {
|
||||
license?: string
|
||||
language?: string[]
|
||||
datasets?: string[]
|
||||
metrics?: string[]
|
||||
}
|
||||
siblings?: Array<{
|
||||
rfilename: string
|
||||
size?: number
|
||||
blobId?: string
|
||||
}>
|
||||
readme?: string
|
||||
}
|
||||
|
||||
// TODO: Replace this with the actual provider later
|
||||
const defaultProvider = 'llamacpp'
|
||||
|
||||
@ -63,6 +93,47 @@ export const fetchModelCatalog = async (): Promise<ModelCatalog> => {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches HuggingFace repository information.
|
||||
* @param repoId The repository ID (e.g., "microsoft/DialoGPT-medium")
|
||||
* @returns A promise that resolves to the repository information.
|
||||
*/
|
||||
export const fetchHuggingFaceRepo = async (
|
||||
repoId: string
|
||||
): Promise<HuggingFaceRepo | null> => {
|
||||
try {
|
||||
// Clean the repo ID to handle various input formats
|
||||
const cleanRepoId = repoId
|
||||
.replace(/^https?:\/\/huggingface\.co\//, '')
|
||||
.replace(/^huggingface\.co\//, '')
|
||||
.replace(/\/$/, '') // Remove trailing slash
|
||||
.trim()
|
||||
|
||||
if (!cleanRepoId || !cleanRepoId.includes('/')) {
|
||||
return null
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`https://huggingface.co/api/models/${cleanRepoId}?blobs=true`
|
||||
)
|
||||
|
||||
if (!response.ok) {
|
||||
if (response.status === 404) {
|
||||
return null // Repository not found
|
||||
}
|
||||
throw new Error(
|
||||
`Failed to fetch HuggingFace repository: ${response.status} ${response.statusText}`
|
||||
)
|
||||
}
|
||||
|
||||
const repoData: HuggingFaceRepo = await response.json()
|
||||
return repoData
|
||||
} catch (error) {
|
||||
console.error('Error fetching HuggingFace repository:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates a model.
|
||||
* @param model The model to update.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user