jan/web/utils/modelParam.ts
Louis 98bef7b7cf
test: add model parameter validation rules and persistence tests (#3618)
* test: add model parameter validation rules and persistence tests

* chore: fix CI cov step

* fix: invalid model settings should fallback to origin value

* test: support fallback integer settings
2024-09-17 08:34:58 +07:00

152 lines
4.7 KiB
TypeScript

/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/helpers/atoms/Thread.atom'
/**
* Validation rules for model parameters
*/
export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) =>
Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
ngl: (value: any) => Number.isInteger(value) && value >= 0,
embedding: (value: any) => typeof value === 'boolean',
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
prompt_template: (value: any) => typeof value === 'string',
llama_model_path: (value: any) => typeof value === 'string',
mmproj: (value: any) => typeof value === 'string',
vision_model: (value: any) => typeof value === 'boolean',
text_model: (value: any) => typeof value === 'boolean',
}
/**
* There are some parameters that need to be normalized before being sent to the server
* E.g. ctx_len should be an integer, but it can be a float from the input field
* @param key
* @param value
* @returns
*/
export const normalizeValue = (key: string, value: any) => {
if (
key === 'token_limit' ||
key === 'max_tokens' ||
key === 'ctx_len' ||
key === 'ngl' ||
key === 'n_parallel' ||
key === 'cpu_threads'
) {
// Convert to integer
return Math.floor(Number(value))
}
return value
}
/**
* Extract inference parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractInferenceParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelRuntimeParams => {
if (!modelParams) return {}
const defaultModelParams: ModelRuntimeParams = {
temperature: undefined,
token_limit: undefined,
top_k: undefined,
top_p: undefined,
stream: undefined,
max_tokens: undefined,
stop: undefined,
frequency_penalty: undefined,
presence_penalty: undefined,
}
const runtimeParams: ModelRuntimeParams = {}
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) {
const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: normalizeValue(key, value),
})
}
}
}
return runtimeParams
}
/**
* Extract model load parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractModelLoadParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelSettingParams => {
if (!modelParams) return {}
const defaultSettingParams: ModelSettingParams = {
ctx_len: undefined,
ngl: undefined,
embedding: undefined,
n_parallel: undefined,
cpu_threads: undefined,
prompt_template: undefined,
llama_model_path: undefined,
mmproj: undefined,
vision_model: undefined,
text_model: undefined,
}
const settingParams: ModelSettingParams = {}
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) {
const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(modelParams, {
...modelParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(settingParams, {
...settingParams,
[key]: normalizeValue(key, value),
})
}
}
}
return settingParams
}