fix: wrong model download location when there is a mismatch model_id (#3733)

This commit is contained in:
Louis 2024-09-26 12:43:34 +07:00 committed by GitHub
parent 7f08f0fa79
commit 143f2f5c58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 202 additions and 15 deletions

View File

@ -34,7 +34,7 @@ export class Downloader implements Processor {
}
const array = normalizedPath.split(sep)
const fileName = array.pop() ?? ''
const modelId = array.pop() ?? ''
const modelId = downloadRequest.modelId ?? array.pop() ?? ''
const destination = resolve(getJanDataFolderPath(), normalizedPath)
validatePath(destination)

View File

@ -40,6 +40,14 @@ export type DownloadRequest = {
*/
extensionId?: string
/**
* The model ID of the model that initiated the download.
*/
modelId?: string
/**
* The download type.
*/
downloadType?: DownloadType | string
}

View File

@ -12,7 +12,7 @@ export interface ModelInterface {
* @returns A Promise that resolves when the model has been downloaded.
*/
downloadModel(
model: Model,
model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void>
@ -35,11 +35,11 @@ export interface ModelInterface {
* Gets a list of downloaded models.
* @returns A Promise that resolves with an array of downloaded models.
*/
getDownloadedModels(): Promise<Model[]>
getDownloadedModels(): Promise<ModelFile[]>
/**
* Gets a list of configured models.
* @returns A Promise that resolves with an array of configured models.
*/
getConfiguredModels(): Promise<Model[]>
getConfiguredModels(): Promise<ModelFile[]>
}

View File

@ -8,9 +8,14 @@ const downloadMock = jest.fn()
const mkdirMock = jest.fn()
const writeFileSyncMock = jest.fn()
const copyFileMock = jest.fn()
const dirNameMock = jest.fn()
const executeMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
events: {
emit: jest.fn(),
},
fs: {
existsSync: existMock,
readdirSync: readDirSyncMock,
@ -22,12 +27,15 @@ jest.mock('@janhq/core', () => ({
isDirectory: false,
}),
},
dirName: jest.fn(),
dirName: dirNameMock,
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
downloadFile: downloadMock,
executeOnMain: executeMock,
}))
jest.mock('@huggingface/gguf')
global.fetch = jest.fn(() =>
Promise.resolve({
json: () => Promise.resolve({ test: 100 }),
@ -37,8 +45,7 @@ global.fetch = jest.fn(() =>
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
import { renderJinjaTemplate } from './node/index'
import { Template } from '@huggingface/jinja'
import { gguf } from '@huggingface/gguf'
describe('JanModelExtension', () => {
let sut: JanModelExtension
@ -48,7 +55,7 @@ describe('JanModelExtension', () => {
sut = new JanModelExtension()
})
afterEach(() => {
beforeEach(() => {
jest.clearAllMocks()
})
@ -610,7 +617,172 @@ describe('JanModelExtension', () => {
).rejects.toBeTruthy()
})
it('should download corresponding ID', async () => {
existMock.mockImplementation(() => true)
dirNameMock.mockImplementation(() => 'file://models/model1')
downloadMock.mockImplementation(() => {
return Promise.resolve({})
})
expect(
await sut.downloadModel(
{ ...model, file_path: 'file://models/model1/model.json' },
gpuSettings,
network
)
).toBeUndefined()
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model.gguf',
modelId: 'model-id',
url: 'http://example.com/model.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
})
it('should handle invalid model file', async () => {
executeMock.mockResolvedValue({})
fs.readFileSync = jest.fn(() => {
return JSON.stringify({ metadata: { author: 'user' } })
})
expect(
sut.downloadModel(
{ ...model, file_path: 'file://models/model1/model.json' },
gpuSettings,
network
)
).resolves.not.toThrow()
expect(downloadMock).not.toHaveBeenCalled()
})
it('should handle model file with no sources', async () => {
executeMock.mockResolvedValue({})
const modelWithoutSources = { ...model, sources: [] }
expect(
sut.downloadModel(
{
...modelWithoutSources,
file_path: 'file://models/model1/model.json',
},
gpuSettings,
network
)
).resolves.toBe(undefined)
expect(downloadMock).not.toHaveBeenCalled()
})
it('should handle model file with multiple sources', async () => {
const modelWithMultipleSources = {
...model,
sources: [
{ url: 'http://example.com/model1.gguf', filename: 'model1.gguf' },
{ url: 'http://example.com/model2.gguf', filename: 'model2.gguf' },
],
}
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
downloadMock.mockImplementation(() => {
return Promise.resolve({})
})
expect(
await sut.downloadModel(
{
...modelWithMultipleSources,
file_path: 'file://models/model1/model.json',
},
gpuSettings,
network
)
).toBeUndefined()
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model1.gguf',
modelId: 'model-id',
url: 'http://example.com/model1.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
expect(downloadMock).toHaveBeenCalledWith(
{
localPath: 'file://models/model1/model2.gguf',
modelId: 'model-id',
url: 'http://example.com/model2.gguf',
},
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
)
})
it('should handle model file with no file_path', async () => {
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
const modelWithoutFilepath = { ...model, file_path: undefined }
await sut.downloadModel(modelWithoutFilepath, gpuSettings, network)
expect(downloadMock).toHaveBeenCalledWith(
expect.objectContaining({
localPath: 'file://models/model-id/model.gguf',
}),
expect.anything()
)
})
it('should handle model file with invalid file_path', async () => {
executeMock.mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
;(gguf as jest.Mock).mockResolvedValue({
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
})
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.DEFAULT_MODEL = {
parameters: { stop: [] },
}
const modelWithInvalidFilepath = {
...model,
file_path: 'file://models/invalid-model.json',
}
await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network)
expect(downloadMock).toHaveBeenCalledWith(
expect.objectContaining({
localPath: 'file://models/model1/model.gguf',
}),
expect.anything()
)
})
})
})

View File

@ -24,7 +24,6 @@ import {
ModelEvent,
ModelFile,
dirName,
ModelSettingParams,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@ -77,14 +76,15 @@ export default class JanModelExtension extends ModelExtension {
* @returns A Promise that resolves when the model is downloaded.
*/
async downloadModel(
model: Model,
model: ModelFile,
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void> {
// Create corresponding directory
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
const modelJsonPath =
model.file_path ?? (await joinPath([modelDirPath, 'model.json']))
// Download HF model - model.json not exist
if (!(await fs.existsSync(modelJsonPath))) {
@ -152,11 +152,15 @@ export default class JanModelExtension extends ModelExtension {
JanModelExtension._supportedModelFormat
)
if (source.filename) {
path = await joinPath([modelDirPath, source.filename])
path = model.file_path
? await joinPath([await dirName(model.file_path), source.filename])
: await joinPath([modelDirPath, source.filename])
}
const downloadRequest: DownloadRequest = {
url: source.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)
}
@ -166,10 +170,13 @@ export default class JanModelExtension extends ModelExtension {
model.sources[0]?.url,
JanModelExtension._supportedModelFormat
)
const path = await joinPath([modelDirPath, fileName])
const path = model.file_path
? await joinPath([await dirName(model.file_path), fileName])
: await joinPath([modelDirPath, fileName])
const downloadRequest: DownloadRequest = {
url: model.sources[0]?.url,
localPath: path,
modelId: model.id,
}
downloadFile(downloadRequest, network)