test: add model parameter validation rules and persistence tests (#3618)
* test: add model parameter validation rules and persistence tests * chore: fix CI cov step * fix: invalid model settings should fallback to origin value * test: support fallback integer settings
This commit is contained in:
parent
3ffaa1ef7f
commit
98bef7b7cf
@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
||||
if (!settings?.ngl) {
|
||||
settings.ngl = 100
|
||||
}
|
||||
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
|
||||
log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
|
||||
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
||||
})
|
||||
.then((res) => {
|
||||
log(
|
||||
`[CORTEX]::Debug: Load model success with response ${JSON.stringify(
|
||||
`[CORTEX]:: Load model success with response ${JSON.stringify(
|
||||
res
|
||||
)}`
|
||||
)
|
||||
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
||||
async function validateModelStatus(modelId: string): Promise<void> {
|
||||
// Send a GET request to the validation URL.
|
||||
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
||||
log(`[CORTEX]::Debug: Validating model ${modelId}`)
|
||||
log(`[CORTEX]:: Validating model ${modelId}`)
|
||||
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({
|
||||
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
||||
retryDelay: 300,
|
||||
}).then(async (res: Response) => {
|
||||
log(
|
||||
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
|
||||
`[CORTEX]:: Validate model state with response ${JSON.stringify(
|
||||
res.status
|
||||
)}`
|
||||
)
|
||||
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
||||
// Otherwise, return an object with an error message.
|
||||
if (body.model_loaded) {
|
||||
log(
|
||||
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
|
||||
`[CORTEX]:: Validate model state success with response ${JSON.stringify(
|
||||
body
|
||||
)}`
|
||||
)
|
||||
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
||||
}
|
||||
const errorBody = await res.text()
|
||||
log(
|
||||
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
||||
`[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
||||
res.statusText
|
||||
)}`
|
||||
)
|
||||
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
||||
async function killSubprocess(): Promise<void> {
|
||||
const controller = new AbortController()
|
||||
setTimeout(() => controller.abort(), 5000)
|
||||
log(`[CORTEX]::Debug: Request to kill cortex`)
|
||||
log(`[CORTEX]:: Request to kill cortex`)
|
||||
|
||||
const killRequest = () => {
|
||||
return fetch(NITRO_HTTP_KILL_URL, {
|
||||
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
|
||||
.then(() =>
|
||||
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||
)
|
||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
||||
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||
.catch((err) => {
|
||||
log(
|
||||
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
||||
`[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
||||
)
|
||||
throw 'PORT_NOT_AVAILABLE'
|
||||
})
|
||||
}
|
||||
|
||||
if (subprocess?.pid && process.platform !== 'darwin') {
|
||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
||||
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||
const pid = subprocess.pid
|
||||
return new Promise((resolve, reject) => {
|
||||
terminate(pid, function (err) {
|
||||
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
|
||||
} else {
|
||||
tcpPortUsed
|
||||
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
||||
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||
.then(() => resolve())
|
||||
.catch(() => {
|
||||
log(
|
||||
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
|
||||
* @returns A promise that resolves when the Nitro subprocess is started.
|
||||
*/
|
||||
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||
log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
|
||||
log(`[CORTEX]:: Spawning cortex subprocess...`)
|
||||
|
||||
return new Promise<void>(async (resolve, reject) => {
|
||||
let executableOptions = executableNitroFile(
|
||||
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
||||
// Execute the binary
|
||||
log(
|
||||
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
||||
`[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
||||
)
|
||||
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
|
||||
|
||||
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||
|
||||
// Handle subprocess output
|
||||
subprocess.stdout.on('data', (data: any) => {
|
||||
log(`[CORTEX]::Debug: ${data}`)
|
||||
log(`[CORTEX]:: ${data}`)
|
||||
})
|
||||
|
||||
subprocess.stderr.on('data', (data: any) => {
|
||||
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||
})
|
||||
|
||||
subprocess.on('close', (code: any) => {
|
||||
log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
|
||||
log(`[CORTEX]:: cortex exited with code: ${code}`)
|
||||
subprocess = undefined
|
||||
reject(`child process exited with code ${code}`)
|
||||
})
|
||||
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||
tcpPortUsed
|
||||
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
|
||||
.then(() => {
|
||||
log(`[CORTEX]::Debug: cortex is ready`)
|
||||
log(`[CORTEX]:: cortex is ready`)
|
||||
resolve()
|
||||
})
|
||||
})
|
||||
|
||||
@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
|
||||
}
|
||||
|
||||
if (subprocess?.pid) {
|
||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
||||
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||
const pid = subprocess.pid
|
||||
return new Promise((resolve, reject) => {
|
||||
terminate(pid, function (err) {
|
||||
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
|
||||
return tcpPortUsed
|
||||
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
|
||||
.then(() => resolve())
|
||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
||||
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||
.catch(() => {
|
||||
killRequest()
|
||||
})
|
||||
|
||||
@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
|
||||
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { toRuntimeParams } from '@/utils/modelParam'
|
||||
import { extractInferenceParams } from '@/utils/modelParam'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
||||
},
|
||||
]
|
||||
|
||||
const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
|
||||
const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
|
||||
|
||||
const messageRequest: MessageRequest = {
|
||||
id: msgId,
|
||||
|
||||
@ -87,26 +87,28 @@ const SliderRightPanel = ({
|
||||
onValueChanged?.(Number(min))
|
||||
setVal(min.toString())
|
||||
setShowTooltip({ max: false, min: true })
|
||||
} else {
|
||||
setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
|
||||
}
|
||||
}}
|
||||
onChange={(e) => {
|
||||
// Should not accept invalid value or NaN
|
||||
// E.g. anything changes that trigger onValueChanged
|
||||
// Which is incorrect
|
||||
if (Number(e.target.value) > Number(max)) {
|
||||
setVal(max.toString())
|
||||
} else if (
|
||||
Number(e.target.value) < Number(min) ||
|
||||
!e.target.value.length
|
||||
) {
|
||||
setVal(min.toString())
|
||||
} else if (Number.isNaN(Number(e.target.value))) return
|
||||
|
||||
onValueChanged?.(Number(e.target.value))
|
||||
// TODO: How to support negative number input?
|
||||
// Passthru since it validates again onBlur
|
||||
if (/^\d*\.?\d*$/.test(e.target.value)) {
|
||||
setVal(e.target.value)
|
||||
}
|
||||
|
||||
// Should not accept invalid value or NaN
|
||||
// E.g. anything changes that trigger onValueChanged
|
||||
// Which is incorrect
|
||||
if (
|
||||
Number(e.target.value) > Number(max) ||
|
||||
Number(e.target.value) < Number(min) ||
|
||||
Number.isNaN(Number(e.target.value))
|
||||
) {
|
||||
return
|
||||
}
|
||||
onValueChanged?.(Number(e.target.value))
|
||||
}}
|
||||
/>
|
||||
}
|
||||
|
||||
@ -23,7 +23,10 @@ import {
|
||||
import { Stack } from '@/utils/Stack'
|
||||
import { compressImage, getBase64 } from '@/utils/base64'
|
||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
|
||||
@ -189,8 +192,8 @@ export default function useSendChatMessage() {
|
||||
|
||||
if (engineParamsUpdate) setReloadModel(true)
|
||||
|
||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
||||
const settingParams = toSettingParams(activeModelParams)
|
||||
const runtimeParams = extractInferenceParams(activeModelParams)
|
||||
const settingParams = extractModelLoadParams(activeModelParams)
|
||||
|
||||
const prompt = message.trim()
|
||||
|
||||
|
||||
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
@ -0,0 +1,314 @@
|
||||
import { renderHook, act } from '@testing-library/react'
|
||||
// Mock dependencies
|
||||
jest.mock('ulidx')
|
||||
jest.mock('@/extension')
|
||||
|
||||
import useUpdateModelParameters from './useUpdateModelParameters'
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
// Mock data
|
||||
let model: any = {
|
||||
id: 'model-1',
|
||||
engine: 'nitro',
|
||||
}
|
||||
|
||||
let extension: any = {
|
||||
saveThread: jest.fn(),
|
||||
}
|
||||
|
||||
const mockThread: any = {
|
||||
id: 'thread-1',
|
||||
assistants: [
|
||||
{
|
||||
model: {
|
||||
parameters: {},
|
||||
settings: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
created: 0,
|
||||
updated: 0,
|
||||
}
|
||||
|
||||
describe('useUpdateModelParameters', () => {
|
||||
beforeAll(() => {
|
||||
jest.clearAllMocks()
|
||||
jest.mock('./useRecommendedModel', () => ({
|
||||
useRecommendedModel: () => ({
|
||||
recommendedModel: model,
|
||||
setRecommendedModel: jest.fn(),
|
||||
downloadedModels: [],
|
||||
}),
|
||||
}))
|
||||
})
|
||||
|
||||
it('should update model parameters and save thread when params are valid', async () => {
|
||||
const mockValidParameters: any = {
|
||||
params: {
|
||||
// Inference
|
||||
stop: ['<eos>', '<eos2>'],
|
||||
temperature: 0.5,
|
||||
token_limit: 1000,
|
||||
top_k: 0.7,
|
||||
top_p: 0.1,
|
||||
stream: true,
|
||||
max_tokens: 1000,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.2,
|
||||
|
||||
// Load model
|
||||
ctx_len: 1024,
|
||||
ngl: 12,
|
||||
embedding: true,
|
||||
n_parallel: 2,
|
||||
cpu_threads: 4,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
vision_model: 'vision',
|
||||
text_model: 'text',
|
||||
},
|
||||
modelId: 'model-1',
|
||||
engine: 'nitro',
|
||||
}
|
||||
|
||||
// Spy functions
|
||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||
|
||||
const { result } = renderHook(() => useUpdateModelParameters())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.updateModelParameter(mockThread, mockValidParameters)
|
||||
})
|
||||
|
||||
// Check if the model parameters are valid before persisting
|
||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||
assistants: [
|
||||
{
|
||||
model: {
|
||||
parameters: {
|
||||
stop: ['<eos>', '<eos2>'],
|
||||
temperature: 0.5,
|
||||
token_limit: 1000,
|
||||
top_k: 0.7,
|
||||
top_p: 0.1,
|
||||
stream: true,
|
||||
max_tokens: 1000,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.2,
|
||||
},
|
||||
settings: {
|
||||
ctx_len: 1024,
|
||||
ngl: 12,
|
||||
embedding: true,
|
||||
n_parallel: 2,
|
||||
cpu_threads: 4,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
created: 0,
|
||||
id: 'thread-1',
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
updated: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it('should not update invalid model parameters', async () => {
|
||||
const mockInvalidParameters: any = {
|
||||
params: {
|
||||
// Inference
|
||||
stop: [1, '<eos>'],
|
||||
temperature: '0.5',
|
||||
token_limit: '1000',
|
||||
top_k: '0.7',
|
||||
top_p: '0.1',
|
||||
stream: 'true',
|
||||
max_tokens: '1000',
|
||||
frequency_penalty: '0.3',
|
||||
presence_penalty: '0.2',
|
||||
|
||||
// Load model
|
||||
ctx_len: '1024',
|
||||
ngl: '12',
|
||||
embedding: 'true',
|
||||
n_parallel: '2',
|
||||
cpu_threads: '4',
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
vision_model: 'vision',
|
||||
text_model: 'text',
|
||||
},
|
||||
modelId: 'model-1',
|
||||
engine: 'nitro',
|
||||
}
|
||||
|
||||
// Spy functions
|
||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||
|
||||
const { result } = renderHook(() => useUpdateModelParameters())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.updateModelParameter(
|
||||
mockThread,
|
||||
mockInvalidParameters
|
||||
)
|
||||
})
|
||||
|
||||
// Check if the model parameters are valid before persisting
|
||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||
assistants: [
|
||||
{
|
||||
model: {
|
||||
parameters: {
|
||||
max_tokens: 1000,
|
||||
token_limit: 1000,
|
||||
},
|
||||
settings: {
|
||||
cpu_threads: 4,
|
||||
ctx_len: 1024,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
n_parallel: 2,
|
||||
ngl: 12,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
created: 0,
|
||||
id: 'thread-1',
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
updated: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it('should update valid model parameters only', async () => {
|
||||
const mockInvalidParameters: any = {
|
||||
params: {
|
||||
// Inference
|
||||
stop: ['<eos>'],
|
||||
temperature: -0.5,
|
||||
token_limit: 100.2,
|
||||
top_k: 0.7,
|
||||
top_p: 0.1,
|
||||
stream: true,
|
||||
max_tokens: 1000,
|
||||
frequency_penalty: 1.2,
|
||||
presence_penalty: 0.2,
|
||||
|
||||
// Load model
|
||||
ctx_len: 1024,
|
||||
ngl: 0,
|
||||
embedding: 'true',
|
||||
n_parallel: 2,
|
||||
cpu_threads: 4,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
vision_model: 'vision',
|
||||
text_model: 'text',
|
||||
},
|
||||
modelId: 'model-1',
|
||||
engine: 'nitro',
|
||||
}
|
||||
|
||||
// Spy functions
|
||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||
|
||||
const { result } = renderHook(() => useUpdateModelParameters())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.updateModelParameter(
|
||||
mockThread,
|
||||
mockInvalidParameters
|
||||
)
|
||||
})
|
||||
|
||||
// Check if the model parameters are valid before persisting
|
||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||
assistants: [
|
||||
{
|
||||
model: {
|
||||
parameters: {
|
||||
stop: ['<eos>'],
|
||||
top_k: 0.7,
|
||||
top_p: 0.1,
|
||||
stream: true,
|
||||
token_limit: 100,
|
||||
max_tokens: 1000,
|
||||
presence_penalty: 0.2,
|
||||
},
|
||||
settings: {
|
||||
ctx_len: 1024,
|
||||
ngl: 0,
|
||||
n_parallel: 2,
|
||||
cpu_threads: 4,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: 'path',
|
||||
mmproj: 'mmproj',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
created: 0,
|
||||
id: 'thread-1',
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
updated: 0,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle missing modelId and engine gracefully', async () => {
|
||||
const mockParametersWithoutModelIdAndEngine: any = {
|
||||
params: {
|
||||
stop: ['<eos>', '<eos2>'],
|
||||
temperature: 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
// Spy functions
|
||||
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||
|
||||
const { result } = renderHook(() => useUpdateModelParameters())
|
||||
|
||||
await act(async () => {
|
||||
await result.current.updateModelParameter(
|
||||
mockThread,
|
||||
mockParametersWithoutModelIdAndEngine
|
||||
)
|
||||
})
|
||||
|
||||
// Check if the model parameters are valid before persisting
|
||||
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||
assistants: [
|
||||
{
|
||||
model: {
|
||||
parameters: {
|
||||
stop: ['<eos>', '<eos2>'],
|
||||
temperature: 0.5,
|
||||
},
|
||||
settings: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
created: 0,
|
||||
id: 'thread-1',
|
||||
object: 'thread',
|
||||
title: 'New Thread',
|
||||
updated: 0,
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -12,7 +12,10 @@ import {
|
||||
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import useRecommendedModel from './useRecommendedModel'
|
||||
|
||||
@ -47,12 +50,17 @@ export default function useUpdateModelParameters() {
|
||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||
const updatedModelParams = settings.modelId
|
||||
? toUpdateSettings
|
||||
: { ...activeModelParams, ...toUpdateSettings }
|
||||
: {
|
||||
...selectedModel?.parameters,
|
||||
...selectedModel?.settings,
|
||||
...activeModelParams,
|
||||
...toUpdateSettings,
|
||||
}
|
||||
|
||||
// update the state
|
||||
setThreadModelParams(thread.id, updatedModelParams)
|
||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
||||
const settingParams = toSettingParams(updatedModelParams)
|
||||
const runtimeParams = extractInferenceParams(updatedModelParams)
|
||||
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||
|
||||
const assistants = thread.assistants.map(
|
||||
(assistant: ThreadAssistantInfo) => {
|
||||
|
||||
@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
|
||||
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
|
||||
toSettingParams(selectedModel?.settings)
|
||||
extractModelLoadParams(selectedModel?.settings)
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedModel) {
|
||||
setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
|
||||
setCurrentModelSettingParams(
|
||||
extractModelLoadParams(selectedModel?.settings)
|
||||
)
|
||||
}
|
||||
}, [selectedModel])
|
||||
|
||||
const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
|
||||
const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
|
||||
|
||||
const componentDataRuntimeSetting = getConfigurationsData(
|
||||
modelRuntimeParams,
|
||||
|
||||
@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
import { localEngines } from '@/utils/modelEngine'
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import PromptTemplateSetting from './PromptTemplateSetting'
|
||||
import Tools from './Tools'
|
||||
@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
|
||||
|
||||
const settings = useMemo(() => {
|
||||
// runtime setting
|
||||
const modelRuntimeParams = toRuntimeParams(activeModelParams)
|
||||
const modelRuntimeParams = extractInferenceParams(
|
||||
{
|
||||
...selectedModel?.parameters,
|
||||
...activeModelParams,
|
||||
},
|
||||
selectedModel?.parameters
|
||||
)
|
||||
const componentDataRuntimeSetting = getConfigurationsData(
|
||||
modelRuntimeParams,
|
||||
selectedModel
|
||||
).filter((x) => x.key !== 'prompt_template')
|
||||
|
||||
// engine setting
|
||||
const modelEngineParams = toSettingParams(activeModelParams)
|
||||
const modelEngineParams = extractModelLoadParams(
|
||||
{
|
||||
...selectedModel?.settings,
|
||||
...activeModelParams,
|
||||
},
|
||||
selectedModel?.settings
|
||||
)
|
||||
const componentDataEngineSetting = getConfigurationsData(
|
||||
modelEngineParams,
|
||||
selectedModel
|
||||
@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
|
||||
}, [activeModelParams, selectedModel])
|
||||
|
||||
const promptTemplateSettings = useMemo(() => {
|
||||
const modelEngineParams = toSettingParams(activeModelParams)
|
||||
const modelEngineParams = extractModelLoadParams({
|
||||
...selectedModel?.settings,
|
||||
...activeModelParams,
|
||||
})
|
||||
const componentDataEngineSetting = getConfigurationsData(
|
||||
modelEngineParams,
|
||||
selectedModel
|
||||
|
||||
183
web/utils/modelParam.test.ts
Normal file
183
web/utils/modelParam.test.ts
Normal file
@ -0,0 +1,183 @@
|
||||
// web/utils/modelParam.test.ts
|
||||
import { normalizeValue, validationRules } from './modelParam'
|
||||
|
||||
describe('validationRules', () => {
|
||||
it('should validate temperature correctly', () => {
|
||||
expect(validationRules.temperature(0.5)).toBe(true)
|
||||
expect(validationRules.temperature(2)).toBe(true)
|
||||
expect(validationRules.temperature(0)).toBe(true)
|
||||
expect(validationRules.temperature(-0.1)).toBe(false)
|
||||
expect(validationRules.temperature(2.3)).toBe(false)
|
||||
expect(validationRules.temperature('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate token_limit correctly', () => {
|
||||
expect(validationRules.token_limit(100)).toBe(true)
|
||||
expect(validationRules.token_limit(1)).toBe(true)
|
||||
expect(validationRules.token_limit(0)).toBe(true)
|
||||
expect(validationRules.token_limit(-1)).toBe(false)
|
||||
expect(validationRules.token_limit('100')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate top_k correctly', () => {
|
||||
expect(validationRules.top_k(0.5)).toBe(true)
|
||||
expect(validationRules.top_k(1)).toBe(true)
|
||||
expect(validationRules.top_k(0)).toBe(true)
|
||||
expect(validationRules.top_k(-0.1)).toBe(false)
|
||||
expect(validationRules.top_k(1.1)).toBe(false)
|
||||
expect(validationRules.top_k('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate top_p correctly', () => {
|
||||
expect(validationRules.top_p(0.5)).toBe(true)
|
||||
expect(validationRules.top_p(1)).toBe(true)
|
||||
expect(validationRules.top_p(0)).toBe(true)
|
||||
expect(validationRules.top_p(-0.1)).toBe(false)
|
||||
expect(validationRules.top_p(1.1)).toBe(false)
|
||||
expect(validationRules.top_p('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate stream correctly', () => {
|
||||
expect(validationRules.stream(true)).toBe(true)
|
||||
expect(validationRules.stream(false)).toBe(true)
|
||||
expect(validationRules.stream('true')).toBe(false)
|
||||
expect(validationRules.stream(1)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate max_tokens correctly', () => {
|
||||
expect(validationRules.max_tokens(100)).toBe(true)
|
||||
expect(validationRules.max_tokens(1)).toBe(true)
|
||||
expect(validationRules.max_tokens(0)).toBe(true)
|
||||
expect(validationRules.max_tokens(-1)).toBe(false)
|
||||
expect(validationRules.max_tokens('100')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate stop correctly', () => {
|
||||
expect(validationRules.stop(['word1', 'word2'])).toBe(true)
|
||||
expect(validationRules.stop([])).toBe(true)
|
||||
expect(validationRules.stop(['word1', 2])).toBe(false)
|
||||
expect(validationRules.stop('word1')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate frequency_penalty correctly', () => {
|
||||
expect(validationRules.frequency_penalty(0.5)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(1)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(0)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
|
||||
expect(validationRules.frequency_penalty(1.1)).toBe(false)
|
||||
expect(validationRules.frequency_penalty('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate presence_penalty correctly', () => {
|
||||
expect(validationRules.presence_penalty(0.5)).toBe(true)
|
||||
expect(validationRules.presence_penalty(1)).toBe(true)
|
||||
expect(validationRules.presence_penalty(0)).toBe(true)
|
||||
expect(validationRules.presence_penalty(-0.1)).toBe(false)
|
||||
expect(validationRules.presence_penalty(1.1)).toBe(false)
|
||||
expect(validationRules.presence_penalty('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate ctx_len correctly', () => {
|
||||
expect(validationRules.ctx_len(1024)).toBe(true)
|
||||
expect(validationRules.ctx_len(1)).toBe(true)
|
||||
expect(validationRules.ctx_len(0)).toBe(true)
|
||||
expect(validationRules.ctx_len(-1)).toBe(false)
|
||||
expect(validationRules.ctx_len('1024')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate ngl correctly', () => {
|
||||
expect(validationRules.ngl(12)).toBe(true)
|
||||
expect(validationRules.ngl(1)).toBe(true)
|
||||
expect(validationRules.ngl(0)).toBe(true)
|
||||
expect(validationRules.ngl(-1)).toBe(false)
|
||||
expect(validationRules.ngl('12')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate embedding correctly', () => {
|
||||
expect(validationRules.embedding(true)).toBe(true)
|
||||
expect(validationRules.embedding(false)).toBe(true)
|
||||
expect(validationRules.embedding('true')).toBe(false)
|
||||
expect(validationRules.embedding(1)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate n_parallel correctly', () => {
|
||||
expect(validationRules.n_parallel(2)).toBe(true)
|
||||
expect(validationRules.n_parallel(1)).toBe(true)
|
||||
expect(validationRules.n_parallel(0)).toBe(true)
|
||||
expect(validationRules.n_parallel(-1)).toBe(false)
|
||||
expect(validationRules.n_parallel('2')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate cpu_threads correctly', () => {
|
||||
expect(validationRules.cpu_threads(4)).toBe(true)
|
||||
expect(validationRules.cpu_threads(1)).toBe(true)
|
||||
expect(validationRules.cpu_threads(0)).toBe(true)
|
||||
expect(validationRules.cpu_threads(-1)).toBe(false)
|
||||
expect(validationRules.cpu_threads('4')).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate prompt_template correctly', () => {
|
||||
expect(validationRules.prompt_template('template')).toBe(true)
|
||||
expect(validationRules.prompt_template('')).toBe(true)
|
||||
expect(validationRules.prompt_template(123)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate llama_model_path correctly', () => {
|
||||
expect(validationRules.llama_model_path('path')).toBe(true)
|
||||
expect(validationRules.llama_model_path('')).toBe(true)
|
||||
expect(validationRules.llama_model_path(123)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate mmproj correctly', () => {
|
||||
expect(validationRules.mmproj('mmproj')).toBe(true)
|
||||
expect(validationRules.mmproj('')).toBe(true)
|
||||
expect(validationRules.mmproj(123)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate vision_model correctly', () => {
|
||||
expect(validationRules.vision_model(true)).toBe(true)
|
||||
expect(validationRules.vision_model(false)).toBe(true)
|
||||
expect(validationRules.vision_model('true')).toBe(false)
|
||||
expect(validationRules.vision_model(1)).toBe(false)
|
||||
})
|
||||
|
||||
it('should validate text_model correctly', () => {
|
||||
expect(validationRules.text_model(true)).toBe(true)
|
||||
expect(validationRules.text_model(false)).toBe(true)
|
||||
expect(validationRules.text_model('true')).toBe(false)
|
||||
expect(validationRules.text_model(1)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('normalizeValue', () => {
|
||||
it('should normalize ctx_len correctly', () => {
|
||||
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
|
||||
expect(normalizeValue('ctx_len', '2')).toBe(2)
|
||||
expect(normalizeValue('ctx_len', 100)).toBe(100)
|
||||
})
|
||||
it('should normalize token_limit correctly', () => {
|
||||
expect(normalizeValue('token_limit', 100.5)).toBe(100)
|
||||
expect(normalizeValue('token_limit', '1')).toBe(1)
|
||||
expect(normalizeValue('token_limit', 0)).toBe(0)
|
||||
})
|
||||
it('should normalize max_tokens correctly', () => {
|
||||
expect(normalizeValue('max_tokens', 100.5)).toBe(100)
|
||||
expect(normalizeValue('max_tokens', '1')).toBe(1)
|
||||
expect(normalizeValue('max_tokens', 0)).toBe(0)
|
||||
})
|
||||
it('should normalize ngl correctly', () => {
|
||||
expect(normalizeValue('ngl', 12.5)).toBe(12)
|
||||
expect(normalizeValue('ngl', '2')).toBe(2)
|
||||
expect(normalizeValue('ngl', 0)).toBe(0)
|
||||
})
|
||||
it('should normalize n_parallel correctly', () => {
|
||||
expect(normalizeValue('n_parallel', 2.5)).toBe(2)
|
||||
expect(normalizeValue('n_parallel', '2')).toBe(2)
|
||||
expect(normalizeValue('n_parallel', 0)).toBe(0)
|
||||
})
|
||||
it('should normalize cpu_threads correctly', () => {
|
||||
expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
|
||||
expect(normalizeValue('cpu_threads', '4')).toBe(4)
|
||||
expect(normalizeValue('cpu_threads', 0)).toBe(0)
|
||||
})
|
||||
})
|
||||
@ -1,9 +1,69 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||
|
||||
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const toRuntimeParams = (
|
||||
modelParams?: ModelParams
|
||||
/**
|
||||
* Validation rules for model parameters
|
||||
*/
|
||||
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||
temperature: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 2,
|
||||
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
stream: (value: any) => typeof value === 'boolean',
|
||||
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
stop: (value: any) =>
|
||||
Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||
frequency_penalty: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 1,
|
||||
presence_penalty: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 1,
|
||||
|
||||
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
embedding: (value: any) => typeof value === 'boolean',
|
||||
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
prompt_template: (value: any) => typeof value === 'string',
|
||||
llama_model_path: (value: any) => typeof value === 'string',
|
||||
mmproj: (value: any) => typeof value === 'string',
|
||||
vision_model: (value: any) => typeof value === 'boolean',
|
||||
text_model: (value: any) => typeof value === 'boolean',
|
||||
}
|
||||
|
||||
/**
|
||||
* There are some parameters that need to be normalized before being sent to the server
|
||||
* E.g. ctx_len should be an integer, but it can be a float from the input field
|
||||
* @param key
|
||||
* @param value
|
||||
* @returns
|
||||
*/
|
||||
export const normalizeValue = (key: string, value: any) => {
|
||||
if (
|
||||
key === 'token_limit' ||
|
||||
key === 'max_tokens' ||
|
||||
key === 'ctx_len' ||
|
||||
key === 'ngl' ||
|
||||
key === 'n_parallel' ||
|
||||
key === 'cpu_threads'
|
||||
) {
|
||||
// Convert to integer
|
||||
return Math.floor(Number(value))
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract inference parameters from flat model parameters
|
||||
* @param modelParams
|
||||
* @returns
|
||||
*/
|
||||
export const extractInferenceParams = (
|
||||
modelParams?: ModelParams,
|
||||
originParams?: ModelParams
|
||||
): ModelRuntimeParams => {
|
||||
if (!modelParams) return {}
|
||||
const defaultModelParams: ModelRuntimeParams = {
|
||||
@ -22,15 +82,35 @@ export const toRuntimeParams = (
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultModelParams) {
|
||||
Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
|
||||
const validate = validationRules[key]
|
||||
if (validate && !validate(normalizeValue(key, value))) {
|
||||
// Invalid value - fall back to origin value
|
||||
if (originParams && key in originParams) {
|
||||
Object.assign(runtimeParams, {
|
||||
...runtimeParams,
|
||||
[key]: originParams[key as keyof typeof originParams],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Object.assign(runtimeParams, {
|
||||
...runtimeParams,
|
||||
[key]: normalizeValue(key, value),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return runtimeParams
|
||||
}
|
||||
|
||||
export const toSettingParams = (
|
||||
modelParams?: ModelParams
|
||||
/**
|
||||
* Extract model load parameters from flat model parameters
|
||||
* @param modelParams
|
||||
* @returns
|
||||
*/
|
||||
export const extractModelLoadParams = (
|
||||
modelParams?: ModelParams,
|
||||
originParams?: ModelParams
|
||||
): ModelSettingParams => {
|
||||
if (!modelParams) return {}
|
||||
const defaultSettingParams: ModelSettingParams = {
|
||||
@ -49,7 +129,21 @@ export const toSettingParams = (
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultSettingParams) {
|
||||
Object.assign(settingParams, { ...settingParams, [key]: value })
|
||||
const validate = validationRules[key]
|
||||
if (validate && !validate(normalizeValue(key, value))) {
|
||||
// Invalid value - fall back to origin value
|
||||
if (originParams && key in originParams) {
|
||||
Object.assign(modelParams, {
|
||||
...modelParams,
|
||||
[key]: originParams[key as keyof typeof originParams],
|
||||
})
|
||||
}
|
||||
} else {
|
||||
Object.assign(settingParams, {
|
||||
...settingParams,
|
||||
[key]: normalizeValue(key, value),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user