diff --git a/web/utils/huggingface.test.ts b/web/utils/huggingface.test.ts new file mode 100644 index 000000000..db7dbf3e1 --- /dev/null +++ b/web/utils/huggingface.test.ts @@ -0,0 +1,96 @@ +import { + fetchHuggingFaceRepoData, + toHuggingFaceUrl, + InvalidHostError, +} from './huggingface' +import { getFileSize } from '@janhq/core' + +// Mock the getFileSize function +jest.mock('@janhq/core', () => ({ + getFileSize: jest.fn(), + AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'], +})) + +describe('huggingface utils', () => { + let originalFetch: typeof global.fetch + + beforeAll(() => { + originalFetch = global.fetch + global.fetch = jest.fn() + }) + + afterAll(() => { + global.fetch = originalFetch + }) + + beforeEach(() => { + jest.resetAllMocks() + }) + + describe('fetchHuggingFaceRepoData', () => { + it('should fetch and process repo data correctly', async () => { + const mockResponse = { + tags: ['gguf'], + siblings: [ + { rfilename: 'model-q4_0.gguf' }, + { rfilename: 'model-q8_0.gguf' }, + ], + } + + ;(global.fetch as jest.Mock).mockResolvedValue({ + json: jest.fn().mockResolvedValue(mockResponse), + }) + ;(getFileSize as jest.Mock).mockResolvedValue(1000000) + + const result = await fetchHuggingFaceRepoData('user/repo') + + expect(result.tags).toEqual(['gguf']) + expect(result.siblings).toHaveLength(2) + expect(result.siblings[0].fileSize).toBe(1000000) + expect(result.siblings[0].quantization).toBe('q4_0') + expect(result.modelUrl).toBe('https://huggingface.co/user/repo') + }) + + it('should throw an error if the model is not GGUF', async () => { + const mockResponse = { + tags: ['not-gguf'], + } + + ;(global.fetch as jest.Mock).mockResolvedValue({ + json: jest.fn().mockResolvedValue(mockResponse), + }) + + await expect(fetchHuggingFaceRepoData('user/repo')).rejects.toThrow( + 'user/repo is not supported. Only GGUF models are supported.' + ) + }) + + // ... existing code ... + }) + + describe('toHuggingFaceUrl', () => { + it('should convert a valid repo ID to a Hugging Face API URL', () => { + expect(toHuggingFaceUrl('user/repo')).toBe( + 'https://huggingface.co/api/models/user/repo' + ) + }) + + it('should handle a full Hugging Face URL', () => { + expect(toHuggingFaceUrl('https://huggingface.co/user/repo')).toBe( + 'https://huggingface.co/api/models/user/repo' + ) + }) + + it('should throw an InvalidHostError for non-Hugging Face URLs', () => { + expect(() => toHuggingFaceUrl('https://example.com/user/repo')).toThrow( + InvalidHostError + ) + }) + + it('should throw an error for invalid URLs', () => { + expect(() => toHuggingFaceUrl('https://invalid-url')).toThrow( + 'Invalid Hugging Face repo URL: https://invalid-url' + ) + }) + }) +}) diff --git a/web/utils/huggingface.ts b/web/utils/huggingface.ts index 328d684e6..ceddc6867 100644 --- a/web/utils/huggingface.ts +++ b/web/utils/huggingface.ts @@ -60,7 +60,7 @@ export const fetchHuggingFaceRepoData = async ( return data } -function toHuggingFaceUrl(repoId: string): string { +export function toHuggingFaceUrl(repoId: string): string { try { const url = new URL(repoId) if (url.host !== 'huggingface.co') { @@ -85,7 +85,7 @@ function toHuggingFaceUrl(repoId: string): string { return `https://huggingface.co/api/models/${repoId}` } } -class InvalidHostError extends Error { +export class InvalidHostError extends Error { constructor(message: string) { super(message) this.name = 'InvalidHostError'