test: add model parameter validation rules and persistence tests (#3618)

* test: add model parameter validation rules and persistence tests

* chore: fix CI cov step

* fix: invalid model settings should fallback to origin value

* test: support fallback integer settings
This commit is contained in:
Louis 2024-09-17 08:34:58 +07:00 committed by GitHub
parent 3ffaa1ef7f
commit 98bef7b7cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 681 additions and 54 deletions

View File

@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
if (!settings?.ngl) {
settings.ngl = 100
}
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
method: 'POST',
headers: {
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
})
.then((res) => {
log(
`[CORTEX]::Debug: Load model success with response ${JSON.stringify(
`[CORTEX]:: Load model success with response ${JSON.stringify(
res
)}`
)
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
async function validateModelStatus(modelId: string): Promise<void> {
// Send a GET request to the validation URL.
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
log(`[CORTEX]::Debug: Validating model ${modelId}`)
log(`[CORTEX]:: Validating model ${modelId}`)
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
method: 'POST',
body: JSON.stringify({
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
retryDelay: 300,
}).then(async (res: Response) => {
log(
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
`[CORTEX]:: Validate model state with response ${JSON.stringify(
res.status
)}`
)
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
// Otherwise, return an object with an error message.
if (body.model_loaded) {
log(
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
`[CORTEX]:: Validate model state success with response ${JSON.stringify(
body
)}`
)
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
}
const errorBody = await res.text()
log(
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
`[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
res.statusText
)}`
)
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
async function killSubprocess(): Promise<void> {
const controller = new AbortController()
setTimeout(() => controller.abort(), 5000)
log(`[CORTEX]::Debug: Request to kill cortex`)
log(`[CORTEX]:: Request to kill cortex`)
const killRequest = () => {
return fetch(NITRO_HTTP_KILL_URL, {
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
.then(() =>
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
)
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch((err) => {
log(
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
`[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
)
throw 'PORT_NOT_AVAILABLE'
})
}
if (subprocess?.pid && process.platform !== 'darwin') {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
} else {
tcpPortUsed
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.then(() => resolve())
.catch(() => {
log(
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
* @returns A promise that resolves when the Nitro subprocess is started.
*/
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
log(`[CORTEX]:: Spawning cortex subprocess...`)
return new Promise<void>(async (resolve, reject) => {
let executableOptions = executableNitroFile(
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
// Execute the binary
log(
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
`[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
)
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
// Handle subprocess output
subprocess.stdout.on('data', (data: any) => {
log(`[CORTEX]::Debug: ${data}`)
log(`[CORTEX]:: ${data}`)
})
subprocess.stderr.on('data', (data: any) => {
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
})
subprocess.on('close', (code: any) => {
log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
log(`[CORTEX]:: cortex exited with code: ${code}`)
subprocess = undefined
reject(`child process exited with code ${code}`)
})
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
tcpPortUsed
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
.then(() => {
log(`[CORTEX]::Debug: cortex is ready`)
log(`[CORTEX]:: cortex is ready`)
resolve()
})
})

View File

@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
}
if (subprocess?.pid) {
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
const pid = subprocess.pid
return new Promise((resolve, reject) => {
terminate(pid, function (err) {
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
return tcpPortUsed
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
.then(() => resolve())
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
.then(() => log(`[CORTEX]:: cortex process is terminated`))
.catch(() => {
killRequest()
})

View File

@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { toRuntimeParams } from '@/utils/modelParam'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import {
@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
},
]
const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
const messageRequest: MessageRequest = {
id: msgId,

View File

@ -87,26 +87,28 @@ const SliderRightPanel = ({
onValueChanged?.(Number(min))
setVal(min.toString())
setShowTooltip({ max: false, min: true })
} else {
setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
}
}}
onChange={(e) => {
// Should not accept invalid value or NaN
// E.g. anything changes that trigger onValueChanged
// Which is incorrect
if (Number(e.target.value) > Number(max)) {
setVal(max.toString())
} else if (
Number(e.target.value) < Number(min) ||
!e.target.value.length
) {
setVal(min.toString())
} else if (Number.isNaN(Number(e.target.value))) return
onValueChanged?.(Number(e.target.value))
// TODO: How to support negative number input?
// Passthru since it validates again onBlur
if (/^\d*\.?\d*$/.test(e.target.value)) {
setVal(e.target.value)
}
// Should not accept invalid value or NaN
// E.g. anything changes that trigger onValueChanged
// Which is incorrect
if (
Number(e.target.value) > Number(max) ||
Number(e.target.value) < Number(min) ||
Number.isNaN(Number(e.target.value))
) {
return
}
onValueChanged?.(Number(e.target.value))
}}
/>
}

View File

@ -23,7 +23,10 @@ import {
import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
@ -189,8 +192,8 @@ export default function useSendChatMessage() {
if (engineParamsUpdate) setReloadModel(true)
const runtimeParams = toRuntimeParams(activeModelParams)
const settingParams = toSettingParams(activeModelParams)
const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim()

View File

@ -0,0 +1,314 @@
import { renderHook, act } from '@testing-library/react'
// Mock dependencies
jest.mock('ulidx')
jest.mock('@/extension')
import useUpdateModelParameters from './useUpdateModelParameters'
import { extensionManager } from '@/extension'
// Mock data
let model: any = {
id: 'model-1',
engine: 'nitro',
}
let extension: any = {
saveThread: jest.fn(),
}
const mockThread: any = {
id: 'thread-1',
assistants: [
{
model: {
parameters: {},
settings: {},
},
},
],
object: 'thread',
title: 'New Thread',
created: 0,
updated: 0,
}
describe('useUpdateModelParameters', () => {
beforeAll(() => {
jest.clearAllMocks()
jest.mock('./useRecommendedModel', () => ({
useRecommendedModel: () => ({
recommendedModel: model,
setRecommendedModel: jest.fn(),
downloadedModels: [],
}),
}))
})
it('should update model parameters and save thread when params are valid', async () => {
const mockValidParameters: any = {
params: {
// Inference
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
token_limit: 1000,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 0.3,
presence_penalty: 0.2,
// Load model
ctx_len: 1024,
ngl: 12,
embedding: true,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(mockThread, mockValidParameters)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
token_limit: 1000,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 0.3,
presence_penalty: 0.2,
},
settings: {
ctx_len: 1024,
ngl: 12,
embedding: true,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should not update invalid model parameters', async () => {
const mockInvalidParameters: any = {
params: {
// Inference
stop: [1, '<eos>'],
temperature: '0.5',
token_limit: '1000',
top_k: '0.7',
top_p: '0.1',
stream: 'true',
max_tokens: '1000',
frequency_penalty: '0.3',
presence_penalty: '0.2',
// Load model
ctx_len: '1024',
ngl: '12',
embedding: 'true',
n_parallel: '2',
cpu_threads: '4',
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockInvalidParameters
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
max_tokens: 1000,
token_limit: 1000,
},
settings: {
cpu_threads: 4,
ctx_len: 1024,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
n_parallel: 2,
ngl: 12,
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should update valid model parameters only', async () => {
const mockInvalidParameters: any = {
params: {
// Inference
stop: ['<eos>'],
temperature: -0.5,
token_limit: 100.2,
top_k: 0.7,
top_p: 0.1,
stream: true,
max_tokens: 1000,
frequency_penalty: 1.2,
presence_penalty: 0.2,
// Load model
ctx_len: 1024,
ngl: 0,
embedding: 'true',
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
vision_model: 'vision',
text_model: 'text',
},
modelId: 'model-1',
engine: 'nitro',
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockInvalidParameters
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>'],
top_k: 0.7,
top_p: 0.1,
stream: true,
token_limit: 100,
max_tokens: 1000,
presence_penalty: 0.2,
},
settings: {
ctx_len: 1024,
ngl: 0,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should handle missing modelId and engine gracefully', async () => {
const mockParametersWithoutModelIdAndEngine: any = {
params: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockParametersWithoutModelIdAndEngine
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
settings: {},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
})

View File

@ -12,7 +12,10 @@ import {
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import useRecommendedModel from './useRecommendedModel'
@ -47,12 +50,17 @@ export default function useUpdateModelParameters() {
const toUpdateSettings = processStopWords(settings.params ?? {})
const updatedModelParams = settings.modelId
? toUpdateSettings
: { ...activeModelParams, ...toUpdateSettings }
: {
...selectedModel?.parameters,
...selectedModel?.settings,
...activeModelParams,
...toUpdateSettings,
}
// update the state
setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)
const runtimeParams = extractInferenceParams(updatedModelParams)
const settingParams = extractModelLoadParams(updatedModelParams)
const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {

View File

@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
import { getConfigurationsData } from '@/utils/componentSettings'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
const selectedModel = useAtomValue(selectedModelAtom)
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
toSettingParams(selectedModel?.settings)
extractModelLoadParams(selectedModel?.settings)
)
useEffect(() => {
if (selectedModel) {
setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
setCurrentModelSettingParams(
extractModelLoadParams(selectedModel?.settings)
)
}
}, [selectedModel])
const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,

View File

@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
import { localEngines } from '@/utils/modelEngine'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools'
@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
const settings = useMemo(() => {
// runtime setting
const modelRuntimeParams = toRuntimeParams(activeModelParams)
const modelRuntimeParams = extractInferenceParams(
{
...selectedModel?.parameters,
...activeModelParams,
},
selectedModel?.parameters
)
const componentDataRuntimeSetting = getConfigurationsData(
modelRuntimeParams,
selectedModel
).filter((x) => x.key !== 'prompt_template')
// engine setting
const modelEngineParams = toSettingParams(activeModelParams)
const modelEngineParams = extractModelLoadParams(
{
...selectedModel?.settings,
...activeModelParams,
},
selectedModel?.settings
)
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel
@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
}, [activeModelParams, selectedModel])
const promptTemplateSettings = useMemo(() => {
const modelEngineParams = toSettingParams(activeModelParams)
const modelEngineParams = extractModelLoadParams({
...selectedModel?.settings,
...activeModelParams,
})
const componentDataEngineSetting = getConfigurationsData(
modelEngineParams,
selectedModel

View File

@ -0,0 +1,183 @@
// web/utils/modelParam.test.ts
import { normalizeValue, validationRules } from './modelParam'
describe('validationRules', () => {
it('should validate temperature correctly', () => {
expect(validationRules.temperature(0.5)).toBe(true)
expect(validationRules.temperature(2)).toBe(true)
expect(validationRules.temperature(0)).toBe(true)
expect(validationRules.temperature(-0.1)).toBe(false)
expect(validationRules.temperature(2.3)).toBe(false)
expect(validationRules.temperature('0.5')).toBe(false)
})
it('should validate token_limit correctly', () => {
expect(validationRules.token_limit(100)).toBe(true)
expect(validationRules.token_limit(1)).toBe(true)
expect(validationRules.token_limit(0)).toBe(true)
expect(validationRules.token_limit(-1)).toBe(false)
expect(validationRules.token_limit('100')).toBe(false)
})
it('should validate top_k correctly', () => {
expect(validationRules.top_k(0.5)).toBe(true)
expect(validationRules.top_k(1)).toBe(true)
expect(validationRules.top_k(0)).toBe(true)
expect(validationRules.top_k(-0.1)).toBe(false)
expect(validationRules.top_k(1.1)).toBe(false)
expect(validationRules.top_k('0.5')).toBe(false)
})
it('should validate top_p correctly', () => {
expect(validationRules.top_p(0.5)).toBe(true)
expect(validationRules.top_p(1)).toBe(true)
expect(validationRules.top_p(0)).toBe(true)
expect(validationRules.top_p(-0.1)).toBe(false)
expect(validationRules.top_p(1.1)).toBe(false)
expect(validationRules.top_p('0.5')).toBe(false)
})
it('should validate stream correctly', () => {
expect(validationRules.stream(true)).toBe(true)
expect(validationRules.stream(false)).toBe(true)
expect(validationRules.stream('true')).toBe(false)
expect(validationRules.stream(1)).toBe(false)
})
it('should validate max_tokens correctly', () => {
expect(validationRules.max_tokens(100)).toBe(true)
expect(validationRules.max_tokens(1)).toBe(true)
expect(validationRules.max_tokens(0)).toBe(true)
expect(validationRules.max_tokens(-1)).toBe(false)
expect(validationRules.max_tokens('100')).toBe(false)
})
it('should validate stop correctly', () => {
expect(validationRules.stop(['word1', 'word2'])).toBe(true)
expect(validationRules.stop([])).toBe(true)
expect(validationRules.stop(['word1', 2])).toBe(false)
expect(validationRules.stop('word1')).toBe(false)
})
it('should validate frequency_penalty correctly', () => {
expect(validationRules.frequency_penalty(0.5)).toBe(true)
expect(validationRules.frequency_penalty(1)).toBe(true)
expect(validationRules.frequency_penalty(0)).toBe(true)
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
expect(validationRules.frequency_penalty(1.1)).toBe(false)
expect(validationRules.frequency_penalty('0.5')).toBe(false)
})
it('should validate presence_penalty correctly', () => {
expect(validationRules.presence_penalty(0.5)).toBe(true)
expect(validationRules.presence_penalty(1)).toBe(true)
expect(validationRules.presence_penalty(0)).toBe(true)
expect(validationRules.presence_penalty(-0.1)).toBe(false)
expect(validationRules.presence_penalty(1.1)).toBe(false)
expect(validationRules.presence_penalty('0.5')).toBe(false)
})
it('should validate ctx_len correctly', () => {
expect(validationRules.ctx_len(1024)).toBe(true)
expect(validationRules.ctx_len(1)).toBe(true)
expect(validationRules.ctx_len(0)).toBe(true)
expect(validationRules.ctx_len(-1)).toBe(false)
expect(validationRules.ctx_len('1024')).toBe(false)
})
it('should validate ngl correctly', () => {
expect(validationRules.ngl(12)).toBe(true)
expect(validationRules.ngl(1)).toBe(true)
expect(validationRules.ngl(0)).toBe(true)
expect(validationRules.ngl(-1)).toBe(false)
expect(validationRules.ngl('12')).toBe(false)
})
it('should validate embedding correctly', () => {
expect(validationRules.embedding(true)).toBe(true)
expect(validationRules.embedding(false)).toBe(true)
expect(validationRules.embedding('true')).toBe(false)
expect(validationRules.embedding(1)).toBe(false)
})
it('should validate n_parallel correctly', () => {
expect(validationRules.n_parallel(2)).toBe(true)
expect(validationRules.n_parallel(1)).toBe(true)
expect(validationRules.n_parallel(0)).toBe(true)
expect(validationRules.n_parallel(-1)).toBe(false)
expect(validationRules.n_parallel('2')).toBe(false)
})
it('should validate cpu_threads correctly', () => {
expect(validationRules.cpu_threads(4)).toBe(true)
expect(validationRules.cpu_threads(1)).toBe(true)
expect(validationRules.cpu_threads(0)).toBe(true)
expect(validationRules.cpu_threads(-1)).toBe(false)
expect(validationRules.cpu_threads('4')).toBe(false)
})
it('should validate prompt_template correctly', () => {
expect(validationRules.prompt_template('template')).toBe(true)
expect(validationRules.prompt_template('')).toBe(true)
expect(validationRules.prompt_template(123)).toBe(false)
})
it('should validate llama_model_path correctly', () => {
expect(validationRules.llama_model_path('path')).toBe(true)
expect(validationRules.llama_model_path('')).toBe(true)
expect(validationRules.llama_model_path(123)).toBe(false)
})
it('should validate mmproj correctly', () => {
expect(validationRules.mmproj('mmproj')).toBe(true)
expect(validationRules.mmproj('')).toBe(true)
expect(validationRules.mmproj(123)).toBe(false)
})
it('should validate vision_model correctly', () => {
expect(validationRules.vision_model(true)).toBe(true)
expect(validationRules.vision_model(false)).toBe(true)
expect(validationRules.vision_model('true')).toBe(false)
expect(validationRules.vision_model(1)).toBe(false)
})
it('should validate text_model correctly', () => {
expect(validationRules.text_model(true)).toBe(true)
expect(validationRules.text_model(false)).toBe(true)
expect(validationRules.text_model('true')).toBe(false)
expect(validationRules.text_model(1)).toBe(false)
})
})
describe('normalizeValue', () => {
it('should normalize ctx_len correctly', () => {
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
expect(normalizeValue('ctx_len', '2')).toBe(2)
expect(normalizeValue('ctx_len', 100)).toBe(100)
})
it('should normalize token_limit correctly', () => {
expect(normalizeValue('token_limit', 100.5)).toBe(100)
expect(normalizeValue('token_limit', '1')).toBe(1)
expect(normalizeValue('token_limit', 0)).toBe(0)
})
it('should normalize max_tokens correctly', () => {
expect(normalizeValue('max_tokens', 100.5)).toBe(100)
expect(normalizeValue('max_tokens', '1')).toBe(1)
expect(normalizeValue('max_tokens', 0)).toBe(0)
})
it('should normalize ngl correctly', () => {
expect(normalizeValue('ngl', 12.5)).toBe(12)
expect(normalizeValue('ngl', '2')).toBe(2)
expect(normalizeValue('ngl', 0)).toBe(0)
})
it('should normalize n_parallel correctly', () => {
expect(normalizeValue('n_parallel', 2.5)).toBe(2)
expect(normalizeValue('n_parallel', '2')).toBe(2)
expect(normalizeValue('n_parallel', 0)).toBe(0)
})
it('should normalize cpu_threads correctly', () => {
expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
expect(normalizeValue('cpu_threads', '4')).toBe(4)
expect(normalizeValue('cpu_threads', 0)).toBe(0)
})
})

View File

@ -1,9 +1,69 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/helpers/atoms/Thread.atom'
export const toRuntimeParams = (
modelParams?: ModelParams
/**
* Validation rules for model parameters
*/
export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) =>
Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
ngl: (value: any) => Number.isInteger(value) && value >= 0,
embedding: (value: any) => typeof value === 'boolean',
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
prompt_template: (value: any) => typeof value === 'string',
llama_model_path: (value: any) => typeof value === 'string',
mmproj: (value: any) => typeof value === 'string',
vision_model: (value: any) => typeof value === 'boolean',
text_model: (value: any) => typeof value === 'boolean',
}
/**
* There are some parameters that need to be normalized before being sent to the server
* E.g. ctx_len should be an integer, but it can be a float from the input field
* @param key
* @param value
* @returns
*/
export const normalizeValue = (key: string, value: any) => {
if (
key === 'token_limit' ||
key === 'max_tokens' ||
key === 'ctx_len' ||
key === 'ngl' ||
key === 'n_parallel' ||
key === 'cpu_threads'
) {
// Convert to integer
return Math.floor(Number(value))
}
return value
}
/**
* Extract inference parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractInferenceParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelRuntimeParams => {
if (!modelParams) return {}
const defaultModelParams: ModelRuntimeParams = {
@ -22,15 +82,35 @@ export const toRuntimeParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultModelParams) {
Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(runtimeParams, {
...runtimeParams,
[key]: normalizeValue(key, value),
})
}
}
}
return runtimeParams
}
export const toSettingParams = (
modelParams?: ModelParams
/**
* Extract model load parameters from flat model parameters
* @param modelParams
* @returns
*/
export const extractModelLoadParams = (
modelParams?: ModelParams,
originParams?: ModelParams
): ModelSettingParams => {
if (!modelParams) return {}
const defaultSettingParams: ModelSettingParams = {
@ -49,7 +129,21 @@ export const toSettingParams = (
for (const [key, value] of Object.entries(modelParams)) {
if (key in defaultSettingParams) {
Object.assign(settingParams, { ...settingParams, [key]: value })
const validate = validationRules[key]
if (validate && !validate(normalizeValue(key, value))) {
// Invalid value - fall back to origin value
if (originParams && key in originParams) {
Object.assign(modelParams, {
...modelParams,
[key]: originParams[key as keyof typeof originParams],
})
}
} else {
Object.assign(settingParams, {
...settingParams,
[key]: normalizeValue(key, value),
})
}
}
}