fix: wrong model download location when there is a mismatch model_id (#3733)
This commit is contained in:
parent
7f08f0fa79
commit
143f2f5c58
@ -34,7 +34,7 @@ export class Downloader implements Processor {
|
||||
}
|
||||
const array = normalizedPath.split(sep)
|
||||
const fileName = array.pop() ?? ''
|
||||
const modelId = array.pop() ?? ''
|
||||
const modelId = downloadRequest.modelId ?? array.pop() ?? ''
|
||||
|
||||
const destination = resolve(getJanDataFolderPath(), normalizedPath)
|
||||
validatePath(destination)
|
||||
|
||||
@ -40,6 +40,14 @@ export type DownloadRequest = {
|
||||
*/
|
||||
extensionId?: string
|
||||
|
||||
/**
|
||||
* The model ID of the model that initiated the download.
|
||||
*/
|
||||
modelId?: string
|
||||
|
||||
/**
|
||||
* The download type.
|
||||
*/
|
||||
downloadType?: DownloadType | string
|
||||
}
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ export interface ModelInterface {
|
||||
* @returns A Promise that resolves when the model has been downloaded.
|
||||
*/
|
||||
downloadModel(
|
||||
model: Model,
|
||||
model: ModelFile,
|
||||
gpuSettings?: GpuSetting,
|
||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||
): Promise<void>
|
||||
@ -35,11 +35,11 @@ export interface ModelInterface {
|
||||
* Gets a list of downloaded models.
|
||||
* @returns A Promise that resolves with an array of downloaded models.
|
||||
*/
|
||||
getDownloadedModels(): Promise<Model[]>
|
||||
getDownloadedModels(): Promise<ModelFile[]>
|
||||
|
||||
/**
|
||||
* Gets a list of configured models.
|
||||
* @returns A Promise that resolves with an array of configured models.
|
||||
*/
|
||||
getConfiguredModels(): Promise<Model[]>
|
||||
getConfiguredModels(): Promise<ModelFile[]>
|
||||
}
|
||||
|
||||
@ -8,9 +8,14 @@ const downloadMock = jest.fn()
|
||||
const mkdirMock = jest.fn()
|
||||
const writeFileSyncMock = jest.fn()
|
||||
const copyFileMock = jest.fn()
|
||||
const dirNameMock = jest.fn()
|
||||
const executeMock = jest.fn()
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
events: {
|
||||
emit: jest.fn(),
|
||||
},
|
||||
fs: {
|
||||
existsSync: existMock,
|
||||
readdirSync: readDirSyncMock,
|
||||
@ -22,12 +27,15 @@ jest.mock('@janhq/core', () => ({
|
||||
isDirectory: false,
|
||||
}),
|
||||
},
|
||||
dirName: jest.fn(),
|
||||
dirName: dirNameMock,
|
||||
joinPath: (paths) => paths.join('/'),
|
||||
ModelExtension: jest.fn(),
|
||||
downloadFile: downloadMock,
|
||||
executeOnMain: executeMock,
|
||||
}))
|
||||
|
||||
jest.mock('@huggingface/gguf')
|
||||
|
||||
global.fetch = jest.fn(() =>
|
||||
Promise.resolve({
|
||||
json: () => Promise.resolve({ test: 100 }),
|
||||
@ -37,8 +45,7 @@ global.fetch = jest.fn(() =>
|
||||
|
||||
import JanModelExtension from '.'
|
||||
import { fs, dirName } from '@janhq/core'
|
||||
import { renderJinjaTemplate } from './node/index'
|
||||
import { Template } from '@huggingface/jinja'
|
||||
import { gguf } from '@huggingface/gguf'
|
||||
|
||||
describe('JanModelExtension', () => {
|
||||
let sut: JanModelExtension
|
||||
@ -48,7 +55,7 @@ describe('JanModelExtension', () => {
|
||||
sut = new JanModelExtension()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
@ -610,7 +617,172 @@ describe('JanModelExtension', () => {
|
||||
).rejects.toBeTruthy()
|
||||
})
|
||||
|
||||
|
||||
it('should download corresponding ID', async () => {
|
||||
existMock.mockImplementation(() => true)
|
||||
dirNameMock.mockImplementation(() => 'file://models/model1')
|
||||
downloadMock.mockImplementation(() => {
|
||||
return Promise.resolve({})
|
||||
})
|
||||
|
||||
expect(
|
||||
await sut.downloadModel(
|
||||
{ ...model, file_path: 'file://models/model1/model.json' },
|
||||
gpuSettings,
|
||||
network
|
||||
)
|
||||
).toBeUndefined()
|
||||
|
||||
expect(downloadMock).toHaveBeenCalledWith(
|
||||
{
|
||||
localPath: 'file://models/model1/model.gguf',
|
||||
modelId: 'model-id',
|
||||
url: 'http://example.com/model.gguf',
|
||||
},
|
||||
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle invalid model file', async () => {
|
||||
executeMock.mockResolvedValue({})
|
||||
|
||||
fs.readFileSync = jest.fn(() => {
|
||||
return JSON.stringify({ metadata: { author: 'user' } })
|
||||
})
|
||||
|
||||
expect(
|
||||
sut.downloadModel(
|
||||
{ ...model, file_path: 'file://models/model1/model.json' },
|
||||
gpuSettings,
|
||||
network
|
||||
)
|
||||
).resolves.not.toThrow()
|
||||
|
||||
expect(downloadMock).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should handle model file with no sources', async () => {
|
||||
executeMock.mockResolvedValue({})
|
||||
const modelWithoutSources = { ...model, sources: [] }
|
||||
|
||||
expect(
|
||||
sut.downloadModel(
|
||||
{
|
||||
...modelWithoutSources,
|
||||
file_path: 'file://models/model1/model.json',
|
||||
},
|
||||
gpuSettings,
|
||||
network
|
||||
)
|
||||
).resolves.toBe(undefined)
|
||||
|
||||
expect(downloadMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle model file with multiple sources', async () => {
|
||||
const modelWithMultipleSources = {
|
||||
...model,
|
||||
sources: [
|
||||
{ url: 'http://example.com/model1.gguf', filename: 'model1.gguf' },
|
||||
{ url: 'http://example.com/model2.gguf', filename: 'model2.gguf' },
|
||||
],
|
||||
}
|
||||
|
||||
executeMock.mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
;(gguf as jest.Mock).mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
// @ts-ignore
|
||||
global.NODE = 'node'
|
||||
// @ts-ignore
|
||||
global.DEFAULT_MODEL = {
|
||||
parameters: { stop: [] },
|
||||
}
|
||||
downloadMock.mockImplementation(() => {
|
||||
return Promise.resolve({})
|
||||
})
|
||||
|
||||
expect(
|
||||
await sut.downloadModel(
|
||||
{
|
||||
...modelWithMultipleSources,
|
||||
file_path: 'file://models/model1/model.json',
|
||||
},
|
||||
gpuSettings,
|
||||
network
|
||||
)
|
||||
).toBeUndefined()
|
||||
|
||||
expect(downloadMock).toHaveBeenCalledWith(
|
||||
{
|
||||
localPath: 'file://models/model1/model1.gguf',
|
||||
modelId: 'model-id',
|
||||
url: 'http://example.com/model1.gguf',
|
||||
},
|
||||
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||
)
|
||||
|
||||
expect(downloadMock).toHaveBeenCalledWith(
|
||||
{
|
||||
localPath: 'file://models/model1/model2.gguf',
|
||||
modelId: 'model-id',
|
||||
url: 'http://example.com/model2.gguf',
|
||||
},
|
||||
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle model file with no file_path', async () => {
|
||||
executeMock.mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
;(gguf as jest.Mock).mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
// @ts-ignore
|
||||
global.NODE = 'node'
|
||||
// @ts-ignore
|
||||
global.DEFAULT_MODEL = {
|
||||
parameters: { stop: [] },
|
||||
}
|
||||
const modelWithoutFilepath = { ...model, file_path: undefined }
|
||||
|
||||
await sut.downloadModel(modelWithoutFilepath, gpuSettings, network)
|
||||
|
||||
expect(downloadMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
localPath: 'file://models/model-id/model.gguf',
|
||||
}),
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle model file with invalid file_path', async () => {
|
||||
executeMock.mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
;(gguf as jest.Mock).mockResolvedValue({
|
||||
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||
})
|
||||
// @ts-ignore
|
||||
global.NODE = 'node'
|
||||
// @ts-ignore
|
||||
global.DEFAULT_MODEL = {
|
||||
parameters: { stop: [] },
|
||||
}
|
||||
const modelWithInvalidFilepath = {
|
||||
...model,
|
||||
file_path: 'file://models/invalid-model.json',
|
||||
}
|
||||
|
||||
await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network)
|
||||
|
||||
expect(downloadMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
localPath: 'file://models/model1/model.gguf',
|
||||
}),
|
||||
expect.anything()
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@ -24,7 +24,6 @@ import {
|
||||
ModelEvent,
|
||||
ModelFile,
|
||||
dirName,
|
||||
ModelSettingParams,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { extractFileName } from './helpers/path'
|
||||
@ -77,14 +76,15 @@ export default class JanModelExtension extends ModelExtension {
|
||||
* @returns A Promise that resolves when the model is downloaded.
|
||||
*/
|
||||
async downloadModel(
|
||||
model: Model,
|
||||
model: ModelFile,
|
||||
gpuSettings?: GpuSetting,
|
||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||
): Promise<void> {
|
||||
// Create corresponding directory
|
||||
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
||||
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
||||
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
|
||||
const modelJsonPath =
|
||||
model.file_path ?? (await joinPath([modelDirPath, 'model.json']))
|
||||
|
||||
// Download HF model - model.json not exist
|
||||
if (!(await fs.existsSync(modelJsonPath))) {
|
||||
@ -152,11 +152,15 @@ export default class JanModelExtension extends ModelExtension {
|
||||
JanModelExtension._supportedModelFormat
|
||||
)
|
||||
if (source.filename) {
|
||||
path = await joinPath([modelDirPath, source.filename])
|
||||
path = model.file_path
|
||||
? await joinPath([await dirName(model.file_path), source.filename])
|
||||
: await joinPath([modelDirPath, source.filename])
|
||||
}
|
||||
|
||||
const downloadRequest: DownloadRequest = {
|
||||
url: source.url,
|
||||
localPath: path,
|
||||
modelId: model.id,
|
||||
}
|
||||
downloadFile(downloadRequest, network)
|
||||
}
|
||||
@ -166,10 +170,13 @@ export default class JanModelExtension extends ModelExtension {
|
||||
model.sources[0]?.url,
|
||||
JanModelExtension._supportedModelFormat
|
||||
)
|
||||
const path = await joinPath([modelDirPath, fileName])
|
||||
const path = model.file_path
|
||||
? await joinPath([await dirName(model.file_path), fileName])
|
||||
: await joinPath([modelDirPath, fileName])
|
||||
const downloadRequest: DownloadRequest = {
|
||||
url: model.sources[0]?.url,
|
||||
localPath: path,
|
||||
modelId: model.id,
|
||||
}
|
||||
downloadFile(downloadRequest, network)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user