From 6599d916606deb35dc087ac6c20df1a474da5dea Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 24 Jul 2025 09:46:13 +0700 Subject: [PATCH] fix: bring back HF repo ID search in Hub (#5880) * fix: bring back HF search input * test: fix useModelSources tests for updated addSource signature --- .../hooks/__tests__/useModelSources.test.ts | 188 ++++++++--- web-app/src/hooks/useModelSources.ts | 28 +- .../__tests__/huggingface-conversion.test.ts | 307 ++++++++++++++++++ web-app/src/routes/hub/index.tsx | 142 ++++++-- web-app/src/services/__tests__/models.test.ts | 256 +++++++++++++++ web-app/src/services/models.ts | 71 ++++ 6 files changed, 903 insertions(+), 89 deletions(-) create mode 100644 web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts diff --git a/web-app/src/hooks/__tests__/useModelSources.test.ts b/web-app/src/hooks/__tests__/useModelSources.test.ts index dfff0ba7c..bd06da434 100644 --- a/web-app/src/hooks/__tests__/useModelSources.test.ts +++ b/web-app/src/hooks/__tests__/useModelSources.test.ts @@ -234,61 +234,149 @@ describe('useModelSources', () => { }) describe('addSource', () => { - it('should log the source parameter', async () => { - const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {}) + it('should add a new source to the store', () => { const { result } = renderHook(() => useModelSources()) - - await act(async () => { - await result.current.addSource('test-source') - }) - - expect(consoleSpy).toHaveBeenCalledWith('test-source') - consoleSpy.mockRestore() - }) - - it('should set loading state during addSource', async () => { - const { result } = renderHook(() => useModelSources()) - - await act(async () => { - await result.current.addSource('test-source') - }) - - expect(result.current.loading).toBe(true) - expect(result.current.error).toBe(null) - }) - - it('should handle different source types', async () => { - const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {}) - const { result } = renderHook(() => useModelSources()) - - const sources = [ - 'http://example.com/source1', - 'https://secure.example.com/source2', - 'file:///local/path/source3', - 'custom-source-name', - ] - - for (const source of sources) { - await act(async () => { - await result.current.addSource(source) - }) - - expect(consoleSpy).toHaveBeenCalledWith(source) + + const testModel: CatalogModel = { + model_name: 'test-model', + description: 'Test model description', + developer: 'test-developer', + downloads: 100, + num_quants: 2, + quants: [ + { + model_id: 'test-model-q4', + path: 'https://example.com/test-model-q4.gguf', + file_size: '2.0 GB', + }, + ], + created_at: '2023-01-01T00:00:00Z', } - consoleSpy.mockRestore() - }) - - it('should handle empty source string', async () => { - const consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {}) - const { result } = renderHook(() => useModelSources()) - - await act(async () => { - await result.current.addSource('') + act(() => { + result.current.addSource(testModel) }) - expect(consoleSpy).toHaveBeenCalledWith('') - consoleSpy.mockRestore() + expect(result.current.sources).toHaveLength(1) + expect(result.current.sources[0]).toEqual(testModel) + }) + + it('should replace existing source with same model_name', () => { + const { result } = renderHook(() => useModelSources()) + + const originalModel: CatalogModel = { + model_name: 'duplicate-model', + description: 'Original description', + developer: 'original-developer', + downloads: 50, + num_quants: 1, + quants: [], + created_at: '2023-01-01T00:00:00Z', + } + + const updatedModel: CatalogModel = { + model_name: 'duplicate-model', + description: 'Updated description', + developer: 'updated-developer', + downloads: 150, + num_quants: 2, + quants: [ + { + model_id: 'duplicate-model-q4', + path: 'https://example.com/duplicate-model-q4.gguf', + file_size: '3.0 GB', + }, + ], + created_at: '2023-02-01T00:00:00Z', + } + + act(() => { + result.current.addSource(originalModel) + }) + + expect(result.current.sources).toHaveLength(1) + + act(() => { + result.current.addSource(updatedModel) + }) + + expect(result.current.sources).toHaveLength(1) + expect(result.current.sources[0]).toEqual(updatedModel) + }) + + it('should handle multiple different sources', () => { + const { result } = renderHook(() => useModelSources()) + + const model1: CatalogModel = { + model_name: 'model-1', + description: 'First model', + developer: 'developer-1', + downloads: 100, + num_quants: 1, + quants: [], + created_at: '2023-01-01T00:00:00Z', + } + + const model2: CatalogModel = { + model_name: 'model-2', + description: 'Second model', + developer: 'developer-2', + downloads: 200, + num_quants: 1, + quants: [], + created_at: '2023-01-02T00:00:00Z', + } + + act(() => { + result.current.addSource(model1) + }) + + act(() => { + result.current.addSource(model2) + }) + + expect(result.current.sources).toHaveLength(2) + expect(result.current.sources).toContainEqual(model1) + expect(result.current.sources).toContainEqual(model2) + }) + + it('should handle CatalogModel with complete quants data', () => { + const { result } = renderHook(() => useModelSources()) + + const modelWithQuants: CatalogModel = { + model_name: 'model-with-quants', + description: 'Model with quantizations', + developer: 'quant-developer', + downloads: 500, + num_quants: 3, + quants: [ + { + model_id: 'model-q4_k_m', + path: 'https://example.com/model-q4_k_m.gguf', + file_size: '2.0 GB', + }, + { + model_id: 'model-q8_0', + path: 'https://example.com/model-q8_0.gguf', + file_size: '4.0 GB', + }, + { + model_id: 'model-f16', + path: 'https://example.com/model-f16.gguf', + file_size: '8.0 GB', + }, + ], + created_at: '2023-01-01T00:00:00Z', + readme: 'https://example.com/readme.md', + } + + act(() => { + result.current.addSource(modelWithQuants) + }) + + expect(result.current.sources).toHaveLength(1) + expect(result.current.sources[0]).toEqual(modelWithQuants) + expect(result.current.sources[0].quants).toHaveLength(3) }) }) diff --git a/web-app/src/hooks/useModelSources.ts b/web-app/src/hooks/useModelSources.ts index e85815b23..263d6a0dd 100644 --- a/web-app/src/hooks/useModelSources.ts +++ b/web-app/src/hooks/useModelSources.ts @@ -8,8 +8,8 @@ type ModelSourcesState = { sources: CatalogModel[] error: Error | null loading: boolean + addSource: (source: CatalogModel) => void fetchSources: () => Promise - addSource: (source: string) => Promise } export const useModelSources = create()( @@ -19,6 +19,14 @@ export const useModelSources = create()( error: null, loading: false, + addSource: (source: CatalogModel) => { + set((state) => ({ + sources: [ + ...state.sources.filter((e) => e.model_name !== source.model_name), + source, + ], + })) + }, fetchSources: async () => { set({ loading: true, error: null }) try { @@ -38,24 +46,6 @@ export const useModelSources = create()( set({ error: error as Error, loading: false }) } }, - - addSource: async (source: string) => { - set({ loading: true, error: null }) - console.log(source) - // try { - // await addModelSource(source) - // const newSources = await fetchModelSources() - // const currentSources = get().sources - - // if (!deepCompareModelSources(currentSources, newSources)) { - // set({ sources: newSources, loading: false }) - // } else { - // set({ loading: false }) - // } - // } catch (error) { - // set({ error: error as Error, loading: false }) - // } - }, }), { name: localStorageKey.modelSources, diff --git a/web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts b/web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts new file mode 100644 index 000000000..23b65b9ef --- /dev/null +++ b/web-app/src/routes/hub/__tests__/huggingface-conversion.test.ts @@ -0,0 +1,307 @@ +import { describe, it, expect } from 'vitest' +import { HuggingFaceRepo, CatalogModel } from '@/services/models' + +// Helper function to test the conversion logic (extracted from the component) +const convertHfRepoToCatalogModel = (repo: HuggingFaceRepo): CatalogModel => { + // Extract GGUF files from the repository siblings + const ggufFiles = + repo.siblings?.filter((file) => + file.rfilename.toLowerCase().endsWith('.gguf') + ) || [] + + // Convert GGUF files to quants format + const quants = ggufFiles.map((file) => { + // Format file size + const formatFileSize = (size?: number) => { + if (!size) return 'Unknown size' + if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB` + return `${(size / 1024 ** 3).toFixed(1)} GB` + } + + // Generate model_id from filename (remove .gguf extension, case-insensitive) + const modelId = file.rfilename.replace(/\.gguf$/i, '') + + return { + model_id: modelId, + path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`, + file_size: formatFileSize(file.size), + } + }) + + return { + model_name: repo.modelId, + description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`, + developer: repo.author, + downloads: repo.downloads || 0, + num_quants: quants.length, + quants: quants, + created_at: repo.created_at, + readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`, + } +} + +describe('HuggingFace Repository Conversion', () => { + const mockHuggingFaceRepo: HuggingFaceRepo = { + id: 'microsoft/DialoGPT-medium', + modelId: 'microsoft/DialoGPT-medium', + sha: 'abc123', + downloads: 5000, + likes: 100, + tags: ['conversational', 'pytorch', 'text-generation'], + pipeline_tag: 'text-generation', + created_at: '2023-01-01T00:00:00Z', + last_modified: '2023-12-01T00:00:00Z', + private: false, + disabled: false, + gated: false, + author: 'microsoft', + siblings: [ + { + rfilename: 'model-Q4_K_M.gguf', + size: 2147483648, // 2GB + blobId: 'blob123', + }, + { + rfilename: 'model-Q8_0.gguf', + size: 4294967296, // 4GB + blobId: 'blob456', + }, + { + rfilename: 'model-small.gguf', + size: 536870912, // 512MB + blobId: 'blob789', + }, + { + rfilename: 'README.md', + size: 1024, + blobId: 'blob101', + }, + ], + readme: '# DialoGPT Model\nThis is a conversational AI model.', + } + + describe('convertHfRepoToCatalogModel', () => { + it('should convert HuggingFace repository to CatalogModel correctly', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + expect(result).toEqual({ + model_name: 'microsoft/DialoGPT-medium', + description: '**Metadata:** text-generation\n\n **Tags**: conversational, pytorch, text-generation', + developer: 'microsoft', + downloads: 5000, + num_quants: 3, + quants: [ + { + model_id: 'model-Q4_K_M', + path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf', + file_size: '2.0 GB', + }, + { + model_id: 'model-Q8_0', + path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf', + file_size: '4.0 GB', + }, + { + model_id: 'model-small', + path: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-small.gguf', + file_size: '512.0 MB', + }, + ], + created_at: '2023-01-01T00:00:00Z', + readme: 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md', + }) + }) + + it('should filter only GGUF files from siblings', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + // Should have 3 GGUF files, not 4 total files + expect(result.num_quants).toBe(3) + expect(result.quants).toHaveLength(3) + + // All quants should be from GGUF files + result.quants.forEach((quant) => { + expect(quant.path).toContain('.gguf') + }) + }) + + it('should format file sizes correctly', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + expect(result.quants[0].file_size).toBe('2.0 GB') // 2GB + expect(result.quants[1].file_size).toBe('4.0 GB') // 4GB + expect(result.quants[2].file_size).toBe('512.0 MB') // 512MB + }) + + it('should generate correct download paths', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + expect(result.quants[0].path).toBe( + 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q4_K_M.gguf' + ) + expect(result.quants[1].path).toBe( + 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/model-Q8_0.gguf' + ) + }) + + it('should generate correct model IDs by removing .gguf extension', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + expect(result.quants[0].model_id).toBe('model-Q4_K_M') + expect(result.quants[1].model_id).toBe('model-Q8_0') + expect(result.quants[2].model_id).toBe('model-small') + }) + + it('should handle repository with no siblings', () => { + const repoWithoutSiblings = { + ...mockHuggingFaceRepo, + siblings: undefined, + } + + const result = convertHfRepoToCatalogModel(repoWithoutSiblings) + + expect(result.num_quants).toBe(0) + expect(result.quants).toEqual([]) + }) + + it('should handle repository with empty siblings array', () => { + const repoWithEmptySiblings = { + ...mockHuggingFaceRepo, + siblings: [], + } + + const result = convertHfRepoToCatalogModel(repoWithEmptySiblings) + + expect(result.num_quants).toBe(0) + expect(result.quants).toEqual([]) + }) + + it('should handle repository with no GGUF files', () => { + const repoWithoutGGUF = { + ...mockHuggingFaceRepo, + siblings: [ + { + rfilename: 'README.md', + size: 1024, + blobId: 'blob101', + }, + { + rfilename: 'config.json', + size: 512, + blobId: 'blob102', + }, + ], + } + + const result = convertHfRepoToCatalogModel(repoWithoutGGUF) + + expect(result.num_quants).toBe(0) + expect(result.quants).toEqual([]) + }) + + it('should handle files with unknown sizes', () => { + const repoWithUnknownSizes = { + ...mockHuggingFaceRepo, + siblings: [ + { + rfilename: 'model-unknown.gguf', + size: undefined, + blobId: 'blob123', + }, + ], + } + + const result = convertHfRepoToCatalogModel(repoWithUnknownSizes) + + expect(result.quants[0].file_size).toBe('Unknown size') + }) + + it('should handle repository with zero downloads', () => { + const repoWithZeroDownloads = { + ...mockHuggingFaceRepo, + downloads: 0, + } + + const result = convertHfRepoToCatalogModel(repoWithZeroDownloads) + + expect(result.downloads).toBe(0) + }) + + it('should handle repository with no tags', () => { + const repoWithoutTags = { + ...mockHuggingFaceRepo, + tags: [], + } + + const result = convertHfRepoToCatalogModel(repoWithoutTags) + + expect(result.description).toContain('**Tags**: ') + }) + + it('should handle repository with no pipeline_tag', () => { + const repoWithoutPipelineTag = { + ...mockHuggingFaceRepo, + pipeline_tag: undefined, + } + + const result = convertHfRepoToCatalogModel(repoWithoutPipelineTag) + + expect(result.description).toContain('**Metadata:** undefined') + }) + + it('should generate README URL correctly', () => { + const result = convertHfRepoToCatalogModel(mockHuggingFaceRepo) + + expect(result.readme).toBe( + 'https://huggingface.co/microsoft/DialoGPT-medium/resolve/main/README.md' + ) + }) + + it('should handle case-insensitive GGUF file extensions', () => { + const repoWithMixedCase = { + ...mockHuggingFaceRepo, + siblings: [ + { + rfilename: 'model-uppercase.GGUF', + size: 1024, + blobId: 'blob1', + }, + { + rfilename: 'model-mixedcase.Gguf', + size: 2048, + blobId: 'blob2', + }, + { + rfilename: 'model-lowercase.gguf', + size: 4096, + blobId: 'blob3', + }, + ], + } + + const result = convertHfRepoToCatalogModel(repoWithMixedCase) + + expect(result.num_quants).toBe(3) + expect(result.quants[0].model_id).toBe('model-uppercase') + expect(result.quants[1].model_id).toBe('model-mixedcase') + expect(result.quants[2].model_id).toBe('model-lowercase') + }) + + it('should handle very large file sizes (> 1TB)', () => { + const repoWithLargeFiles = { + ...mockHuggingFaceRepo, + siblings: [ + { + rfilename: 'huge-model.gguf', + size: 1099511627776, // 1TB + blobId: 'blob1', + }, + ], + } + + const result = convertHfRepoToCatalogModel(repoWithLargeFiles) + + expect(result.quants[0].file_size).toBe('1024.0 GB') + }) + }) +}) \ No newline at end of file diff --git a/web-app/src/routes/hub/index.tsx b/web-app/src/routes/hub/index.tsx index 845dbcccb..af1eade6c 100644 --- a/web-app/src/routes/hub/index.tsx +++ b/web-app/src/routes/hub/index.tsx @@ -26,7 +26,12 @@ import { DropdownMenuItem, DropdownMenuTrigger, } from '@/components/ui/dropdown-menu' -import { CatalogModel, pullModel } from '@/services/models' +import { + CatalogModel, + pullModel, + fetchHuggingFaceRepo, + HuggingFaceRepo, +} from '@/services/models' import { useDownloadStore } from '@/hooks/useDownloadStore' import { Progress } from '@/components/ui/progress' import HeaderPage from '@/containers/HeaderPage' @@ -54,7 +59,7 @@ function Hub() { { value: 'newest', name: t('hub:sortNewest') }, { value: 'most-downloaded', name: t('hub:sortMostDownloaded') }, ] - const { sources, fetchSources, addSource, loading } = useModelSources() + const { sources, addSource, fetchSources, loading } = useModelSources() const search = useSearch({ from: route.hub.index as any }) const [searchValue, setSearchValue] = useState('') const [sortSelected, setSortSelected] = useState('newest') @@ -63,6 +68,9 @@ function Hub() { ) const [isSearching, setIsSearching] = useState(false) const [showOnlyDownloaded, setShowOnlyDownloaded] = useState(false) + const [huggingFaceRepo, setHuggingFaceRepo] = useState( + null + ) const [joyrideReady, setJoyrideReady] = useState(false) const [currentStepIndex, setCurrentStepIndex] = useState(0) const addModelSourceTimeoutRef = useRef | null>( @@ -74,6 +82,48 @@ function Hub() { const { getProviderByName } = useModelProvider() const llamaProvider = getProviderByName('llamacpp') + // Convert HuggingFace repository to CatalogModel format + const convertHfRepoToCatalogModel = useCallback( + (repo: HuggingFaceRepo): CatalogModel => { + // Extract GGUF files from the repository siblings + const ggufFiles = + repo.siblings?.filter((file) => + file.rfilename.toLowerCase().endsWith('.gguf') + ) || [] + + // Convert GGUF files to quants format + const quants = ggufFiles.map((file) => { + // Format file size + const formatFileSize = (size?: number) => { + if (!size) return 'Unknown size' + if (size < 1024 ** 3) return `${(size / 1024 ** 2).toFixed(1)} MB` + return `${(size / 1024 ** 3).toFixed(1)} GB` + } + + // Generate model_id from filename (remove .gguf extension, case-insensitive) + const modelId = file.rfilename.replace(/\.gguf$/i, '') + + return { + model_id: modelId, + path: `https://huggingface.co/${repo.modelId}/resolve/main/${file.rfilename}`, + file_size: formatFileSize(file.size), + } + }) + + return { + model_name: repo.modelId, + description: `**Metadata:** ${repo.pipeline_tag}\n\n **Tags**: ${repo.tags?.join(', ')}`, + developer: repo.author, + downloads: repo.downloads || 0, + num_quants: quants.length, + quants: quants, + created_at: repo.created_at, + readme: `https://huggingface.co/${repo.modelId}/resolve/main/README.md`, + } + }, + [] + ) + const toggleModelExpansion = (modelId: string) => { setExpandedModels((prev) => ({ ...prev, @@ -85,17 +135,26 @@ function Hub() { if (search.repo) { setSearchValue(search.repo || '') setIsSearching(true) - addModelSourceTimeoutRef.current = setTimeout(() => { - addSource(search.repo) - .then(() => { - fetchSources() - }) - .finally(() => { - setIsSearching(false) - }) + + addModelSourceTimeoutRef.current = setTimeout(async () => { + try { + // Fetch HuggingFace repository information + const repoInfo = await fetchHuggingFaceRepo(search.repo) + if (repoInfo) { + const catalogModel = convertHfRepoToCatalogModel(repoInfo) + setHuggingFaceRepo(catalogModel) + addSource(catalogModel) + } + + await fetchSources() + } catch (error) { + console.error('Error fetching repository info:', error) + } finally { + setIsSearching(false) + } }, 500) } - }, [addSource, fetchSources, search]) + }, [convertHfRepoToCatalogModel, fetchSources, addSource, search]) // Sorting functionality const sortedModels = useMemo(() => { @@ -143,8 +202,19 @@ function Hub() { ) } + // Add HuggingFace repo at the beginning if available + if (huggingFaceRepo) { + filtered = [huggingFaceRepo, ...filtered] + } + return filtered - }, [searchValue, sortedModels, showOnlyDownloaded, llamaProvider?.models]) + }, [ + searchValue, + sortedModels, + showOnlyDownloaded, + llamaProvider?.models, + huggingFaceRepo, + ]) useEffect(() => { fetchSources() @@ -153,22 +223,35 @@ function Hub() { const handleSearchChange = (e: ChangeEvent) => { setIsSearching(false) setSearchValue(e.target.value) + setHuggingFaceRepo(null) // Clear previous repo info + if (addModelSourceTimeoutRef.current) { clearTimeout(addModelSourceTimeoutRef.current) } + if ( e.target.value.length && (e.target.value.includes('/') || e.target.value.startsWith('http')) ) { setIsSearching(true) - addModelSourceTimeoutRef.current = setTimeout(() => { - addSource(e.target.value) - .then(() => { - fetchSources() - }) - .finally(() => { - setIsSearching(false) - }) + + addModelSourceTimeoutRef.current = setTimeout(async () => { + try { + // Fetch HuggingFace repository information + const repoInfo = await fetchHuggingFaceRepo(e.target.value) + if (repoInfo) { + const catalogModel = convertHfRepoToCatalogModel(repoInfo) + setHuggingFaceRepo(catalogModel) + addSource(catalogModel) + } + + // Original addSource logic (if needed) + await fetchSources() + } catch (error) { + console.error('Error fetching repository info:', error) + } finally { + setIsSearching(false) + } }, 500) } } @@ -213,6 +296,25 @@ function Hub() { const DownloadButtonPlaceholder = useMemo(() => { return ({ model }: ModelProps) => { + // Check if this is a HuggingFace repository (no quants) + if (model.quants.length === 0) { + return ( +
+ +
+ ) + } + const quant = model.quants.find((e) => defaultModelQuantizations.some((m) => diff --git a/web-app/src/services/__tests__/models.test.ts b/web-app/src/services/__tests__/models.test.ts index d5e38b034..c6f626911 100644 --- a/web-app/src/services/__tests__/models.test.ts +++ b/web-app/src/services/__tests__/models.test.ts @@ -3,6 +3,7 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' import { fetchModels, fetchModelCatalog, + fetchHuggingFaceRepo, updateModel, pullModel, abortDownload, @@ -271,4 +272,259 @@ describe('models service', () => { await expect(startModel(provider, model)).resolves.toBe(undefined) }) }) + + describe('fetchHuggingFaceRepo', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should fetch HuggingFace repository successfully with blobs=true', async () => { + const mockRepoData = { + id: 'microsoft/DialoGPT-medium', + modelId: 'microsoft/DialoGPT-medium', + sha: 'abc123', + downloads: 5000, + likes: 100, + tags: ['conversational', 'pytorch'], + pipeline_tag: 'text-generation', + created_at: '2023-01-01T00:00:00Z', + last_modified: '2023-12-01T00:00:00Z', + private: false, + disabled: false, + gated: false, + author: 'microsoft', + siblings: [ + { + rfilename: 'model-Q4_K_M.gguf', + size: 2147483648, + blobId: 'blob123', + }, + { + rfilename: 'model-Q8_0.gguf', + size: 4294967296, + blobId: 'blob456', + }, + { + rfilename: 'README.md', + size: 1024, + blobId: 'blob789', + }, + ], + readme: '# DialoGPT Model\nThis is a conversational AI model.', + } + + ;(fetch as any).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockRepoData), + }) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toEqual(mockRepoData) + expect(fetch).toHaveBeenCalledWith( + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true' + ) + }) + + it('should clean repository ID from various input formats', async () => { + const mockRepoData = { modelId: 'microsoft/DialoGPT-medium' } + ;(fetch as any).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockRepoData), + }) + + // Test with full URL + await fetchHuggingFaceRepo('https://huggingface.co/microsoft/DialoGPT-medium') + expect(fetch).toHaveBeenCalledWith( + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true' + ) + + // Test with domain prefix + await fetchHuggingFaceRepo('huggingface.co/microsoft/DialoGPT-medium') + expect(fetch).toHaveBeenCalledWith( + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true' + ) + + // Test with trailing slash + await fetchHuggingFaceRepo('microsoft/DialoGPT-medium/') + expect(fetch).toHaveBeenCalledWith( + 'https://huggingface.co/api/models/microsoft/DialoGPT-medium?blobs=true' + ) + }) + + it('should return null for invalid repository IDs', async () => { + // Test empty string + expect(await fetchHuggingFaceRepo('')).toBeNull() + + // Test string without slash + expect(await fetchHuggingFaceRepo('invalid-repo')).toBeNull() + + // Test whitespace only + expect(await fetchHuggingFaceRepo(' ')).toBeNull() + }) + + it('should return null for 404 responses', async () => { + ;(fetch as any).mockResolvedValue({ + ok: false, + status: 404, + statusText: 'Not Found', + }) + + const result = await fetchHuggingFaceRepo('nonexistent/model') + + expect(result).toBeNull() + expect(fetch).toHaveBeenCalledWith( + 'https://huggingface.co/api/models/nonexistent/model?blobs=true' + ) + }) + + it('should handle other HTTP errors', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + ;(fetch as any).mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + }) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toBeNull() + expect(consoleSpy).toHaveBeenCalledWith( + 'Error fetching HuggingFace repository:', + expect.any(Error) + ) + + consoleSpy.mockRestore() + }) + + it('should handle network errors', async () => { + const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) + + ;(fetch as any).mockRejectedValue(new Error('Network error')) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toBeNull() + expect(consoleSpy).toHaveBeenCalledWith( + 'Error fetching HuggingFace repository:', + expect.any(Error) + ) + + consoleSpy.mockRestore() + }) + + it('should handle repository with no siblings', async () => { + const mockRepoData = { + id: 'microsoft/DialoGPT-medium', + modelId: 'microsoft/DialoGPT-medium', + sha: 'abc123', + downloads: 5000, + likes: 100, + tags: ['conversational'], + pipeline_tag: 'text-generation', + created_at: '2023-01-01T00:00:00Z', + last_modified: '2023-12-01T00:00:00Z', + private: false, + disabled: false, + gated: false, + author: 'microsoft', + siblings: undefined, + } + + ;(fetch as any).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockRepoData), + }) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toEqual(mockRepoData) + }) + + it('should handle repository with no GGUF files', async () => { + const mockRepoData = { + id: 'microsoft/DialoGPT-medium', + modelId: 'microsoft/DialoGPT-medium', + sha: 'abc123', + downloads: 5000, + likes: 100, + tags: ['conversational'], + pipeline_tag: 'text-generation', + created_at: '2023-01-01T00:00:00Z', + last_modified: '2023-12-01T00:00:00Z', + private: false, + disabled: false, + gated: false, + author: 'microsoft', + siblings: [ + { + rfilename: 'README.md', + size: 1024, + blobId: 'blob789', + }, + { + rfilename: 'config.json', + size: 512, + blobId: 'blob101', + }, + ], + } + + ;(fetch as any).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockRepoData), + }) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toEqual(mockRepoData) + }) + + it('should handle repository with mixed file types including GGUF', async () => { + const mockRepoData = { + id: 'microsoft/DialoGPT-medium', + modelId: 'microsoft/DialoGPT-medium', + sha: 'abc123', + downloads: 5000, + likes: 100, + tags: ['conversational'], + pipeline_tag: 'text-generation', + created_at: '2023-01-01T00:00:00Z', + last_modified: '2023-12-01T00:00:00Z', + private: false, + disabled: false, + gated: false, + author: 'microsoft', + siblings: [ + { + rfilename: 'model-Q4_K_M.gguf', + size: 2147483648, // 2GB + blobId: 'blob123', + }, + { + rfilename: 'README.md', + size: 1024, + blobId: 'blob789', + }, + { + rfilename: 'config.json', + size: 512, + blobId: 'blob101', + }, + ], + } + + ;(fetch as any).mockResolvedValue({ + ok: true, + json: vi.fn().mockResolvedValue(mockRepoData), + }) + + const result = await fetchHuggingFaceRepo('microsoft/DialoGPT-medium') + + expect(result).toEqual(mockRepoData) + // Verify the GGUF file is present in siblings + expect(result?.siblings?.some(s => s.rfilename.endsWith('.gguf'))).toBe(true) + }) + }) }) diff --git a/web-app/src/services/models.ts b/web-app/src/services/models.ts index f38afa06f..71911244f 100644 --- a/web-app/src/services/models.ts +++ b/web-app/src/services/models.ts @@ -25,6 +25,36 @@ export interface CatalogModel { export type ModelCatalog = CatalogModel[] +// HuggingFace repository information +export interface HuggingFaceRepo { + id: string + modelId: string + sha: string + downloads: number + likes: number + library_name?: string + tags: string[] + pipeline_tag?: string + created_at: string + last_modified: string + private: boolean + disabled: boolean + gated: boolean | string + author: string + cardData?: { + license?: string + language?: string[] + datasets?: string[] + metrics?: string[] + } + siblings?: Array<{ + rfilename: string + size?: number + blobId?: string + }> + readme?: string +} + // TODO: Replace this with the actual provider later const defaultProvider = 'llamacpp' @@ -63,6 +93,47 @@ export const fetchModelCatalog = async (): Promise => { } } +/** + * Fetches HuggingFace repository information. + * @param repoId The repository ID (e.g., "microsoft/DialoGPT-medium") + * @returns A promise that resolves to the repository information. + */ +export const fetchHuggingFaceRepo = async ( + repoId: string +): Promise => { + try { + // Clean the repo ID to handle various input formats + const cleanRepoId = repoId + .replace(/^https?:\/\/huggingface\.co\//, '') + .replace(/^huggingface\.co\//, '') + .replace(/\/$/, '') // Remove trailing slash + .trim() + + if (!cleanRepoId || !cleanRepoId.includes('/')) { + return null + } + + const response = await fetch( + `https://huggingface.co/api/models/${cleanRepoId}?blobs=true` + ) + + if (!response.ok) { + if (response.status === 404) { + return null // Repository not found + } + throw new Error( + `Failed to fetch HuggingFace repository: ${response.status} ${response.statusText}` + ) + } + + const repoData: HuggingFaceRepo = await response.json() + return repoData + } catch (error) { + console.error('Error fetching HuggingFace repository:', error) + return null + } +} + /** * Updates a model. * @param model The model to update.