* fix: #3558 wrong model metadata import * chore: remove redundant metadata retrieval
This commit is contained in:
parent
87c8fdf5ac
commit
c0b59ece4d
@ -1,6 +1,13 @@
|
|||||||
|
/**
|
||||||
|
* @jest-environment jsdom
|
||||||
|
*/
|
||||||
const readDirSyncMock = jest.fn()
|
const readDirSyncMock = jest.fn()
|
||||||
const existMock = jest.fn()
|
const existMock = jest.fn()
|
||||||
const readFileSyncMock = jest.fn()
|
const readFileSyncMock = jest.fn()
|
||||||
|
const downloadMock = jest.fn()
|
||||||
|
const mkdirMock = jest.fn()
|
||||||
|
const writeFileSyncMock = jest.fn()
|
||||||
|
const copyFileMock = jest.fn()
|
||||||
|
|
||||||
jest.mock('@janhq/core', () => ({
|
jest.mock('@janhq/core', () => ({
|
||||||
...jest.requireActual('@janhq/core/node'),
|
...jest.requireActual('@janhq/core/node'),
|
||||||
@ -8,6 +15,9 @@ jest.mock('@janhq/core', () => ({
|
|||||||
existsSync: existMock,
|
existsSync: existMock,
|
||||||
readdirSync: readDirSyncMock,
|
readdirSync: readDirSyncMock,
|
||||||
readFileSync: readFileSyncMock,
|
readFileSync: readFileSyncMock,
|
||||||
|
writeFileSync: writeFileSyncMock,
|
||||||
|
mkdir: mkdirMock,
|
||||||
|
copyFile: copyFileMock,
|
||||||
fileStat: () => ({
|
fileStat: () => ({
|
||||||
isDirectory: false,
|
isDirectory: false,
|
||||||
}),
|
}),
|
||||||
@ -15,10 +25,20 @@ jest.mock('@janhq/core', () => ({
|
|||||||
dirName: jest.fn(),
|
dirName: jest.fn(),
|
||||||
joinPath: (paths) => paths.join('/'),
|
joinPath: (paths) => paths.join('/'),
|
||||||
ModelExtension: jest.fn(),
|
ModelExtension: jest.fn(),
|
||||||
|
downloadFile: downloadMock,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
global.fetch = jest.fn(() =>
|
||||||
|
Promise.resolve({
|
||||||
|
json: () => Promise.resolve({ test: 100 }),
|
||||||
|
arrayBuffer: jest.fn(),
|
||||||
|
})
|
||||||
|
) as jest.Mock
|
||||||
|
|
||||||
import JanModelExtension from '.'
|
import JanModelExtension from '.'
|
||||||
import { fs, dirName } from '@janhq/core'
|
import { fs, dirName } from '@janhq/core'
|
||||||
|
import { renderJinjaTemplate } from './node/index'
|
||||||
|
import { Template } from '@huggingface/jinja'
|
||||||
|
|
||||||
describe('JanModelExtension', () => {
|
describe('JanModelExtension', () => {
|
||||||
let sut: JanModelExtension
|
let sut: JanModelExtension
|
||||||
@ -187,7 +207,6 @@ describe('JanModelExtension', () => {
|
|||||||
describe('no models downloaded', () => {
|
describe('no models downloaded', () => {
|
||||||
it('should return empty array', async () => {
|
it('should return empty array', async () => {
|
||||||
// Mock downloaded models data
|
// Mock downloaded models data
|
||||||
const downloadedModels = []
|
|
||||||
existMock.mockReturnValue(true)
|
existMock.mockReturnValue(true)
|
||||||
readDirSyncMock.mockReturnValue([])
|
readDirSyncMock.mockReturnValue([])
|
||||||
|
|
||||||
@ -557,8 +576,41 @@ describe('JanModelExtension', () => {
|
|||||||
file_path: 'file://models/model1/model.json',
|
file_path: 'file://models/model1/model.json',
|
||||||
} as any)
|
} as any)
|
||||||
|
|
||||||
expect(fs.unlinkSync).toHaveBeenCalledWith('file://models/model1/test.engine')
|
expect(fs.unlinkSync).toHaveBeenCalledWith(
|
||||||
|
'file://models/model1/test.engine'
|
||||||
|
)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('downloadModel', () => {
|
||||||
|
const model: any = {
|
||||||
|
id: 'model-id',
|
||||||
|
name: 'Test Model',
|
||||||
|
sources: [
|
||||||
|
{ url: 'http://example.com/model.gguf', filename: 'model.gguf' },
|
||||||
|
],
|
||||||
|
engine: 'test-engine',
|
||||||
|
}
|
||||||
|
|
||||||
|
const network = {
|
||||||
|
ignoreSSL: true,
|
||||||
|
proxy: 'http://proxy.example.com',
|
||||||
|
}
|
||||||
|
|
||||||
|
const gpuSettings: any = {
|
||||||
|
gpus: [{ name: 'nvidia-rtx-3080', arch: 'ampere' }],
|
||||||
|
}
|
||||||
|
|
||||||
|
it('should reject with invalid gguf metadata', async () => {
|
||||||
|
existMock.mockImplementation(() => false)
|
||||||
|
|
||||||
|
expect(
|
||||||
|
sut.downloadModel(model, gpuSettings, network)
|
||||||
|
).rejects.toBeTruthy()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import {
|
|||||||
ModelEvent,
|
ModelEvent,
|
||||||
ModelFile,
|
ModelFile,
|
||||||
dirName,
|
dirName,
|
||||||
|
ModelSettingParams,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { extractFileName } from './helpers/path'
|
import { extractFileName } from './helpers/path'
|
||||||
@ -80,11 +81,27 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
gpuSettings?: GpuSetting,
|
gpuSettings?: GpuSetting,
|
||||||
network?: { ignoreSSL?: boolean; proxy?: string }
|
network?: { ignoreSSL?: boolean; proxy?: string }
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// create corresponding directory
|
// Create corresponding directory
|
||||||
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
const modelDirPath = await joinPath([JanModelExtension._homeDir, model.id])
|
||||||
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath)
|
||||||
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
|
const modelJsonPath = await joinPath([modelDirPath, 'model.json'])
|
||||||
|
|
||||||
|
// Download HF model - model.json not exist
|
||||||
if (!(await fs.existsSync(modelJsonPath))) {
|
if (!(await fs.existsSync(modelJsonPath))) {
|
||||||
|
// It supports only one source for HF download
|
||||||
|
const metadata = await this.fetchModelMetadata(model.sources[0].url)
|
||||||
|
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||||
|
if (updatedModel) {
|
||||||
|
// Update model settings
|
||||||
|
model.settings = {
|
||||||
|
...model.settings,
|
||||||
|
...updatedModel.settings,
|
||||||
|
}
|
||||||
|
model.parameters = {
|
||||||
|
...model.parameters,
|
||||||
|
...updatedModel.parameters,
|
||||||
|
}
|
||||||
|
}
|
||||||
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
|
await fs.writeFileSync(modelJsonPath, JSON.stringify(model, null, 2))
|
||||||
events.emit(ModelEvent.OnModelsUpdate, {})
|
events.emit(ModelEvent.OnModelsUpdate, {})
|
||||||
}
|
}
|
||||||
@ -327,7 +344,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
// Should depend on sources?
|
// Should depend on sources?
|
||||||
const isUserImportModel =
|
const isUserImportModel =
|
||||||
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
modelInfo.metadata?.author?.toLowerCase() === 'user'
|
||||||
if (isUserImportModel) {
|
if (isUserImportModel) {
|
||||||
// just delete the folder
|
// just delete the folder
|
||||||
return fs.rm(dirPath)
|
return fs.rm(dirPath)
|
||||||
}
|
}
|
||||||
@ -555,7 +572,7 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
|
||||||
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
|
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||||
|
|
||||||
if (!defaultModel) {
|
if (!defaultModel) {
|
||||||
console.error('Unable to find default model')
|
console.error('Unable to find default model')
|
||||||
@ -575,18 +592,11 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
],
|
],
|
||||||
parameters: {
|
parameters: {
|
||||||
...defaultModel.parameters,
|
...defaultModel.parameters,
|
||||||
stop: eos_id
|
...updatedModel.parameters,
|
||||||
? [metadata['tokenizer.ggml.tokens'][eos_id] ?? '']
|
|
||||||
: defaultModel.parameters.stop,
|
|
||||||
},
|
},
|
||||||
settings: {
|
settings: {
|
||||||
...defaultModel.settings,
|
...defaultModel.settings,
|
||||||
prompt_template:
|
...updatedModel.settings,
|
||||||
metadata?.parsed_chat_template ??
|
|
||||||
defaultModel.settings.prompt_template,
|
|
||||||
ctx_len:
|
|
||||||
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
|
|
||||||
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
|
|
||||||
llama_model_path: binaryFileName,
|
llama_model_path: binaryFileName,
|
||||||
},
|
},
|
||||||
created: Date.now(),
|
created: Date.now(),
|
||||||
@ -666,9 +676,9 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
'retrieveGGUFMetadata',
|
'retrieveGGUFMetadata',
|
||||||
modelBinaryPath
|
modelBinaryPath
|
||||||
)
|
)
|
||||||
const eos_id = metadata?.['tokenizer.ggml.eos_token_id']
|
|
||||||
|
|
||||||
const binaryFileName = await baseName(modelBinaryPath)
|
const binaryFileName = await baseName(modelBinaryPath)
|
||||||
|
const updatedModel = await this.retrieveGGUFMetadata(metadata)
|
||||||
|
|
||||||
const model: Model = {
|
const model: Model = {
|
||||||
...defaultModel,
|
...defaultModel,
|
||||||
@ -682,19 +692,12 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
],
|
],
|
||||||
parameters: {
|
parameters: {
|
||||||
...defaultModel.parameters,
|
...defaultModel.parameters,
|
||||||
stop: eos_id
|
...updatedModel.parameters,
|
||||||
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
|
|
||||||
: defaultModel.parameters.stop,
|
|
||||||
},
|
},
|
||||||
|
|
||||||
settings: {
|
settings: {
|
||||||
...defaultModel.settings,
|
...defaultModel.settings,
|
||||||
prompt_template:
|
...updatedModel.settings,
|
||||||
metadata?.parsed_chat_template ??
|
|
||||||
defaultModel.settings.prompt_template,
|
|
||||||
ctx_len:
|
|
||||||
metadata?.['llama.context_length'] ?? defaultModel.settings.ctx_len,
|
|
||||||
ngl: (metadata?.['llama.block_count'] ?? 32) + 1,
|
|
||||||
llama_model_path: binaryFileName,
|
llama_model_path: binaryFileName,
|
||||||
},
|
},
|
||||||
created: Date.now(),
|
created: Date.now(),
|
||||||
@ -860,4 +863,35 @@ export default class JanModelExtension extends ModelExtension {
|
|||||||
importedModels
|
importedModels
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieve Model Settings from GGUF Metadata
|
||||||
|
* @param metadata
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
async retrieveGGUFMetadata(metadata: any): Promise<Partial<Model>> {
|
||||||
|
const template = await executeOnMain(NODE, 'renderJinjaTemplate', metadata)
|
||||||
|
const defaultModel = DEFAULT_MODEL as Model
|
||||||
|
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
||||||
|
const architecture = metadata['general.architecture']
|
||||||
|
|
||||||
|
return {
|
||||||
|
settings: {
|
||||||
|
prompt_template: template ?? defaultModel.settings.prompt_template,
|
||||||
|
ctx_len:
|
||||||
|
metadata[`${architecture}.context_length`] ??
|
||||||
|
metadata['llama.context_length'] ??
|
||||||
|
4096,
|
||||||
|
ngl:
|
||||||
|
(metadata[`${architecture}.block_count`] ??
|
||||||
|
metadata['llama.block_count'] ??
|
||||||
|
32) + 1,
|
||||||
|
},
|
||||||
|
parameters: {
|
||||||
|
stop: eos_id
|
||||||
|
? [metadata?.['tokenizer.ggml.tokens'][eos_id] ?? '']
|
||||||
|
: defaultModel.parameters.stop,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,27 +16,8 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
|
|||||||
// Parse metadata and tensor info
|
// Parse metadata and tensor info
|
||||||
const { metadata } = ggufMetadata(buffer.buffer)
|
const { metadata } = ggufMetadata(buffer.buffer)
|
||||||
|
|
||||||
const template = new Template(metadata['tokenizer.chat_template'])
|
|
||||||
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
|
||||||
const bos_id = metadata['tokenizer.ggml.bos_token_id']
|
|
||||||
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
|
|
||||||
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
|
|
||||||
// Parse jinja template
|
// Parse jinja template
|
||||||
const renderedTemplate = template.render({
|
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||||
add_generation_prompt: true,
|
|
||||||
eos_token,
|
|
||||||
bos_token,
|
|
||||||
messages: [
|
|
||||||
{
|
|
||||||
role: 'system',
|
|
||||||
content: '{system_message}',
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
content: '{prompt}',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
})
|
|
||||||
return {
|
return {
|
||||||
...metadata,
|
...metadata,
|
||||||
parsed_chat_template: renderedTemplate,
|
parsed_chat_template: renderedTemplate,
|
||||||
@ -45,3 +26,34 @@ export const retrieveGGUFMetadata = async (ggufPath: string) => {
|
|||||||
console.log('[MODEL_EXT]', e)
|
console.log('[MODEL_EXT]', e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert metadata to jinja template
|
||||||
|
* @param metadata
|
||||||
|
*/
|
||||||
|
export const renderJinjaTemplate = (metadata: any): string => {
|
||||||
|
const template = new Template(metadata['tokenizer.chat_template'])
|
||||||
|
const eos_id = metadata['tokenizer.ggml.eos_token_id']
|
||||||
|
const bos_id = metadata['tokenizer.ggml.bos_token_id']
|
||||||
|
if (eos_id === undefined || bos_id === undefined) {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
const eos_token = metadata['tokenizer.ggml.tokens'][eos_id]
|
||||||
|
const bos_token = metadata['tokenizer.ggml.tokens'][bos_id]
|
||||||
|
// Parse jinja template
|
||||||
|
return template.render({
|
||||||
|
add_generation_prompt: true,
|
||||||
|
eos_token,
|
||||||
|
bos_token,
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: '{system_message}',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: '{prompt}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
53
extensions/model-extension/src/node/node.test.ts
Normal file
53
extensions/model-extension/src/node/node.test.ts
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import { renderJinjaTemplate } from './index'
|
||||||
|
import { Template } from '@huggingface/jinja'
|
||||||
|
|
||||||
|
jest.mock('@huggingface/jinja', () => ({
|
||||||
|
Template: jest.fn((template: string) => ({
|
||||||
|
render: jest.fn(() => `${template}_rendered`),
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('renderJinjaTemplate', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks() // Clear mocks between tests
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should render the template with correct parameters', () => {
|
||||||
|
const metadata = {
|
||||||
|
'tokenizer.chat_template': 'Hello, {{ messages }}!',
|
||||||
|
'tokenizer.ggml.eos_token_id': 0,
|
||||||
|
'tokenizer.ggml.bos_token_id': 1,
|
||||||
|
'tokenizer.ggml.tokens': ['EOS', 'BOS'],
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||||
|
|
||||||
|
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
|
||||||
|
|
||||||
|
expect(renderedTemplate).toBe('Hello, {{ messages }}!_rendered')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle missing token IDs gracefully', () => {
|
||||||
|
const metadata = {
|
||||||
|
'tokenizer.chat_template': 'Hello, {{ messages }}!',
|
||||||
|
'tokenizer.ggml.eos_token_id': 0,
|
||||||
|
'tokenizer.ggml.tokens': ['EOS'],
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||||
|
|
||||||
|
expect(Template).toHaveBeenCalledWith('Hello, {{ messages }}!')
|
||||||
|
|
||||||
|
expect(renderedTemplate).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty template gracefully', () => {
|
||||||
|
const metadata = {}
|
||||||
|
|
||||||
|
const renderedTemplate = renderJinjaTemplate(metadata)
|
||||||
|
|
||||||
|
expect(Template).toHaveBeenCalledWith(undefined)
|
||||||
|
|
||||||
|
expect(renderedTemplate).toBe("")
|
||||||
|
})
|
||||||
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user