From c0b59ece4d41162f38b525a6a7802b77e1936508 Mon Sep 17 00:00:00 2001 From: Louis Date: Tue, 24 Sep 2024 10:07:53 +0700 Subject: [PATCH] fix: #3558 wrong model metadata import or download from HuggingFace (#3725) * fix: #3558 wrong model metadata import * chore: remove redundant metadata retrieval --- extensions/model-extension/src/index.test.ts | 56 ++++++++++++- extensions/model-extension/src/index.ts | 78 +++++++++++++------ extensions/model-extension/src/node/index.ts | 52 ++++++++----- .../model-extension/src/node/node.test.ts | 53 +++++++++++++ 4 files changed, 195 insertions(+), 44 deletions(-) create mode 100644 extensions/model-extension/src/node/node.test.ts diff --git a/extensions/model-extension/src/index.test.ts b/extensions/model-extension/src/index.test.ts index 6816d7101..823b3a41d 100644 --- a/extensions/model-extension/src/index.test.ts +++ b/extensions/model-extension/src/index.test.ts @@ -1,6 +1,13 @@ +/** + * @jest-environment jsdom + */ const readDirSyncMock = jest.fn() const existMock = jest.fn() const readFileSyncMock = jest.fn() +const downloadMock = jest.fn() +const mkdirMock = jest.fn() +const writeFileSyncMock = jest.fn() +const copyFileMock = jest.fn() jest.mock('@janhq/core', () => ({ ...jest.requireActual('@janhq/core/node'), @@ -8,6 +15,9 @@ jest.mock('@janhq/core', () => ({ existsSync: existMock, readdirSync: readDirSyncMock, readFileSync: readFileSyncMock, + writeFileSync: writeFileSyncMock, + mkdir: mkdirMock, + copyFile: copyFileMock, fileStat: () => ({ isDirectory: false, }), @@ -15,10 +25,20 @@ jest.mock('@janhq/core', () => ({ dirName: jest.fn(), joinPath: (paths) => paths.join('/'), ModelExtension: jest.fn(), + downloadFile: downloadMock, })) +global.fetch = jest.fn(() => + Promise.resolve({ + json: () => Promise.resolve({ test: 100 }), + arrayBuffer: jest.fn(), + }) +) as jest.Mock + import JanModelExtension from '.' import { fs, dirName } from '@janhq/core' +import { renderJinjaTemplate } from './node/index' +import { Template } from '@huggingface/jinja' describe('JanModelExtension', () => { let sut: JanModelExtension @@ -187,7 +207,6 @@ describe('JanModelExtension', () => { describe('no models downloaded', () => { it('should return empty array', async () => { // Mock downloaded models data - const downloadedModels = [] existMock.mockReturnValue(true) readDirSyncMock.mockReturnValue([]) @@ -557,8 +576,41 @@ describe('JanModelExtension', () => { file_path: 'file://models/model1/model.json', } as any) - expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine') + expect(fs.unlinkSync).toHaveBeenCalledWith( + 'file://models/model1/test.engine' + ) }) }) }) + + describe('downloadModel', () => { + const model: any = { + id: 'model-id', + name: 'Test Model', + sources: [ + { url: 'http://example.com/model.gguf', filename: 'model.gguf' }, + ], + engine: 'test-engine', + } + + const network = { + ignoreSSL: true, + proxy: 'http://proxy.example.com', + } + + const gpuSettings: any = { + gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }], + } + + it('should reject with invalid gguf metadata', async () => { + existMock.mockImplementation(() => false) + + expect( + sut.downloadModel(model, gpuSettings, network) + ).rejects.toBeTruthy() + }) + + + }) + }) diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index ac9b06a09..beb9f1fed 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -24,6 +24,7 @@ import { ModelEvent, ModelFile, dirName, + ModelSettingParams, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -80,11 +81,27 @@ export default class JanModelExtension extends ModelExtension { gpuSettings?: GpuSetting, network?: { ignoreSSL?: boolean; proxy?: string } ): Promise { - // create corresponding directory + // Create corresponding directory const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id]) if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath) const modelJsonPath = await joinPath([modelDirPath, 'model.json']) + + // Download HF model - model.json not exist if (!(await fs.existsSync(modelJsonPath))) { + // It supports only one source for HF download + const metadata = await this.fetchModelMetadata(model.sources[0].url) + const updatedModel = await this.retrieveGGUFMetadata(metadata) + if (updatedModel) { + // Update model settings + model.settings = { + ...model.settings, + ...updatedModel.settings, + } + model.parameters = { + ...model.parameters, + ...updatedModel.parameters, + } + } await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2)) events.emit(ModelEvent.OnModelsUpdate, {}) } @@ -327,7 +344,7 @@ export default class JanModelExtension extends ModelExtension { // Should depend on sources? const isUserImportModel = modelInfo.metadata?.author?.toLowerCase() === 'user' - if (isUserImportModel) { + if (isUserImportModel) { // just delete the folder return fs.rm(dirPath) } @@ -555,7 +572,7 @@ export default class JanModelExtension extends ModelExtension { ]) ) - const eos_id = metadata?.['tokenizer.ggml.eos_token_id'] + const updatedModel = await this.retrieveGGUFMetadata(metadata) if (!defaultModel) { console.error('Unable to find default model') @@ -575,18 +592,11 @@ export default class JanModelExtension extends ModelExtension { ], parameters: { ...defaultModel.parameters, - stop: eos_id - ? [metadata['tokenizer.ggml.tokens'][eos_id] ?? ''] - : defaultModel.parameters.stop, + ...updatedModel.parameters, }, settings: { ...defaultModel.settings, - prompt_template: - metadata?.parsed_chat_template ?? - defaultModel.settings.prompt_template, - ctx_len: - metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len, - ngl: (metadata?.['llama.block_count'] ?? 32) + 1, + ...updatedModel.settings, llama_model_path: binaryFileName, }, created: Date.now(), @@ -666,9 +676,9 @@ export default class JanModelExtension extends ModelExtension { 'retrieveGGUFMetadata', modelBinaryPath ) - const eos_id = metadata?.['tokenizer.ggml.eos_token_id'] const binaryFileName = await baseName(modelBinaryPath) + const updatedModel = await this.retrieveGGUFMetadata(metadata) const model: Model = { ...defaultModel, @@ -682,19 +692,12 @@ export default class JanModelExtension extends ModelExtension { ], parameters: { ...defaultModel.parameters, - stop: eos_id - ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? ''] - : defaultModel.parameters.stop, + ...updatedModel.parameters, }, settings: { ...defaultModel.settings, - prompt_template: - metadata?.parsed_chat_template ?? - defaultModel.settings.prompt_template, - ctx_len: - metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len, - ngl: (metadata?.['llama.block_count'] ?? 32) + 1, + ...updatedModel.settings, llama_model_path: binaryFileName, }, created: Date.now(), @@ -860,4 +863,35 @@ export default class JanModelExtension extends ModelExtension { importedModels ) } + + /** + * Retrieve Model Settings from GGUF Metadata + * @param metadata + * @returns + */ + async retrieveGGUFMetadata(metadata: any): Promise> { + const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata) + const defaultModel = DEFAULT_MODEL as Model + const eos_id = metadata['tokenizer.ggml.eos_token_id'] + const architecture = metadata['general.architecture'] + + return { + settings: { + prompt_template: template ?? defaultModel.settings.prompt_template, + ctx_len: + metadata[`${architecture}.context_length`] ?? + metadata['llama.context_length'] ?? + 4096, + ngl: + (metadata[`${architecture}.block_count`] ?? + metadata['llama.block_count'] ?? + 32) + 1, + }, + parameters: { + stop: eos_id + ? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? ''] + : defaultModel.parameters.stop, + }, + } + } } diff --git a/extensions/model-extension/src/node/index.ts b/extensions/model-extension/src/node/index.ts index 2b498f424..6323d7f97 100644 --- a/extensions/model-extension/src/node/index.ts +++ b/extensions/model-extension/src/node/index.ts @@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => { // Parse metadata and tensor info const { metadata } = ggufMetadata(buffer.buffer) - const template = new Template(metadata['tokenizer.chat_template']) - const eos_id = metadata['tokenizer.ggml.eos_token_id'] - const bos_id = metadata['tokenizer.ggml.bos_token_id'] - const eos_token = metadata['tokenizer.ggml.tokens'][eos_id] - const bos_token = metadata['tokenizer.ggml.tokens'][bos_id] // Parse jinja template - const renderedTemplate = template.render({ - add_generation_prompt: true, - eos_token, - bos_token, - messages: [ - { - role: 'system', - content: '{system_message}', - }, - { - role: 'user', - content: '{prompt}', - }, - ], - }) + const renderedTemplate = renderJinjaTemplate(metadata) return { ...metadata, parsed_chat_template: renderedTemplate, @@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => { console.log('[MODEL_EXT]', e) } } + +/** + * Convert metadata to jinja template + * @param metadata + */ +export const renderJinjaTemplate = (metadata: any): string => { + const template = new Template(metadata['tokenizer.chat_template']) + const eos_id = metadata['tokenizer.ggml.eos_token_id'] + const bos_id = metadata['tokenizer.ggml.bos_token_id'] + if (eos_id === undefined || bos_id === undefined) { + return '' + } + const eos_token = metadata['tokenizer.ggml.tokens'][eos_id] + const bos_token = metadata['tokenizer.ggml.tokens'][bos_id] + // Parse jinja template + return template.render({ + add_generation_prompt: true, + eos_token, + bos_token, + messages: [ + { + role: 'system', + content: '{system_message}', + }, + { + role: 'user', + content: '{prompt}', + }, + ], + }) +} diff --git a/extensions/model-extension/src/node/node.test.ts b/extensions/model-extension/src/node/node.test.ts new file mode 100644 index 000000000..afd2b8470 --- /dev/null +++ b/extensions/model-extension/src/node/node.test.ts @@ -0,0 +1,53 @@ +import { renderJinjaTemplate } from './index' +import { Template } from '@huggingface/jinja' + +jest.mock('@huggingface/jinja', () => ({ + Template: jest.fn((template: string) => ({ + render: jest.fn(() => `${template}_rendered`), + })), +})) + +describe('renderJinjaTemplate', () => { + beforeEach(() => { + jest.clearAllMocks() // Clear mocks between tests + }) + + it('should render the template with correct parameters', () => { + const metadata = { + 'tokenizer.chat_template': 'Hello, {{ messages }}!', + 'tokenizer.ggml.eos_token_id': 0, + 'tokenizer.ggml.bos_token_id': 1, + 'tokenizer.ggml.tokens': ['EOS', 'BOS'], + } + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!') + + expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered') + }) + + it('should handle missing token IDs gracefully', () => { + const metadata = { + 'tokenizer.chat_template': 'Hello, {{ messages }}!', + 'tokenizer.ggml.eos_token_id': 0, + 'tokenizer.ggml.tokens': ['EOS'], + } + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!') + + expect(renderedTemplate).toBe('') + }) + + it('should handle empty template gracefully', () => { + const metadata = {} + + const renderedTemplate = renderJinjaTemplate(metadata) + + expect(Template).toHaveBeenCalledWith(undefined) + + expect(renderedTemplate).toBe("") + }) +})