Merge pull request #4133 from janhq/fix/4096-failed-to-get-huggingface-models
fix: 4096 - failed to get huggingface models
This commit is contained in:
commit
3a9a8dad3f
@ -3,7 +3,6 @@ import { joinPath } from './core'
|
|||||||
import { openFileExplorer } from './core'
|
import { openFileExplorer } from './core'
|
||||||
import { getJanDataFolderPath } from './core'
|
import { getJanDataFolderPath } from './core'
|
||||||
import { abortDownload } from './core'
|
import { abortDownload } from './core'
|
||||||
import { getFileSize } from './core'
|
|
||||||
import { executeOnMain } from './core'
|
import { executeOnMain } from './core'
|
||||||
|
|
||||||
describe('test core apis', () => {
|
describe('test core apis', () => {
|
||||||
@ -66,18 +65,6 @@ describe('test core apis', () => {
|
|||||||
expect(result).toBe('aborted')
|
expect(result).toBe('aborted')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should get file size', async () => {
|
|
||||||
const url = 'http://example.com/file'
|
|
||||||
globalThis.core = {
|
|
||||||
api: {
|
|
||||||
getFileSize: jest.fn().mockResolvedValue(1024),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
const result = await getFileSize(url)
|
|
||||||
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
|
|
||||||
expect(result).toBe(1024)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should execute function on main process', async () => {
|
it('should execute function on main process', async () => {
|
||||||
const extension = 'testExtension'
|
const extension = 'testExtension'
|
||||||
const method = 'testMethod'
|
const method = 'testMethod'
|
||||||
|
|||||||
@ -28,15 +28,6 @@ const downloadFile: (downloadRequest: DownloadRequest, network?: NetworkConfig)
|
|||||||
network
|
network
|
||||||
) => globalThis.core?.api?.downloadFile(downloadRequest, network)
|
) => globalThis.core?.api?.downloadFile(downloadRequest, network)
|
||||||
|
|
||||||
/**
|
|
||||||
* Get unit in bytes for a remote file.
|
|
||||||
*
|
|
||||||
* @param url - The url of the file.
|
|
||||||
* @returns {Promise<number>} - A promise that resolves with the file size.
|
|
||||||
*/
|
|
||||||
const getFileSize: (url: string) => Promise<number> = (url: string) =>
|
|
||||||
globalThis.core.api?.getFileSize(url)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Aborts the download of a specific file.
|
* Aborts the download of a specific file.
|
||||||
* @param {string} fileName - The name of the file whose download is to be aborted.
|
* @param {string} fileName - The name of the file whose download is to be aborted.
|
||||||
@ -167,7 +158,6 @@ export {
|
|||||||
getUserHomePath,
|
getUserHomePath,
|
||||||
systemInformation,
|
systemInformation,
|
||||||
showToast,
|
showToast,
|
||||||
getFileSize,
|
|
||||||
dirName,
|
dirName,
|
||||||
FileStat,
|
FileStat,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,6 +23,11 @@ jest.mock('fs', () => ({
|
|||||||
createWriteStream: jest.fn(),
|
createWriteStream: jest.fn(),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
const requestMock = jest.fn((options, callback) => {
|
||||||
|
callback(new Error('Test error'), null)
|
||||||
|
})
|
||||||
|
jest.mock('request', () => requestMock)
|
||||||
|
|
||||||
jest.mock('request-progress', () => {
|
jest.mock('request-progress', () => {
|
||||||
return jest.fn().mockImplementation(() => {
|
return jest.fn().mockImplementation(() => {
|
||||||
return {
|
return {
|
||||||
@ -54,18 +59,6 @@ describe('Downloader', () => {
|
|||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
jest.resetAllMocks()
|
jest.resetAllMocks()
|
||||||
})
|
})
|
||||||
it('should handle getFileSize errors correctly', async () => {
|
|
||||||
const observer = jest.fn()
|
|
||||||
const url = 'http://example.com/file'
|
|
||||||
|
|
||||||
const downloader = new Downloader(observer)
|
|
||||||
const requestMock = jest.fn((options, callback) => {
|
|
||||||
callback(new Error('Test error'), null)
|
|
||||||
})
|
|
||||||
jest.mock('request', () => requestMock)
|
|
||||||
|
|
||||||
await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
|
|
||||||
})
|
|
||||||
|
|
||||||
it('should pause download correctly', () => {
|
it('should pause download correctly', () => {
|
||||||
const observer = jest.fn()
|
const observer = jest.fn()
|
||||||
|
|||||||
@ -135,25 +135,4 @@ export class Downloader implements Processor {
|
|||||||
pauseDownload(_observer: any, fileName: any) {
|
pauseDownload(_observer: any, fileName: any) {
|
||||||
DownloadManager.instance.networkRequests[fileName]?.pause()
|
DownloadManager.instance.networkRequests[fileName]?.pause()
|
||||||
}
|
}
|
||||||
|
|
||||||
async getFileSize(_observer: any, url: string): Promise<number> {
|
|
||||||
return new Promise((resolve, reject) => {
|
|
||||||
const request = require('request')
|
|
||||||
request(
|
|
||||||
{
|
|
||||||
url,
|
|
||||||
method: 'HEAD',
|
|
||||||
},
|
|
||||||
function (err: any, response: any) {
|
|
||||||
if (err) {
|
|
||||||
console.error('Getting file size failed:', err)
|
|
||||||
reject(err)
|
|
||||||
} else {
|
|
||||||
const size: number = response.headers['content-length'] ?? -1
|
|
||||||
resolve(size)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -65,7 +65,6 @@ export enum DownloadRoute {
|
|||||||
pauseDownload = 'pauseDownload',
|
pauseDownload = 'pauseDownload',
|
||||||
resumeDownload = 'resumeDownload',
|
resumeDownload = 'resumeDownload',
|
||||||
getDownloadProgress = 'getDownloadProgress',
|
getDownloadProgress = 'getDownloadProgress',
|
||||||
getFileSize = 'getFileSize',
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum DownloadEvent {
|
export enum DownloadEvent {
|
||||||
|
|||||||
@ -46,7 +46,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
|
|||||||
errMessage = err.message
|
errMessage = err.message
|
||||||
}
|
}
|
||||||
toaster({
|
toaster({
|
||||||
title: 'Failed to get Hugging Face models',
|
title: 'Oops, you may be rate limited, give it a bit more time',
|
||||||
description: errMessage,
|
description: errMessage,
|
||||||
type: 'error',
|
type: 'error',
|
||||||
})
|
})
|
||||||
|
|||||||
@ -3,11 +3,8 @@ import {
|
|||||||
toHuggingFaceUrl,
|
toHuggingFaceUrl,
|
||||||
InvalidHostError,
|
InvalidHostError,
|
||||||
} from './huggingface'
|
} from './huggingface'
|
||||||
import { getFileSize } from '@janhq/core'
|
|
||||||
|
|
||||||
// Mock the getFileSize function
|
|
||||||
jest.mock('@janhq/core', () => ({
|
jest.mock('@janhq/core', () => ({
|
||||||
getFileSize: jest.fn(),
|
|
||||||
AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'],
|
AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'],
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -38,9 +35,15 @@ describe('huggingface utils', () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
;(global.fetch as jest.Mock).mockResolvedValue({
|
;(global.fetch as jest.Mock).mockResolvedValue({
|
||||||
json: jest.fn().mockResolvedValue(mockResponse),
|
json: jest
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValueOnce(mockResponse)
|
||||||
|
.mockResolvedValueOnce([{
|
||||||
|
path: 'model-q4_0.gguf', size: 1000000,
|
||||||
|
},{
|
||||||
|
path: 'model-q4_0.gguf', size: 2000
|
||||||
|
}]),
|
||||||
})
|
})
|
||||||
;(getFileSize as jest.Mock).mockResolvedValue(1000000)
|
|
||||||
|
|
||||||
const result = await fetchHuggingFaceRepoData('user/repo')
|
const result = await fetchHuggingFaceRepoData('user/repo')
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
|
import { AllQuantizations, HuggingFaceRepoData } from '@janhq/core'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetches data from a Hugging Face repository.
|
* Fetches data from a Hugging Face repository.
|
||||||
@ -39,21 +39,19 @@ export const fetchHuggingFaceRepoData = async (
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const promises: Promise<number>[] = []
|
|
||||||
|
|
||||||
// fetching file sizes
|
// fetching file sizes
|
||||||
const url = new URL(sanitizedUrl)
|
const url = new URL(sanitizedUrl)
|
||||||
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)
|
||||||
|
|
||||||
|
const repoTree: { path: string; size: number }[] = await fetch(
|
||||||
|
`https://huggingface.co/api/models/${paths[2]}/${paths[3]}/tree/main`
|
||||||
|
).then((res) => res.json())
|
||||||
|
|
||||||
for (const sibling of data.siblings) {
|
for (const sibling of data.siblings) {
|
||||||
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
|
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
|
||||||
sibling.downloadUrl = downloadUrl
|
sibling.downloadUrl = downloadUrl
|
||||||
promises.push(getFileSize(downloadUrl))
|
sibling.fileSize =
|
||||||
}
|
repoTree.find((file) => file.path === sibling.rfilename)?.size ?? 0
|
||||||
|
|
||||||
const result = await Promise.all(promises)
|
|
||||||
for (let i = 0; i < data.siblings.length; i++) {
|
|
||||||
data.siblings[i].fileSize = result[i]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
AllQuantizations.forEach((quantization) => {
|
AllQuantizations.forEach((quantization) => {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user