fix: #3558 wrong model metadata import or download from HuggingFace (#3725)

* fix: #3558 wrong model metadata import

* chore: remove redundant metadata retrieval
This commit is contained in:
Louis 2024-09-24 10:07:53 +07:00 committed by GitHub
parent 87c8fdf5ac
commit c0b59ece4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 195 additions and 44 deletions

View File

@ -1,6 +1,13 @@
/**
* @jest-environment jsdom
*/
const readDirSyncMock = jest.fn()
const existMock = jest.fn()
const readFileSyncMock = jest.fn()
const downloadMock = jest.fn()
const mkdirMock = jest.fn()
const writeFileSyncMock = jest.fn()
const copyFileMock = jest.fn()
jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
@ -8,6 +15,9 @@ jest.mock('@janhq/core', () => ({
existsSync: existMock,
readdirSync: readDirSyncMock,
readFileSync: readFileSyncMock,
writeFileSync: writeFileSyncMock,
mkdir: mkdirMock,
copyFile: copyFileMock,
fileStat: () => ({
isDirectory: false,
}),
@ -15,10 +25,20 @@ jest.mock('@janhq/core', () => ({
dirName: jest.fn(),
joinPath: (paths) => paths.join('/'),
ModelExtension: jest.fn(),
downloadFile: downloadMock,
}))
global.fetch = jest.fn(() =>
Promise.resolve({
json: () => Promise.resolve({ test: 100 }),
arrayBuffer: jest.fn(),
})
) as jest.Mock
import JanModelExtension from '.'
import { fs, dirName } from '@janhq/core'
import { renderJinjaTemplate } from './node/index'
import { Template } from '@huggingface/jinja'
describe('JanModelExtension', () => {
let sut: JanModelExtension
@ -187,7 +207,6 @@ describe('JanModelExtension', () => {
describe('no models downloaded', () => {
it('should return empty array', async () => {
// Mock downloaded models data
const downloadedModels = []
existMock.mockReturnValue(true)
readDirSyncMock.mockReturnValue([])
@ -557,8 +576,41 @@ describe('JanModelExtension', () => {
file_path: 'file://models/model1/model.json',
} as any)
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
expect(fs.unlinkSync).toHaveBeenCalledWith(
'file://models/model1/test.engine'
)
})
})
})
describe('downloadModel', () => {
const model: any = {
id: 'model-id',
name: 'Test Model',
sources: [
{ url: 'http://example.com/model.gguf', filename: 'model.gguf' },
],
engine: 'test-engine',
}
const network = {
ignoreSSL: true,
proxy: 'http://proxy.example.com',
}
const gpuSettings: any = {
gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }],
}
it('should reject with invalid gguf metadata', async () => {
existMock.mockImplementation(() => false)
expect(
sut.downloadModel(model, gpuSettings, network)
).rejects.toBeTruthy()
})
})
})

View File

@ -24,6 +24,7 @@ import {
ModelEvent,
ModelFile,
dirName,
ModelSettingParams,
} from '@janhq/core'
import { extractFileName } from './helpers/path'
@ -80,11 +81,27 @@ export default class JanModelExtension extends ModelExtension {
gpuSettings?: GpuSetting,
network?: { ignoreSSL?: boolean; proxy?: string }
): Promise<void> {
// create corresponding directory
// Create corresponding directory
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
// Download HF model - model.json not exist
if (!(await fs.existsSync(modelJsonPath))) {
// It supports only one source for HF download
const metadata = await this.fetchModelMetadata(model.sources[0].url)
const updatedModel = await this.retrieveGGUFMetadata(metadata)
if (updatedModel) {
// Update model settings
model.settings = {
...model.settings,
...updatedModel.settings,
}
model.parameters = {
...model.parameters,
...updatedModel.parameters,
}
}
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
events.emit(ModelEvent.OnModelsUpdate, {})
}
@ -327,7 +344,7 @@ export default class JanModelExtension extends ModelExtension {
// Should depend on sources?
const isUserImportModel =
modelInfo.metadata?.author?.toLowerCase() === 'user'
if (isUserImportModel) {
if (isUserImportModel) {
// just delete the folder
return fs.rm(dirPath)
}
@ -555,7 +572,7 @@ export default class JanModelExtension extends ModelExtension {
])
)
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
const updatedModel = await this.retrieveGGUFMetadata(metadata)
if (!defaultModel) {
console.error('Unable to find default model')
@ -575,18 +592,11 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
stop: eos_id
? [metadata['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
prompt_template:
metadata?.parsed_chat_template ??
defaultModel.settings.prompt_template,
ctx_len:
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@ -666,9 +676,9 @@ export default class JanModelExtension extends ModelExtension {
'retrieveGGUFMetadata',
modelBinaryPath
)
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
const binaryFileName = await baseName(modelBinaryPath)
const updatedModel = await this.retrieveGGUFMetadata(metadata)
const model: Model = {
...defaultModel,
@ -682,19 +692,12 @@ export default class JanModelExtension extends ModelExtension {
],
parameters: {
...defaultModel.parameters,
stop: eos_id
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
...updatedModel.parameters,
},
settings: {
...defaultModel.settings,
prompt_template:
metadata?.parsed_chat_template ??
defaultModel.settings.prompt_template,
ctx_len:
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
...updatedModel.settings,
llama_model_path: binaryFileName,
},
created: Date.now(),
@ -860,4 +863,35 @@ export default class JanModelExtension extends ModelExtension {
importedModels
)
}
/**
* Retrieve Model Settings from GGUF Metadata
* @param metadata
* @returns
*/
async retrieveGGUFMetadata(metadata: any): Promise<Partial<Model>> {
const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata)
const defaultModel = DEFAULT_MODEL as Model
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const architecture = metadata['general.architecture']
return {
settings: {
prompt_template: template ?? defaultModel.settings.prompt_template,
ctx_len:
metadata[`${architecture}.context_length`] ??
metadata['llama.context_length'] ??
4096,
ngl:
(metadata[`${architecture}.block_count`] ??
metadata['llama.block_count'] ??
32) + 1,
},
parameters: {
stop: eos_id
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
: defaultModel.parameters.stop,
},
}
}
}

View File

@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
// Parse metadata and tensor info
const { metadata } = ggufMetadata(buffer.buffer)
const template = new Template(metadata['tokenizer.chat_template'])
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const bos_id = metadata['tokenizer.ggml.bos_token_id']
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
// Parse jinja template
const renderedTemplate = template.render({
add_generation_prompt: true,
eos_token,
bos_token,
messages: [
{
role: 'system',
content: '{system_message}',
},
{
role: 'user',
content: '{prompt}',
},
],
})
const renderedTemplate = renderJinjaTemplate(metadata)
return {
...metadata,
parsed_chat_template: renderedTemplate,
@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
console.log('[MODEL_EXT]', e)
}
}
/**
* Convert metadata to jinja template
* @param metadata
*/
export const renderJinjaTemplate = (metadata: any): string => {
const template = new Template(metadata['tokenizer.chat_template'])
const eos_id = metadata['tokenizer.ggml.eos_token_id']
const bos_id = metadata['tokenizer.ggml.bos_token_id']
if (eos_id === undefined || bos_id === undefined) {
return ''
}
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
// Parse jinja template
return template.render({
add_generation_prompt: true,
eos_token,
bos_token,
messages: [
{
role: 'system',
content: '{system_message}',
},
{
role: 'user',
content: '{prompt}',
},
],
})
}

View File

@ -0,0 +1,53 @@
import { renderJinjaTemplate } from './index'
import { Template } from '@huggingface/jinja'
jest.mock('@huggingface/jinja', () => ({
Template: jest.fn((template: string) => ({
render: jest.fn(() => `${template}_rendered`),
})),
}))
describe('renderJinjaTemplate', () => {
beforeEach(() => {
jest.clearAllMocks() // Clear mocks between tests
})
it('should render the template with correct parameters', () => {
const metadata = {
'tokenizer.chat_template': 'Hello, {{ messages }}!',
'tokenizer.ggml.eos_token_id': 0,
'tokenizer.ggml.bos_token_id': 1,
'tokenizer.ggml.tokens': ['EOS', 'BOS'],
}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered')
})
it('should handle missing token IDs gracefully', () => {
const metadata = {
'tokenizer.chat_template': 'Hello, {{ messages }}!',
'tokenizer.ggml.eos_token_id': 0,
'tokenizer.ggml.tokens': ['EOS'],
}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
expect(renderedTemplate).toBe('')
})
it('should handle empty template gracefully', () => {
const metadata = {}
const renderedTemplate = renderJinjaTemplate(metadata)
expect(Template).toHaveBeenCalledWith(undefined)
expect(renderedTemplate).toBe("")
})
})