Merge pull request #4133 from janhq/fix/4096-failed-to-get-huggingface-models

fix: 4096 - failed to get huggingface models
This commit is contained in:
Louis 2024-11-27 00:53:34 +07:00 committed by GitHub
commit 3a9a8dad3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 21 additions and 72 deletions

View File

@ -3,7 +3,6 @@ import { joinPath } from './core'
import { openFileExplorer } from './core' import { openFileExplorer } from './core'
import { getJanDataFolderPath } from './core' import { getJanDataFolderPath } from './core'
import { abortDownload } from './core' import { abortDownload } from './core'
import { getFileSize } from './core'
import { executeOnMain } from './core' import { executeOnMain } from './core'
describe('test core apis', () => { describe('test core apis', () => {
@ -66,18 +65,6 @@ describe('test core apis', () => {
expect(result).toBe('aborted') expect(result).toBe('aborted')
}) })
it('should get file size', async () => {
const url = 'http://example.com/file'
globalThis.core = {
api: {
getFileSize: jest.fn().mockResolvedValue(1024),
},
}
const result = await getFileSize(url)
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
expect(result).toBe(1024)
})
it('should execute function on main process', async () => { it('should execute function on main process', async () => {
const extension = 'testExtension' const extension = 'testExtension'
const method = 'testMethod' const method = 'testMethod'

View File

@ -28,15 +28,6 @@ const downloadFile: (downloadRequest: DownloadRequest, network?: NetworkConfig)
network network
) => globalThis.core?.api?.downloadFile(downloadRequest, network) ) => globalThis.core?.api?.downloadFile(downloadRequest, network)
/**
* Get unit in bytes for a remote file.
*
* @param url - The url of the file.
* @returns {Promise<number>} - A promise that resolves with the file size.
*/
const getFileSize: (url: string) => Promise<number> = (url: string) =>
globalThis.core.api?.getFileSize(url)
/** /**
* Aborts the download of a specific file. * Aborts the download of a specific file.
* @param {string} fileName - The name of the file whose download is to be aborted. * @param {string} fileName - The name of the file whose download is to be aborted.
@ -167,7 +158,6 @@ export {
getUserHomePath, getUserHomePath,
systemInformation, systemInformation,
showToast, showToast,
getFileSize,
dirName, dirName,
FileStat, FileStat,
} }

View File

@ -23,6 +23,11 @@ jest.mock('fs', () => ({
createWriteStream: jest.fn(), createWriteStream: jest.fn(),
})) }))
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)
jest.mock('request-progress', () => { jest.mock('request-progress', () => {
return jest.fn().mockImplementation(() => { return jest.fn().mockImplementation(() => {
return { return {
@ -54,18 +59,6 @@ describe('Downloader', () => {
beforeEach(() => { beforeEach(() => {
jest.resetAllMocks() jest.resetAllMocks()
}) })
it('should handle getFileSize errors correctly', async () => {
const observer = jest.fn()
const url = 'http://example.com/file'
const downloader = new Downloader(observer)
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)
await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
})
it('should pause download correctly', () => { it('should pause download correctly', () => {
const observer = jest.fn() const observer = jest.fn()

View File

@ -135,25 +135,4 @@ export class Downloader implements Processor {
pauseDownload(_observer: any, fileName: any) { pauseDownload(_observer: any, fileName: any) {
DownloadManager.instance.networkRequests[fileName]?.pause() DownloadManager.instance.networkRequests[fileName]?.pause()
} }
async getFileSize(_observer: any, url: string): Promise<number> {
return new Promise((resolve, reject) => {
const request = require('request')
request(
{
url,
method: 'HEAD',
},
function (err: any, response: any) {
if (err) {
console.error('Getting file size failed:', err)
reject(err)
} else {
const size: number = response.headers['content-length'] ?? -1
resolve(size)
}
}
)
})
}
} }

View File

@ -65,7 +65,6 @@ export enum DownloadRoute {
pauseDownload = 'pauseDownload', pauseDownload = 'pauseDownload',
resumeDownload = 'resumeDownload', resumeDownload = 'resumeDownload',
getDownloadProgress = 'getDownloadProgress', getDownloadProgress = 'getDownloadProgress',
getFileSize = 'getFileSize',
} }
export enum DownloadEvent { export enum DownloadEvent {

View File

@ -46,7 +46,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
errMessage = err.message errMessage = err.message
} }
toaster({ toaster({
title: 'Failed to get Hugging Face models', title: 'Oops, you may be rate limited, give it a bit more time',
description: errMessage, description: errMessage,
type: 'error', type: 'error',
}) })

View File

@ -3,11 +3,8 @@ import {
toHuggingFaceUrl, toHuggingFaceUrl,
InvalidHostError, InvalidHostError,
} from './huggingface' } from './huggingface'
import { getFileSize } from '@janhq/core'
// Mock the getFileSize function
jest.mock('@janhq/core', () => ({ jest.mock('@janhq/core', () => ({
getFileSize: jest.fn(),
AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'], AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'],
})) }))
@ -38,9 +35,15 @@ describe('huggingface utils', () => {
} }
;(global.fetch as jest.Mock).mockResolvedValue({ ;(global.fetch as jest.Mock).mockResolvedValue({
json: jest.fn().mockResolvedValue(mockResponse), json: jest
.fn()
.mockResolvedValueOnce(mockResponse)
.mockResolvedValueOnce([{
path: 'model-q4_0.gguf', size: 1000000,
},{
path: 'model-q4_0.gguf', size: 2000
}]),
}) })
;(getFileSize as jest.Mock).mockResolvedValue(1000000)
const result = await fetchHuggingFaceRepoData('user/repo') const result = await fetchHuggingFaceRepoData('user/repo')

View File

@ -1,4 +1,4 @@
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core' import { AllQuantizations, HuggingFaceRepoData } from '@janhq/core'
/** /**
* Fetches data from a Hugging Face repository. * Fetches data from a Hugging Face repository.
@ -39,21 +39,19 @@ export const fetchHuggingFaceRepoData = async (
) )
} }
const promises: Promise<number>[] = []
// fetching file sizes // fetching file sizes
const url = new URL(sanitizedUrl) const url = new URL(sanitizedUrl)
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0) const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
const repoTree: { path: string; size: number }[] = await fetch(
`https://huggingface.co/api/models/${paths[2]}/${paths[3]}/tree/main`
).then((res) => res.json())
for (const sibling of data.siblings) { for (const sibling of data.siblings) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}` const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
sibling.downloadUrl = downloadUrl sibling.downloadUrl = downloadUrl
promises.push(getFileSize(downloadUrl)) sibling.fileSize =
} repoTree.find((file) => file.path === sibling.rfilename)?.size ?? 0
const result = await Promise.all(promises)
for (let i = 0; i < data.siblings.length; i++) {
data.siblings[i].fileSize = result[i]
} }
AllQuantizations.forEach((quantization) => { AllQuantizations.forEach((quantization) => {