Merge pull request #4133 from janhq/fix/4096-failed-to-get-huggingface-models

fix: 4096 - failed to get huggingface models
This commit is contained in:
Louis 2024-11-27 00:53:34 +07:00 committed by GitHub
commit 3a9a8dad3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 21 additions and 72 deletions

View File

@ -3,7 +3,6 @@ import { joinPath } from './core'
import { openFileExplorer } from './core'
import { getJanDataFolderPath } from './core'
import { abortDownload } from './core'
import { getFileSize } from './core'
import { executeOnMain } from './core'
describe('test core apis', () => {
@ -66,18 +65,6 @@ describe('test core apis', () => {
expect(result).toBe('aborted')
})
it('should get file size', async () => {
const url = 'http://example.com/file'
globalThis.core = {
api: {
getFileSize: jest.fn().mockResolvedValue(1024),
},
}
const result = await getFileSize(url)
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
expect(result).toBe(1024)
})
it('should execute function on main process', async () => {
const extension = 'testExtension'
const method = 'testMethod'

View File

@ -28,15 +28,6 @@ const downloadFile: (downloadRequest: DownloadRequest, network?: NetworkConfig)
network
) => globalThis.core?.api?.downloadFile(downloadRequest, network)
/**
* Get unit in bytes for a remote file.
*
* @param url - The url of the file.
* @returns {Promise<number>} - A promise that resolves with the file size.
*/
const getFileSize: (url: string) => Promise<number> = (url: string) =>
globalThis.core.api?.getFileSize(url)
/**
* Aborts the download of a specific file.
* @param {string} fileName - The name of the file whose download is to be aborted.
@ -167,7 +158,6 @@ export {
getUserHomePath,
systemInformation,
showToast,
getFileSize,
dirName,
FileStat,
}

View File

@ -23,6 +23,11 @@ jest.mock('fs', () => ({
createWriteStream: jest.fn(),
}))
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)
jest.mock('request-progress', () => {
return jest.fn().mockImplementation(() => {
return {
@ -54,18 +59,6 @@ describe('Downloader', () => {
beforeEach(() => {
jest.resetAllMocks()
})
it('should handle getFileSize errors correctly', async () => {
const observer = jest.fn()
const url = 'http://example.com/file'
const downloader = new Downloader(observer)
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)
await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
})
it('should pause download correctly', () => {
const observer = jest.fn()

View File

@ -135,25 +135,4 @@ export class Downloader implements Processor {
pauseDownload(_observer: any, fileName: any) {
DownloadManager.instance.networkRequests[fileName]?.pause()
}
async getFileSize(_observer: any, url: string): Promise<number> {
return new Promise((resolve, reject) => {
const request = require('request')
request(
{
url,
method: 'HEAD',
},
function (err: any, response: any) {
if (err) {
console.error('Getting file size failed:', err)
reject(err)
} else {
const size: number = response.headers['content-length'] ?? -1
resolve(size)
}
}
)
})
}
}

View File

@ -65,7 +65,6 @@ export enum DownloadRoute {
pauseDownload = 'pauseDownload',
resumeDownload = 'resumeDownload',
getDownloadProgress = 'getDownloadProgress',
getFileSize = 'getFileSize',
}
export enum DownloadEvent {

View File

@ -46,7 +46,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
errMessage = err.message
}
toaster({
title: 'Failed to get Hugging Face models',
title: 'Oops, you may be rate limited, give it a bit more time',
description: errMessage,
type: 'error',
})

View File

@ -3,11 +3,8 @@ import {
toHuggingFaceUrl,
InvalidHostError,
} from './huggingface'
import { getFileSize } from '@janhq/core'
// Mock the getFileSize function
jest.mock('@janhq/core', () => ({
getFileSize: jest.fn(),
AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'],
}))
@ -38,9 +35,15 @@ describe('huggingface utils', () => {
}
;(global.fetch as jest.Mock).mockResolvedValue({
json: jest.fn().mockResolvedValue(mockResponse),
json: jest
.fn()
.mockResolvedValueOnce(mockResponse)
.mockResolvedValueOnce([{
path: 'model-q4_0.gguf', size: 1000000,
},{
path: 'model-q4_0.gguf', size: 2000
}]),
})
;(getFileSize as jest.Mock).mockResolvedValue(1000000)
const result = await fetchHuggingFaceRepoData('user/repo')

View File

@ -1,4 +1,4 @@
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
import { AllQuantizations, HuggingFaceRepoData } from '@janhq/core'
/**
* Fetches data from a Hugging Face repository.
@ -39,21 +39,19 @@ export const fetchHuggingFaceRepoData = async (
)
}
const promises: Promise<number>[] = []
// fetching file sizes
const url = new URL(sanitizedUrl)
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
const repoTree: { path: string; size: number }[] = await fetch(
`https://huggingface.co/api/models/${paths[2]}/${paths[3]}/tree/main`
).then((res) => res.json())
for (const sibling of data.siblings) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
sibling.downloadUrl = downloadUrl
promises.push(getFileSize(downloadUrl))
}
const result = await Promise.all(promises)
for (let i = 0; i < data.siblings.length; i++) {
data.siblings[i].fileSize = result[i]
sibling.fileSize =
repoTree.find((file) => file.path === sibling.rfilename)?.size ?? 0
}
AllQuantizations.forEach((quantization) => {