* fix: #3558 wrong model metadata import * chore: remove redundant metadata retrieval
This commit is contained in:
parent
87c8fdf5ac
commit
c0b59ece4d
@ -1,6 +1,13 @@
|
||||
/**
|
||||
* @jest-environment jsdom
|
||||
*/
|
||||
const readDirSyncMock = jest.fn()
|
||||
const existMock = jest.fn()
|
||||
const readFileSyncMock = jest.fn()
|
||||
const downloadMock = jest.fn()
|
||||
const mkdirMock = jest.fn()
|
||||
const writeFileSyncMock = jest.fn()
|
||||
const copyFileMock = jest.fn()
|
||||
|
||||
jest.mock('@janhq/core', () => ({
|
||||
...jest.requireActual('@janhq/core/node'),
|
||||
@ -8,6 +15,9 @@ jest.mock('@janhq/core', () => ({
|
||||
existsSync: existMock,
|
||||
readdirSync: readDirSyncMock,
|
||||
readFileSync: readFileSyncMock,
|
||||
writeFileSync: writeFileSyncMock,
|
||||
mkdir: mkdirMock,
|
||||
copyFile: copyFileMock,
|
||||
fileStat: () => ({
|
||||
isDirectory: false,
|
||||
}),
|
||||
@ -15,10 +25,20 @@ jest.mock('@janhq/core', () => ({
|
||||
dirName: jest.fn(),
|
||||
joinPath: (paths) => paths.join('/'),
|
||||
ModelExtension: jest.fn(),
|
||||
downloadFile: downloadMock,
|
||||
}))
|
||||
|
||||
global.fetch = jest.fn(() =>
|
||||
Promise.resolve({
|
||||
json: () => Promise.resolve({ test: 100 }),
|
||||
arrayBuffer: jest.fn(),
|
||||
})
|
||||
) as jest.Mock
|
||||
|
||||
import JanModelExtension from '.'
|
||||
import { fs, dirName } from '@janhq/core'
|
||||
import { renderJinjaTemplate } from './node/index'
|
||||
import { Template } from '@huggingface/jinja'
|
||||
|
||||
describe('JanModelExtension', () => {
|
||||
let sut: JanModelExtension
|
||||
@ -187,7 +207,6 @@ describe('JanModelExtension', () => {
|
||||
describe('no models downloaded', () => {
|
||||
it('should return empty array', async () => {
|
||||
// Mock downloaded models data
|
||||
const downloadedModels = []
|
||||
existMock.mockReturnValue(true)
|
||||
readDirSyncMock.mockReturnValue([])
|
||||
|
||||
@ -557,8 +576,41 @@ describe('JanModelExtension', () => {
|
||||
file_path: 'file://models/model1/model.json',
|
||||
} as any)
|
||||
|
||||
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
|
||||
expect(fs.unlinkSync).toHaveBeenCalledWith(
|
||||
'file://models/model1/test.engine'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadModel', () => {
|
||||
const model: any = {
|
||||
id: 'model-id',
|
||||
name: 'Test Model',
|
||||
sources: [
|
||||
{ url: 'http://example.com/model.gguf', filename: 'model.gguf' },
|
||||
],
|
||||
engine: 'test-engine',
|
||||
}
|
||||
|
||||
const network = {
|
||||
ignoreSSL: true,
|
||||
proxy: 'http://proxy.example.com',
|
||||
}
|
||||
|
||||
const gpuSettings: any = {
|
||||
gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }],
|
||||
}
|
||||
|
||||
it('should reject with invalid gguf metadata', async () => {
|
||||
existMock.mockImplementation(() => false)
|
||||
|
||||
expect(
|
||||
sut.downloadModel(model, gpuSettings, network)
|
||||
).rejects.toBeTruthy()
|
||||
})
|
||||
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
@ -24,6 +24,7 @@ import {
|
||||
ModelEvent,
|
||||
ModelFile,
|
||||
dirName,
|
||||
ModelSettingParams,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { extractFileName } from './helpers/path'
|
||||
@ -80,11 +81,27 @@ export default class JanModelExtension extends ModelExtension {
|
||||
gpuSettings?: GpuSetting,
|
||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||
): Promise<void> {
|
||||
// create corresponding directory
|
||||
// Create corresponding directory
|
||||
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
||||
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
||||
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
|
||||
|
||||
// Download HF model - model.json not exist
|
||||
if (!(await fs.existsSync(modelJsonPath))) {
|
||||
// It supports only one source for HF download
|
||||
const metadata = await this.fetchModelMetadata(model.sources[0].url)
|
||||
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||
if (updatedModel) {
|
||||
// Update model settings
|
||||
model.settings = {
|
||||
...model.settings,
|
||||
...updatedModel.settings,
|
||||
}
|
||||
model.parameters = {
|
||||
...model.parameters,
|
||||
...updatedModel.parameters,
|
||||
}
|
||||
}
|
||||
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
|
||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||
}
|
||||
@ -327,7 +344,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
// Should depend on sources?
|
||||
const isUserImportModel =
|
||||
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
||||
if (isUserImportModel) {
|
||||
if (isUserImportModel) {
|
||||
// just delete the folder
|
||||
return fs.rm(dirPath)
|
||||
}
|
||||
@ -555,7 +572,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
])
|
||||
)
|
||||
|
||||
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
|
||||
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||
|
||||
if (!defaultModel) {
|
||||
console.error('Unable to find default model')
|
||||
@ -575,18 +592,11 @@ export default class JanModelExtension extends ModelExtension {
|
||||
],
|
||||
parameters: {
|
||||
...defaultModel.parameters,
|
||||
stop: eos_id
|
||||
? [metadata['tokenizer.ggml.tokens'][eos_id] ?? '']
|
||||
: defaultModel.parameters.stop,
|
||||
...updatedModel.parameters,
|
||||
},
|
||||
settings: {
|
||||
...defaultModel.settings,
|
||||
prompt_template:
|
||||
metadata?.parsed_chat_template ??
|
||||
defaultModel.settings.prompt_template,
|
||||
ctx_len:
|
||||
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
|
||||
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
|
||||
...updatedModel.settings,
|
||||
llama_model_path: binaryFileName,
|
||||
},
|
||||
created: Date.now(),
|
||||
@ -666,9 +676,9 @@ export default class JanModelExtension extends ModelExtension {
|
||||
'retrieveGGUFMetadata',
|
||||
modelBinaryPath
|
||||
)
|
||||
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
|
||||
|
||||
const binaryFileName = await baseName(modelBinaryPath)
|
||||
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||
|
||||
const model: Model = {
|
||||
...defaultModel,
|
||||
@ -682,19 +692,12 @@ export default class JanModelExtension extends ModelExtension {
|
||||
],
|
||||
parameters: {
|
||||
...defaultModel.parameters,
|
||||
stop: eos_id
|
||||
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
|
||||
: defaultModel.parameters.stop,
|
||||
...updatedModel.parameters,
|
||||
},
|
||||
|
||||
settings: {
|
||||
...defaultModel.settings,
|
||||
prompt_template:
|
||||
metadata?.parsed_chat_template ??
|
||||
defaultModel.settings.prompt_template,
|
||||
ctx_len:
|
||||
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
|
||||
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
|
||||
...updatedModel.settings,
|
||||
llama_model_path: binaryFileName,
|
||||
},
|
||||
created: Date.now(),
|
||||
@ -860,4 +863,35 @@ export default class JanModelExtension extends ModelExtension {
|
||||
importedModels
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve Model Settings from GGUF Metadata
|
||||
* @param metadata
|
||||
* @returns
|
||||
*/
|
||||
async retrieveGGUFMetadata(metadata: any): Promise<Partial<Model>> {
|
||||
const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata)
|
||||
const defaultModel = DEFAULT_MODEL as Model
|
||||
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
||||
const architecture = metadata['general.architecture']
|
||||
|
||||
return {
|
||||
settings: {
|
||||
prompt_template: template ?? defaultModel.settings.prompt_template,
|
||||
ctx_len:
|
||||
metadata[`${architecture}.context_length`] ??
|
||||
metadata['llama.context_length'] ??
|
||||
4096,
|
||||
ngl:
|
||||
(metadata[`${architecture}.block_count`] ??
|
||||
metadata['llama.block_count'] ??
|
||||
32) + 1,
|
||||
},
|
||||
parameters: {
|
||||
stop: eos_id
|
||||
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
|
||||
: defaultModel.parameters.stop,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
|
||||
// Parse metadata and tensor info
|
||||
const { metadata } = ggufMetadata(buffer.buffer)
|
||||
|
||||
const template = new Template(metadata['tokenizer.chat_template'])
|
||||
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
||||
const bos_id = metadata['tokenizer.ggml.bos_token_id']
|
||||
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
|
||||
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
|
||||
// Parse jinja template
|
||||
const renderedTemplate = template.render({
|
||||
add_generation_prompt: true,
|
||||
eos_token,
|
||||
bos_token,
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: '{system_message}',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: '{prompt}',
|
||||
},
|
||||
],
|
||||
})
|
||||
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||
return {
|
||||
...metadata,
|
||||
parsed_chat_template: renderedTemplate,
|
||||
@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
|
||||
console.log('[MODEL_EXT]', e)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert metadata to jinja template
|
||||
* @param metadata
|
||||
*/
|
||||
export const renderJinjaTemplate = (metadata: any): string => {
|
||||
const template = new Template(metadata['tokenizer.chat_template'])
|
||||
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
||||
const bos_id = metadata['tokenizer.ggml.bos_token_id']
|
||||
if (eos_id === undefined || bos_id === undefined) {
|
||||
return ''
|
||||
}
|
||||
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
|
||||
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
|
||||
// Parse jinja template
|
||||
return template.render({
|
||||
add_generation_prompt: true,
|
||||
eos_token,
|
||||
bos_token,
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: '{system_message}',
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: '{prompt}',
|
||||
},
|
||||
],
|
||||
})
|
||||
}
|
||||
|
||||
53
extensions/model-extension/src/node/node.test.ts
Normal file
53
extensions/model-extension/src/node/node.test.ts
Normal file
@ -0,0 +1,53 @@
|
||||
import { renderJinjaTemplate } from './index'
|
||||
import { Template } from '@huggingface/jinja'
|
||||
|
||||
jest.mock('@huggingface/jinja', () => ({
|
||||
Template: jest.fn((template: string) => ({
|
||||
render: jest.fn(() => `${template}_rendered`),
|
||||
})),
|
||||
}))
|
||||
|
||||
describe('renderJinjaTemplate', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks() // Clear mocks between tests
|
||||
})
|
||||
|
||||
it('should render the template with correct parameters', () => {
|
||||
const metadata = {
|
||||
'tokenizer.chat_template': 'Hello, {{ messages }}!',
|
||||
'tokenizer.ggml.eos_token_id': 0,
|
||||
'tokenizer.ggml.bos_token_id': 1,
|
||||
'tokenizer.ggml.tokens': ['EOS', 'BOS'],
|
||||
}
|
||||
|
||||
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||
|
||||
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
|
||||
|
||||
expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered')
|
||||
})
|
||||
|
||||
it('should handle missing token IDs gracefully', () => {
|
||||
const metadata = {
|
||||
'tokenizer.chat_template': 'Hello, {{ messages }}!',
|
||||
'tokenizer.ggml.eos_token_id': 0,
|
||||
'tokenizer.ggml.tokens': ['EOS'],
|
||||
}
|
||||
|
||||
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||
|
||||
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
|
||||
|
||||
expect(renderedTemplate).toBe('')
|
||||
})
|
||||
|
||||
it('should handle empty template gracefully', () => {
|
||||
const metadata = {}
|
||||
|
||||
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||
|
||||
expect(Template).toHaveBeenCalledWith(undefined)
|
||||
|
||||
expect(renderedTemplate).toBe("")
|
||||
})
|
||||
})
|
||||
Loading…
x
Reference in New Issue
Block a user