From 143f2f5c585329a9569bfeaf43f4b3e3c1aa4196 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 26 Sep 2024 12:43:34 +0700 Subject: [PATCH] fix: wrong model download location when there is a mismatch model_id (#3733) --- core/src/node/api/processors/download.ts | 2 +- core/src/types/file/index.ts | 8 + core/src/types/model/modelInterface.ts | 6 +- extensions/model-extension/src/index.test.ts | 184 ++++++++++++++++++- extensions/model-extension/src/index.ts | 17 +- 5 files changed, 202 insertions(+), 15 deletions(-) diff --git a/core/src/node/api/processors/download.ts b/core/src/node/api/processors/download.ts index 21f7a6f1c..5db18a53a 100644 --- a/core/src/node/api/processors/download.ts +++ b/core/src/node/api/processors/download.ts @@ -34,7 +34,7 @@ export class Downloader implements Processor { } const array = normalizedPath.split(sep) const fileName = array.pop() ?? '' - const modelId = array.pop() ?? '' + const modelId = downloadRequest.modelId ?? array.pop() ?? '' const destination = resolve(getJanDataFolderPath(), normalizedPath) validatePath(destination) diff --git a/core/src/types/file/index.ts b/core/src/types/file/index.ts index 4db956b1e..9f3e32b3e 100644 --- a/core/src/types/file/index.ts +++ b/core/src/types/file/index.ts @@ -40,6 +40,14 @@ export type DownloadRequest = { */ extensionId?: string + /** + * The model ID of the model that initiated the download. + */ + modelId?: string + + /** + * The download type. + */ downloadType?: DownloadType | string } diff --git a/core/src/types/model/modelInterface.ts b/core/src/types/model/modelInterface.ts index 5b5856231..08d456b7e 100644 --- a/core/src/types/model/modelInterface.ts +++ b/core/src/types/model/modelInterface.ts @@ -12,7 +12,7 @@ export interface ModelInterface { * @returns A Promise that resolves when the model has been downloaded. */ downloadModel( - model: Model, + model: ModelFile, gpuSettings?: GpuSetting, network?: { ignoreSSL?: boolean; proxy?: string } ): Promise @@ -35,11 +35,11 @@ export interface ModelInterface { * Gets a list of downloaded models. * @returns A Promise that resolves with an array of downloaded models. */ - getDownloadedModels(): Promise + getDownloadedModels(): Promise /** * Gets a list of configured models. * @returns A Promise that resolves with an array of configured models. */ - getConfiguredModels(): Promise + getConfiguredModels(): Promise } diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts index 823b3a41d..5b126d4cc 100644 --- a/extensions/model-extension/src/index.test.ts +++ b/extensions/model-extension/src/index.test.ts @@ -8,9 +8,14 @@ const downloadMock = jest.fn() const mkdirMock = jest.fn() const writeFileSyncMock = jest.fn() const copyFileMock = jest.fn() +const dirNameMock = jest.fn() +const executeMock = jest.fn() jest.mock('@janhq/core', () => ({ ...jest.requireActual('@janhq/core/node'), + events: { + emit: jest.fn(), + }, fs: { existsSync: existMock, readdirSync: readDirSyncMock, @@ -22,12 +27,15 @@ jest.mock('@janhq/core', () => ({ isDirectory: false, }), }, - dirName: jest.fn(), + dirName: dirNameMock, joinPath: (paths) => paths.join('/'), ModelExtension: jest.fn(), downloadFile: downloadMock, + executeOnMain: executeMock, })) +jest.mock('@huggingface/gguf') + global.fetch = jest.fn(() => Promise.resolve({ json: () => Promise.resolve({ test: 100 }), @@ -37,8 +45,7 @@ global.fetch = jest.fn(() => import JanModelExtension from '.' import { fs, dirName } from '@janhq/core' -import { renderJinjaTemplate } from './node/index' -import { Template } from '@huggingface/jinja' +import { gguf } from '@huggingface/gguf' describe('JanModelExtension', () => { let sut: JanModelExtension @@ -48,7 +55,7 @@ describe('JanModelExtension', () => { sut = new JanModelExtension() }) - afterEach(() => { + beforeEach(() => { jest.clearAllMocks() }) @@ -610,7 +617,172 @@ describe('JanModelExtension', () => { ).rejects.toBeTruthy() }) - + it('should download corresponding ID', async () => { + existMock.mockImplementation(() => true) + dirNameMock.mockImplementation(() => 'file://models/model1') + downloadMock.mockImplementation(() => { + return Promise.resolve({}) + }) + + expect( + await sut.downloadModel( + { ...model, file_path: 'file://models/model1/model.json' }, + gpuSettings, + network + ) + ).toBeUndefined() + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model.gguf', + modelId: 'model-id', + url: 'http://example.com/model.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + }) + + it('should handle invalid model file', async () => { + executeMock.mockResolvedValue({}) + + fs.readFileSync = jest.fn(() => { + return JSON.stringify({ metadata: { author: 'user' } }) + }) + + expect( + sut.downloadModel( + { ...model, file_path: 'file://models/model1/model.json' }, + gpuSettings, + network + ) + ).resolves.not.toThrow() + + expect(downloadMock).not.toHaveBeenCalled() + }) + it('should handle model file with no sources', async () => { + executeMock.mockResolvedValue({}) + const modelWithoutSources = { ...model, sources: [] } + + expect( + sut.downloadModel( + { + ...modelWithoutSources, + file_path: 'file://models/model1/model.json', + }, + gpuSettings, + network + ) + ).resolves.toBe(undefined) + + expect(downloadMock).not.toHaveBeenCalled() + }) + + it('should handle model file with multiple sources', async () => { + const modelWithMultipleSources = { + ...model, + sources: [ + { url: 'http://example.com/model1.gguf', filename: 'model1.gguf' }, + { url: 'http://example.com/model2.gguf', filename: 'model2.gguf' }, + ], + } + + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + downloadMock.mockImplementation(() => { + return Promise.resolve({}) + }) + + expect( + await sut.downloadModel( + { + ...modelWithMultipleSources, + file_path: 'file://models/model1/model.json', + }, + gpuSettings, + network + ) + ).toBeUndefined() + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model1.gguf', + modelId: 'model-id', + url: 'http://example.com/model1.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + + expect(downloadMock).toHaveBeenCalledWith( + { + localPath: 'file://models/model1/model2.gguf', + modelId: 'model-id', + url: 'http://example.com/model2.gguf', + }, + { ignoreSSL: true, proxy: 'http://proxy.example.com' } + ) + }) + + it('should handle model file with no file_path', async () => { + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + const modelWithoutFilepath = { ...model, file_path: undefined } + + await sut.downloadModel(modelWithoutFilepath, gpuSettings, network) + + expect(downloadMock).toHaveBeenCalledWith( + expect.objectContaining({ + localPath: 'file://models/model-id/model.gguf', + }), + expect.anything() + ) + }) + + it('should handle model file with invalid file_path', async () => { + executeMock.mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + ;(gguf as jest.Mock).mockResolvedValue({ + metadata: { 'tokenizer.ggml.eos_token_id': 0 }, + }) + // @ts-ignore + global.NODE = 'node' + // @ts-ignore + global.DEFAULT_MODEL = { + parameters: { stop: [] }, + } + const modelWithInvalidFilepath = { + ...model, + file_path: 'file://models/invalid-model.json', + } + + await sut.downloadModel(modelWithInvalidFilepath, gpuSettings, network) + + expect(downloadMock).toHaveBeenCalledWith( + expect.objectContaining({ + localPath: 'file://models/model1/model.gguf', + }), + expect.anything() + ) + }) }) - }) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index beb9f1fed..20d23b747 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -24,7 +24,6 @@ import { ModelEvent, ModelFile, dirName, - ModelSettingParams, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -77,14 +76,15 @@ export default class JanModelExtension extends ModelExtension { * @returns A Promise that resolves when the model is downloaded. */ async downloadModel( - model: Model, + model: ModelFile, gpuSettings?: GpuSetting, network?: { ignoreSSL?: boolean; proxy?: string } ): Promise { // Create corresponding directory const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id]) if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath) - const modelJsonPath = await joinPath([modelDirPath, 'model.json']) + const modelJsonPath = + model.file_path ?? (await joinPath([modelDirPath, 'model.json'])) // Download HF model - model.json not exist if (!(await fs.existsSync(modelJsonPath))) { @@ -152,11 +152,15 @@ export default class JanModelExtension extends ModelExtension { JanModelExtension._supportedModelFormat ) if (source.filename) { - path = await joinPath([modelDirPath, source.filename]) + path = model.file_path + ? await joinPath([await dirName(model.file_path), source.filename]) + : await joinPath([modelDirPath, source.filename]) } + const downloadRequest: DownloadRequest = { url: source.url, localPath: path, + modelId: model.id, } downloadFile(downloadRequest, network) } @@ -166,10 +170,13 @@ export default class JanModelExtension extends ModelExtension { model.sources[0]?.url, JanModelExtension._supportedModelFormat ) - const path = await joinPath([modelDirPath, fileName]) + const path = model.file_path + ? await joinPath([await dirName(model.file_path), fileName]) + : await joinPath([modelDirPath, fileName]) const downloadRequest: DownloadRequest = { url: model.sources[0]?.url, localPath: path, + modelId: model.id, } downloadFile(downloadRequest, network)