fix: wrong model download location when there is a mismatch model_id (#3733)
This commit is contained in:
parent
7f08f0fa79
commit
143f2f5c58
@ -34,7 +34,7 @@ export class Downloader implements Processor {
|
|||||||
}
|
}
|
||||||
const array = normalizedPath.split(sep)
|
const array = normalizedPath.split(sep)
|
||||||
const fileName = array.pop() ?? ''
|
const fileName = array.pop() ?? ''
|
||||||
const modelId = array.pop() ?? ''
|
const modelId = downloadRequest.modelId ?? array.pop() ?? ''
|
||||||
|
|
||||||
const destination = resolve(getJanDataFolderPath(), normalizedPath)
|
const destination = resolve(getJanDataFolderPath(), normalizedPath)
|
||||||
validatePath(destination)
|
validatePath(destination)
|
||||||
|
|||||||
@ -40,6 +40,14 @@ export type DownloadRequest = {
|
|||||||
*/
|
*/
|
||||||
extensionId?: string
|
extensionId?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model ID of the model that initiated the download.
|
||||||
|
*/
|
||||||
|
modelId?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The download type.
|
||||||
|
*/
|
||||||
downloadType?: DownloadType | string
|
downloadType?: DownloadType | string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ export interface ModelInterface {
|
|||||||
* @returns A Promise that resolves when the model has been downloaded.
|
* @returns A Promise that resolves when the model has been downloaded.
|
||||||
*/
|
*/
|
||||||
downloadModel(
|
downloadModel(
|
||||||
model: Model,
|
model: ModelFile,
|
||||||
gpuSettings?: GpuSetting,
|
gpuSettings?: GpuSetting,
|
||||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||||
): Promise<void>
|
): Promise<void>
|
||||||
@ -35,11 +35,11 @@ export interface ModelInterface {
|
|||||||
* Gets a list of downloaded models.
|
* Gets a list of downloaded models.
|
||||||
* @returns A Promise that resolves with an array of downloaded models.
|
* @returns A Promise that resolves with an array of downloaded models.
|
||||||
*/
|
*/
|
||||||
getDownloadedModels(): Promise<Model[]>
|
getDownloadedModels(): Promise<ModelFile[]>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets a list of configured models.
|
* Gets a list of configured models.
|
||||||
* @returns A Promise that resolves with an array of configured models.
|
* @returns A Promise that resolves with an array of configured models.
|
||||||
*/
|
*/
|
||||||
getConfiguredModels(): Promise<Model[]>
|
getConfiguredModels(): Promise<ModelFile[]>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,9 +8,14 @@ const downloadMock = jest.fn()
|
|||||||
const mkdirMock = jest.fn()
|
const mkdirMock = jest.fn()
|
||||||
const writeFileSyncMock = jest.fn()
|
const writeFileSyncMock = jest.fn()
|
||||||
const copyFileMock = jest.fn()
|
const copyFileMock = jest.fn()
|
||||||
|
const dirNameMock = jest.fn()
|
||||||
|
const executeMock = jest.fn()
|
||||||
|
|
||||||
jest.mock('@janhq/core', () => ({
|
jest.mock('@janhq/core', () => ({
|
||||||
...jest.requireActual('@janhq/core/node'),
|
...jest.requireActual('@janhq/core/node'),
|
||||||
|
events: {
|
||||||
|
emit: jest.fn(),
|
||||||
|
},
|
||||||
fs: {
|
fs: {
|
||||||
existsSync: existMock,
|
existsSync: existMock,
|
||||||
readdirSync: readDirSyncMock,
|
readdirSync: readDirSyncMock,
|
||||||
@ -22,12 +27,15 @@ jest.mock('@janhq/core', () => ({
|
|||||||
isDirectory: false,
|
isDirectory: false,
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
dirName: jest.fn(),
|
dirName: dirNameMock,
|
||||||
joinPath: (paths) => paths.join('/'),
|
joinPath: (paths) => paths.join('/'),
|
||||||
ModelExtension: jest.fn(),
|
ModelExtension: jest.fn(),
|
||||||
downloadFile: downloadMock,
|
downloadFile: downloadMock,
|
||||||
|
executeOnMain: executeMock,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
jest.mock('@huggingface/gguf')
|
||||||
|
|
||||||
global.fetch = jest.fn(() =>
|
global.fetch = jest.fn(() =>
|
||||||
Promise.resolve({
|
Promise.resolve({
|
||||||
json: () => Promise.resolve({ test: 100 }),
|
json: () => Promise.resolve({ test: 100 }),
|
||||||
@ -37,8 +45,7 @@ global.fetch = jest.fn(() =>
|
|||||||
|
|
||||||
import JanModelExtension from '.'
|
import JanModelExtension from '.'
|
||||||
import { fs, dirName } from '@janhq/core'
|
import { fs, dirName } from '@janhq/core'
|
||||||
import { renderJinjaTemplate } from './node/index'
|
import { gguf } from '@huggingface/gguf'
|
||||||
import { Template } from '@huggingface/jinja'
|
|
||||||
|
|
||||||
describe('JanModelExtension', () => {
|
describe('JanModelExtension', () => {
|
||||||
let sut: JanModelExtension
|
let sut: JanModelExtension
|
||||||
@ -48,7 +55,7 @@ describe('JanModelExtension', () => {
|
|||||||
sut = new JanModelExtension()
|
sut = new JanModelExtension()
|
||||||
})
|
})
|
||||||
|
|
||||||
afterEach(() => {
|
beforeEach(() => {
|
||||||
jest.clearAllMocks()
|
jest.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -610,7 +617,172 @@ describe('JanModelExtension', () => {
|
|||||||
).rejects.toBeTruthy()
|
).rejects.toBeTruthy()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should download corresponding ID', async () => {
|
||||||
|
existMock.mockImplementation(() => true)
|
||||||
|
dirNameMock.mockImplementation(() => 'file://models/model1')
|
||||||
|
downloadMock.mockImplementation(() => {
|
||||||
|
return Promise.resolve({})
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(
|
||||||
|
await sut.downloadModel(
|
||||||
|
{ ...model, file_path: 'file://models/model1/model.json' },
|
||||||
|
gpuSettings,
|
||||||
|
network
|
||||||
|
)
|
||||||
|
).toBeUndefined()
|
||||||
|
|
||||||
|
expect(downloadMock).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
localPath: 'file://models/model1/model.gguf',
|
||||||
|
modelId: 'model-id',
|
||||||
|
url: 'http://example.com/model.gguf',
|
||||||
|
},
|
||||||
|
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle invalid model file', async () => {
|
||||||
|
executeMock.mockResolvedValue({})
|
||||||
|
|
||||||
|
fs.readFileSync = jest.fn(() => {
|
||||||
|
return JSON.stringify({ metadata: { author: 'user' } })
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(
|
||||||
|
sut.downloadModel(
|
||||||
|
{ ...model, file_path: 'file://models/model1/model.json' },
|
||||||
|
gpuSettings,
|
||||||
|
network
|
||||||
|
)
|
||||||
|
).resolves.not.toThrow()
|
||||||
|
|
||||||
|
expect(downloadMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
it('should handle model file with no sources', async () => {
|
||||||
|
executeMock.mockResolvedValue({})
|
||||||
|
const modelWithoutSources = { ...model, sources: [] }
|
||||||
|
|
||||||
|
expect(
|
||||||
|
sut.downloadModel(
|
||||||
|
{
|
||||||
|
...modelWithoutSources,
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
},
|
||||||
|
gpuSettings,
|
||||||
|
network
|
||||||
|
)
|
||||||
|
).resolves.toBe(undefined)
|
||||||
|
|
||||||
|
expect(downloadMock).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model file with multiple sources', async () => {
|
||||||
|
const modelWithMultipleSources = {
|
||||||
|
...model,
|
||||||
|
sources: [
|
||||||
|
{ url: 'http://example.com/model1.gguf', filename: 'model1.gguf' },
|
||||||
|
{ url: 'http://example.com/model2.gguf', filename: 'model2.gguf' },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
executeMock.mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
;(gguf as jest.Mock).mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
// @ts-ignore
|
||||||
|
global.NODE = 'node'
|
||||||
|
// @ts-ignore
|
||||||
|
global.DEFAULT_MODEL = {
|
||||||
|
parameters: { stop: [] },
|
||||||
|
}
|
||||||
|
downloadMock.mockImplementation(() => {
|
||||||
|
return Promise.resolve({})
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(
|
||||||
|
await sut.downloadModel(
|
||||||
|
{
|
||||||
|
...modelWithMultipleSources,
|
||||||
|
file_path: 'file://models/model1/model.json',
|
||||||
|
},
|
||||||
|
gpuSettings,
|
||||||
|
network
|
||||||
|
)
|
||||||
|
).toBeUndefined()
|
||||||
|
|
||||||
|
expect(downloadMock).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
localPath: 'file://models/model1/model1.gguf',
|
||||||
|
modelId: 'model-id',
|
||||||
|
url: 'http://example.com/model1.gguf',
|
||||||
|
},
|
||||||
|
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(downloadMock).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
localPath: 'file://models/model1/model2.gguf',
|
||||||
|
modelId: 'model-id',
|
||||||
|
url: 'http://example.com/model2.gguf',
|
||||||
|
},
|
||||||
|
{ ignoreSSL: true, proxy: 'http://proxy.example.com' }
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model file with no file_path', async () => {
|
||||||
|
executeMock.mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
;(gguf as jest.Mock).mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
// @ts-ignore
|
||||||
|
global.NODE = 'node'
|
||||||
|
// @ts-ignore
|
||||||
|
global.DEFAULT_MODEL = {
|
||||||
|
parameters: { stop: [] },
|
||||||
|
}
|
||||||
|
const modelWithoutFilepath = { ...model, file_path: undefined }
|
||||||
|
|
||||||
|
await sut.downloadModel(modelWithoutFilepath, gpuSettings, network)
|
||||||
|
|
||||||
|
expect(downloadMock).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
localPath: 'file://models/model-id/model.gguf',
|
||||||
|
}),
|
||||||
|
expect.anything()
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model file with invalid file_path', async () => {
|
||||||
|
executeMock.mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
;(gguf as jest.Mock).mockResolvedValue({
|
||||||
|
metadata: { 'tokenizer.ggml.eos_token_id': 0 },
|
||||||
|
})
|
||||||
|
// @ts-ignore
|
||||||
|
global.NODE = 'node'
|
||||||
|
// @ts-ignore
|
||||||
|
global.DEFAULT_MODEL = {
|
||||||
|
parameters: { stop: [] },
|
||||||
|
}
|
||||||
|
const modelWithInvalidFilepath = {
|
||||||
|
...model,
|
||||||
|
file_path: 'file://models/invalid-model.json',
|
||||||
|
}
|
||||||
|
|
||||||
|
await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network)
|
||||||
|
|
||||||
|
expect(downloadMock).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
localPath: 'file://models/model1/model.gguf',
|
||||||
|
}),
|
||||||
|
expect.anything()
|
||||||
|
)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|||||||
@ -24,7 +24,6 @@ import {
|
|||||||
ModelEvent,
|
ModelEvent,
|
||||||
ModelFile,
|
ModelFile,
|
||||||
dirName,
|
dirName,
|
||||||
ModelSettingParams,
|
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { extractFileName } from './helpers/path'
|
import { extractFileName } from './helpers/path'
|
||||||
@ -77,14 +76,15 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
* @returns A Promise that resolves when the model is downloaded.
|
* @returns A Promise that resolves when the model is downloaded.
|
||||||
*/
|
*/
|
||||||
async downloadModel(
|
async downloadModel(
|
||||||
model: Model,
|
model: ModelFile,
|
||||||
gpuSettings?: GpuSetting,
|
gpuSettings?: GpuSetting,
|
||||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// Create corresponding directory
|
// Create corresponding directory
|
||||||
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
||||||
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
||||||
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
|
const modelJsonPath =
|
||||||
|
model.file_path ?? (await joinPath([modelDirPath, 'model.json']))
|
||||||
|
|
||||||
// Download HF model - model.json not exist
|
// Download HF model - model.json not exist
|
||||||
if (!(await fs.existsSync(modelJsonPath))) {
|
if (!(await fs.existsSync(modelJsonPath))) {
|
||||||
@ -152,11 +152,15 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
JanModelExtension._supportedModelFormat
|
JanModelExtension._supportedModelFormat
|
||||||
)
|
)
|
||||||
if (source.filename) {
|
if (source.filename) {
|
||||||
path = await joinPath([modelDirPath, source.filename])
|
path = model.file_path
|
||||||
|
? await joinPath([await dirName(model.file_path), source.filename])
|
||||||
|
: await joinPath([modelDirPath, source.filename])
|
||||||
}
|
}
|
||||||
|
|
||||||
const downloadRequest: DownloadRequest = {
|
const downloadRequest: DownloadRequest = {
|
||||||
url: source.url,
|
url: source.url,
|
||||||
localPath: path,
|
localPath: path,
|
||||||
|
modelId: model.id,
|
||||||
}
|
}
|
||||||
downloadFile(downloadRequest, network)
|
downloadFile(downloadRequest, network)
|
||||||
}
|
}
|
||||||
@ -166,10 +170,13 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
model.sources[0]?.url,
|
model.sources[0]?.url,
|
||||||
JanModelExtension._supportedModelFormat
|
JanModelExtension._supportedModelFormat
|
||||||
)
|
)
|
||||||
const path = await joinPath([modelDirPath, fileName])
|
const path = model.file_path
|
||||||
|
? await joinPath([await dirName(model.file_path), fileName])
|
||||||
|
: await joinPath([modelDirPath, fileName])
|
||||||
const downloadRequest: DownloadRequest = {
|
const downloadRequest: DownloadRequest = {
|
||||||
url: model.sources[0]?.url,
|
url: model.sources[0]?.url,
|
||||||
localPath: path,
|
localPath: path,
|
||||||
|
modelId: model.id,
|
||||||
}
|
}
|
||||||
downloadFile(downloadRequest, network)
|
downloadFile(downloadRequest, network)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user