test: add model parameter validation rules and persistence tests (#3618)
* test: add model parameter validation rules and persistence tests * chore: fix CI cov step * fix: invalid model settings should fallback to origin value * test: support fallback integer settings
This commit is contained in:
parent
3ffaa1ef7f
commit
98bef7b7cf
@ -227,7 +227,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
if (!settings?.ngl) {
|
if (!settings?.ngl) {
|
||||||
settings.ngl = 100
|
settings.ngl = 100
|
||||||
}
|
}
|
||||||
log(`[CORTEX]::Debug: Loading model with params ${JSON.stringify(settings)}`)
|
log(`[CORTEX]:: Loading model with params ${JSON.stringify(settings)}`)
|
||||||
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
|
return fetchRetry(NITRO_HTTP_LOAD_MODEL_URL, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
@ -239,7 +239,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
})
|
})
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Load model success with response ${JSON.stringify(
|
`[CORTEX]:: Load model success with response ${JSON.stringify(
|
||||||
res
|
res
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -260,7 +260,7 @@ function loadLLMModel(settings: any): Promise<Response> {
|
|||||||
async function validateModelStatus(modelId: string): Promise<void> {
|
async function validateModelStatus(modelId: string): Promise<void> {
|
||||||
// Send a GET request to the validation URL.
|
// Send a GET request to the validation URL.
|
||||||
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
// Retry the request up to 3 times if it fails, with a delay of 500 milliseconds between retries.
|
||||||
log(`[CORTEX]::Debug: Validating model ${modelId}`)
|
log(`[CORTEX]:: Validating model ${modelId}`)
|
||||||
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
return fetchRetry(NITRO_HTTP_VALIDATE_MODEL_URL, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
@ -275,7 +275,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
retryDelay: 300,
|
retryDelay: 300,
|
||||||
}).then(async (res: Response) => {
|
}).then(async (res: Response) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state with response ${JSON.stringify(
|
`[CORTEX]:: Validate model state with response ${JSON.stringify(
|
||||||
res.status
|
res.status
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -286,7 +286,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
// Otherwise, return an object with an error message.
|
// Otherwise, return an object with an error message.
|
||||||
if (body.model_loaded) {
|
if (body.model_loaded) {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state success with response ${JSON.stringify(
|
`[CORTEX]:: Validate model state success with response ${JSON.stringify(
|
||||||
body
|
body
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -295,7 +295,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
}
|
}
|
||||||
const errorBody = await res.text()
|
const errorBody = await res.text()
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
`[CORTEX]:: Validate model state failed with response ${errorBody} and status is ${JSON.stringify(
|
||||||
res.statusText
|
res.statusText
|
||||||
)}`
|
)}`
|
||||||
)
|
)
|
||||||
@ -310,7 +310,7 @@ async function validateModelStatus(modelId: string): Promise<void> {
|
|||||||
async function killSubprocess(): Promise<void> {
|
async function killSubprocess(): Promise<void> {
|
||||||
const controller = new AbortController()
|
const controller = new AbortController()
|
||||||
setTimeout(() => controller.abort(), 5000)
|
setTimeout(() => controller.abort(), 5000)
|
||||||
log(`[CORTEX]::Debug: Request to kill cortex`)
|
log(`[CORTEX]:: Request to kill cortex`)
|
||||||
|
|
||||||
const killRequest = () => {
|
const killRequest = () => {
|
||||||
return fetch(NITRO_HTTP_KILL_URL, {
|
return fetch(NITRO_HTTP_KILL_URL, {
|
||||||
@ -321,17 +321,17 @@ async function killSubprocess(): Promise<void> {
|
|||||||
.then(() =>
|
.then(() =>
|
||||||
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
tcpPortUsed.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||||
)
|
)
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
`[CORTEX]:: Could not kill running process on port ${PORT}. Might be another process running on the same port? ${err}`
|
||||||
)
|
)
|
||||||
throw 'PORT_NOT_AVAILABLE'
|
throw 'PORT_NOT_AVAILABLE'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (subprocess?.pid && process.platform !== 'darwin') {
|
if (subprocess?.pid && process.platform !== 'darwin') {
|
||||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||||
const pid = subprocess.pid
|
const pid = subprocess.pid
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
terminate(pid, function (err) {
|
terminate(pid, function (err) {
|
||||||
@ -341,7 +341,7 @@ async function killSubprocess(): Promise<void> {
|
|||||||
} else {
|
} else {
|
||||||
tcpPortUsed
|
tcpPortUsed
|
||||||
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
.waitUntilFree(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 5000)
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.then(() => resolve())
|
.then(() => resolve())
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
log(
|
log(
|
||||||
@ -362,7 +362,7 @@ async function killSubprocess(): Promise<void> {
|
|||||||
* @returns A promise that resolves when the Nitro subprocess is started.
|
* @returns A promise that resolves when the Nitro subprocess is started.
|
||||||
*/
|
*/
|
||||||
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
||||||
log(`[CORTEX]::Debug: Spawning cortex subprocess...`)
|
log(`[CORTEX]:: Spawning cortex subprocess...`)
|
||||||
|
|
||||||
return new Promise<void>(async (resolve, reject) => {
|
return new Promise<void>(async (resolve, reject) => {
|
||||||
let executableOptions = executableNitroFile(
|
let executableOptions = executableNitroFile(
|
||||||
@ -381,7 +381,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
const args: string[] = ['1', LOCAL_HOST, PORT.toString()]
|
||||||
// Execute the binary
|
// Execute the binary
|
||||||
log(
|
log(
|
||||||
`[CORTEX]::Debug: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
`[CORTEX]:: Spawn cortex at path: ${executableOptions.executablePath}, and args: ${args}`
|
||||||
)
|
)
|
||||||
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
|
log(`[CORTEX]::Debug: Cortex engine path: ${executableOptions.enginePath}`)
|
||||||
|
|
||||||
@ -415,7 +415,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
|
|
||||||
// Handle subprocess output
|
// Handle subprocess output
|
||||||
subprocess.stdout.on('data', (data: any) => {
|
subprocess.stdout.on('data', (data: any) => {
|
||||||
log(`[CORTEX]::Debug: ${data}`)
|
log(`[CORTEX]:: ${data}`)
|
||||||
})
|
})
|
||||||
|
|
||||||
subprocess.stderr.on('data', (data: any) => {
|
subprocess.stderr.on('data', (data: any) => {
|
||||||
@ -423,7 +423,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
})
|
})
|
||||||
|
|
||||||
subprocess.on('close', (code: any) => {
|
subprocess.on('close', (code: any) => {
|
||||||
log(`[CORTEX]::Debug: cortex exited with code: ${code}`)
|
log(`[CORTEX]:: cortex exited with code: ${code}`)
|
||||||
subprocess = undefined
|
subprocess = undefined
|
||||||
reject(`child process exited with code ${code}`)
|
reject(`child process exited with code ${code}`)
|
||||||
})
|
})
|
||||||
@ -431,7 +431,7 @@ function spawnNitroProcess(systemInfo?: SystemInformation): Promise<any> {
|
|||||||
tcpPortUsed
|
tcpPortUsed
|
||||||
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
|
.waitUntilUsed(PORT, NITRO_PORT_FREE_CHECK_INTERVAL, 30000)
|
||||||
.then(() => {
|
.then(() => {
|
||||||
log(`[CORTEX]::Debug: cortex is ready`)
|
log(`[CORTEX]:: cortex is ready`)
|
||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -97,7 +97,7 @@ function unloadModel(): Promise<void> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (subprocess?.pid) {
|
if (subprocess?.pid) {
|
||||||
log(`[CORTEX]::Debug: Killing PID ${subprocess.pid}`)
|
log(`[CORTEX]:: Killing PID ${subprocess.pid}`)
|
||||||
const pid = subprocess.pid
|
const pid = subprocess.pid
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
terminate(pid, function (err) {
|
terminate(pid, function (err) {
|
||||||
@ -107,7 +107,7 @@ function unloadModel(): Promise<void> {
|
|||||||
return tcpPortUsed
|
return tcpPortUsed
|
||||||
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
|
.waitUntilFree(parseInt(ENGINE_PORT), PORT_CHECK_INTERVAL, 5000)
|
||||||
.then(() => resolve())
|
.then(() => resolve())
|
||||||
.then(() => log(`[CORTEX]::Debug: cortex process is terminated`))
|
.then(() => log(`[CORTEX]:: cortex process is terminated`))
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
killRequest()
|
killRequest()
|
||||||
})
|
})
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import { ulid } from 'ulidx'
|
|||||||
|
|
||||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||||
|
|
||||||
import { toRuntimeParams } from '@/utils/modelParam'
|
import { extractInferenceParams } from '@/utils/modelParam'
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import {
|
||||||
@ -256,7 +256,7 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
const runtimeParams = toRuntimeParams(activeModelParamsRef.current)
|
const runtimeParams = extractInferenceParams(activeModelParamsRef.current)
|
||||||
|
|
||||||
const messageRequest: MessageRequest = {
|
const messageRequest: MessageRequest = {
|
||||||
id: msgId,
|
id: msgId,
|
||||||
|
|||||||
@ -87,26 +87,28 @@ const SliderRightPanel = ({
|
|||||||
onValueChanged?.(Number(min))
|
onValueChanged?.(Number(min))
|
||||||
setVal(min.toString())
|
setVal(min.toString())
|
||||||
setShowTooltip({ max: false, min: true })
|
setShowTooltip({ max: false, min: true })
|
||||||
|
} else {
|
||||||
|
setVal(Number(e.target.value).toString()) // There is a case .5 but not 0.5
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
// Should not accept invalid value or NaN
|
|
||||||
// E.g. anything changes that trigger onValueChanged
|
|
||||||
// Which is incorrect
|
|
||||||
if (Number(e.target.value) > Number(max)) {
|
|
||||||
setVal(max.toString())
|
|
||||||
} else if (
|
|
||||||
Number(e.target.value) < Number(min) ||
|
|
||||||
!e.target.value.length
|
|
||||||
) {
|
|
||||||
setVal(min.toString())
|
|
||||||
} else if (Number.isNaN(Number(e.target.value))) return
|
|
||||||
|
|
||||||
onValueChanged?.(Number(e.target.value))
|
|
||||||
// TODO: How to support negative number input?
|
// TODO: How to support negative number input?
|
||||||
|
// Passthru since it validates again onBlur
|
||||||
if (/^\d*\.?\d*$/.test(e.target.value)) {
|
if (/^\d*\.?\d*$/.test(e.target.value)) {
|
||||||
setVal(e.target.value)
|
setVal(e.target.value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Should not accept invalid value or NaN
|
||||||
|
// E.g. anything changes that trigger onValueChanged
|
||||||
|
// Which is incorrect
|
||||||
|
if (
|
||||||
|
Number(e.target.value) > Number(max) ||
|
||||||
|
Number(e.target.value) < Number(min) ||
|
||||||
|
Number.isNaN(Number(e.target.value))
|
||||||
|
) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
onValueChanged?.(Number(e.target.value))
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,10 @@ import {
|
|||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
import { compressImage, getBase64 } from '@/utils/base64'
|
import { compressImage, getBase64 } from '@/utils/base64'
|
||||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||||
|
|
||||||
@ -189,8 +192,8 @@ export default function useSendChatMessage() {
|
|||||||
|
|
||||||
if (engineParamsUpdate) setReloadModel(true)
|
if (engineParamsUpdate) setReloadModel(true)
|
||||||
|
|
||||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
const runtimeParams = extractInferenceParams(activeModelParams)
|
||||||
const settingParams = toSettingParams(activeModelParams)
|
const settingParams = extractModelLoadParams(activeModelParams)
|
||||||
|
|
||||||
const prompt = message.trim()
|
const prompt = message.trim()
|
||||||
|
|
||||||
|
|||||||
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
314
web/hooks/useUpdateModelParameters.test.ts
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
import { renderHook, act } from '@testing-library/react'
|
||||||
|
// Mock dependencies
|
||||||
|
jest.mock('ulidx')
|
||||||
|
jest.mock('@/extension')
|
||||||
|
|
||||||
|
import useUpdateModelParameters from './useUpdateModelParameters'
|
||||||
|
import { extensionManager } from '@/extension'
|
||||||
|
|
||||||
|
// Mock data
|
||||||
|
let model: any = {
|
||||||
|
id: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
let extension: any = {
|
||||||
|
saveThread: jest.fn(),
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockThread: any = {
|
||||||
|
id: 'thread-1',
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {},
|
||||||
|
settings: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
created: 0,
|
||||||
|
updated: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('useUpdateModelParameters', () => {
|
||||||
|
beforeAll(() => {
|
||||||
|
jest.clearAllMocks()
|
||||||
|
jest.mock('./useRecommendedModel', () => ({
|
||||||
|
useRecommendedModel: () => ({
|
||||||
|
recommendedModel: model,
|
||||||
|
setRecommendedModel: jest.fn(),
|
||||||
|
downloadedModels: [],
|
||||||
|
}),
|
||||||
|
}))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update model parameters and save thread when params are valid', async () => {
|
||||||
|
const mockValidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
token_limit: 1000,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 0.3,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 12,
|
||||||
|
embedding: true,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(mockThread, mockValidParameters)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
token_limit: 1000,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 0.3,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 12,
|
||||||
|
embedding: true,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not update invalid model parameters', async () => {
|
||||||
|
const mockInvalidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: [1, '<eos>'],
|
||||||
|
temperature: '0.5',
|
||||||
|
token_limit: '1000',
|
||||||
|
top_k: '0.7',
|
||||||
|
top_p: '0.1',
|
||||||
|
stream: 'true',
|
||||||
|
max_tokens: '1000',
|
||||||
|
frequency_penalty: '0.3',
|
||||||
|
presence_penalty: '0.2',
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: '1024',
|
||||||
|
ngl: '12',
|
||||||
|
embedding: 'true',
|
||||||
|
n_parallel: '2',
|
||||||
|
cpu_threads: '4',
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockInvalidParameters
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
max_tokens: 1000,
|
||||||
|
token_limit: 1000,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
cpu_threads: 4,
|
||||||
|
ctx_len: 1024,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
n_parallel: 2,
|
||||||
|
ngl: 12,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update valid model parameters only', async () => {
|
||||||
|
const mockInvalidParameters: any = {
|
||||||
|
params: {
|
||||||
|
// Inference
|
||||||
|
stop: ['<eos>'],
|
||||||
|
temperature: -0.5,
|
||||||
|
token_limit: 100.2,
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
max_tokens: 1000,
|
||||||
|
frequency_penalty: 1.2,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 0,
|
||||||
|
embedding: 'true',
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
vision_model: 'vision',
|
||||||
|
text_model: 'text',
|
||||||
|
},
|
||||||
|
modelId: 'model-1',
|
||||||
|
engine: 'nitro',
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockInvalidParameters
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>'],
|
||||||
|
top_k: 0.7,
|
||||||
|
top_p: 0.1,
|
||||||
|
stream: true,
|
||||||
|
token_limit: 100,
|
||||||
|
max_tokens: 1000,
|
||||||
|
presence_penalty: 0.2,
|
||||||
|
},
|
||||||
|
settings: {
|
||||||
|
ctx_len: 1024,
|
||||||
|
ngl: 0,
|
||||||
|
n_parallel: 2,
|
||||||
|
cpu_threads: 4,
|
||||||
|
prompt_template: 'template',
|
||||||
|
llama_model_path: 'path',
|
||||||
|
mmproj: 'mmproj',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle missing modelId and engine gracefully', async () => {
|
||||||
|
const mockParametersWithoutModelIdAndEngine: any = {
|
||||||
|
params: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Spy functions
|
||||||
|
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
|
||||||
|
jest.spyOn(extension, 'saveThread').mockReturnValue({})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useUpdateModelParameters())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.updateModelParameter(
|
||||||
|
mockThread,
|
||||||
|
mockParametersWithoutModelIdAndEngine
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check if the model parameters are valid before persisting
|
||||||
|
expect(extension.saveThread).toHaveBeenCalledWith({
|
||||||
|
assistants: [
|
||||||
|
{
|
||||||
|
model: {
|
||||||
|
parameters: {
|
||||||
|
stop: ['<eos>', '<eos2>'],
|
||||||
|
temperature: 0.5,
|
||||||
|
},
|
||||||
|
settings: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
created: 0,
|
||||||
|
id: 'thread-1',
|
||||||
|
object: 'thread',
|
||||||
|
title: 'New Thread',
|
||||||
|
updated: 0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -12,7 +12,10 @@ import {
|
|||||||
|
|
||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import useRecommendedModel from './useRecommendedModel'
|
import useRecommendedModel from './useRecommendedModel'
|
||||||
|
|
||||||
@ -47,12 +50,17 @@ export default function useUpdateModelParameters() {
|
|||||||
const toUpdateSettings = processStopWords(settings.params ?? {})
|
const toUpdateSettings = processStopWords(settings.params ?? {})
|
||||||
const updatedModelParams = settings.modelId
|
const updatedModelParams = settings.modelId
|
||||||
? toUpdateSettings
|
? toUpdateSettings
|
||||||
: { ...activeModelParams, ...toUpdateSettings }
|
: {
|
||||||
|
...selectedModel?.parameters,
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
...toUpdateSettings,
|
||||||
|
}
|
||||||
|
|
||||||
// update the state
|
// update the state
|
||||||
setThreadModelParams(thread.id, updatedModelParams)
|
setThreadModelParams(thread.id, updatedModelParams)
|
||||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
const runtimeParams = extractInferenceParams(updatedModelParams)
|
||||||
const settingParams = toSettingParams(updatedModelParams)
|
const settingParams = extractModelLoadParams(updatedModelParams)
|
||||||
|
|
||||||
const assistants = thread.assistants.map(
|
const assistants = thread.assistants.map(
|
||||||
(assistant: ThreadAssistantInfo) => {
|
(assistant: ThreadAssistantInfo) => {
|
||||||
|
|||||||
@ -14,7 +14,10 @@ import { loadModelErrorAtom } from '@/hooks/useActiveModel'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
|
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
@ -27,16 +30,18 @@ const LocalServerRightPanel = () => {
|
|||||||
const selectedModel = useAtomValue(selectedModelAtom)
|
const selectedModel = useAtomValue(selectedModelAtom)
|
||||||
|
|
||||||
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
|
const [currentModelSettingParams, setCurrentModelSettingParams] = useState(
|
||||||
toSettingParams(selectedModel?.settings)
|
extractModelLoadParams(selectedModel?.settings)
|
||||||
)
|
)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedModel) {
|
if (selectedModel) {
|
||||||
setCurrentModelSettingParams(toSettingParams(selectedModel?.settings))
|
setCurrentModelSettingParams(
|
||||||
|
extractModelLoadParams(selectedModel?.settings)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}, [selectedModel])
|
}, [selectedModel])
|
||||||
|
|
||||||
const modelRuntimeParams = toRuntimeParams(selectedModel?.settings)
|
const modelRuntimeParams = extractInferenceParams(selectedModel?.settings)
|
||||||
|
|
||||||
const componentDataRuntimeSetting = getConfigurationsData(
|
const componentDataRuntimeSetting = getConfigurationsData(
|
||||||
modelRuntimeParams,
|
modelRuntimeParams,
|
||||||
|
|||||||
@ -29,7 +29,10 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
import { localEngines } from '@/utils/modelEngine'
|
import { localEngines } from '@/utils/modelEngine'
|
||||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
import {
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
|
} from '@/utils/modelParam'
|
||||||
|
|
||||||
import PromptTemplateSetting from './PromptTemplateSetting'
|
import PromptTemplateSetting from './PromptTemplateSetting'
|
||||||
import Tools from './Tools'
|
import Tools from './Tools'
|
||||||
@ -68,14 +71,26 @@ const ThreadRightPanel = () => {
|
|||||||
|
|
||||||
const settings = useMemo(() => {
|
const settings = useMemo(() => {
|
||||||
// runtime setting
|
// runtime setting
|
||||||
const modelRuntimeParams = toRuntimeParams(activeModelParams)
|
const modelRuntimeParams = extractInferenceParams(
|
||||||
|
{
|
||||||
|
...selectedModel?.parameters,
|
||||||
|
...activeModelParams,
|
||||||
|
},
|
||||||
|
selectedModel?.parameters
|
||||||
|
)
|
||||||
const componentDataRuntimeSetting = getConfigurationsData(
|
const componentDataRuntimeSetting = getConfigurationsData(
|
||||||
modelRuntimeParams,
|
modelRuntimeParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
).filter((x) => x.key !== 'prompt_template')
|
).filter((x) => x.key !== 'prompt_template')
|
||||||
|
|
||||||
// engine setting
|
// engine setting
|
||||||
const modelEngineParams = toSettingParams(activeModelParams)
|
const modelEngineParams = extractModelLoadParams(
|
||||||
|
{
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
},
|
||||||
|
selectedModel?.settings
|
||||||
|
)
|
||||||
const componentDataEngineSetting = getConfigurationsData(
|
const componentDataEngineSetting = getConfigurationsData(
|
||||||
modelEngineParams,
|
modelEngineParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
@ -126,7 +141,10 @@ const ThreadRightPanel = () => {
|
|||||||
}, [activeModelParams, selectedModel])
|
}, [activeModelParams, selectedModel])
|
||||||
|
|
||||||
const promptTemplateSettings = useMemo(() => {
|
const promptTemplateSettings = useMemo(() => {
|
||||||
const modelEngineParams = toSettingParams(activeModelParams)
|
const modelEngineParams = extractModelLoadParams({
|
||||||
|
...selectedModel?.settings,
|
||||||
|
...activeModelParams,
|
||||||
|
})
|
||||||
const componentDataEngineSetting = getConfigurationsData(
|
const componentDataEngineSetting = getConfigurationsData(
|
||||||
modelEngineParams,
|
modelEngineParams,
|
||||||
selectedModel
|
selectedModel
|
||||||
|
|||||||
183
web/utils/modelParam.test.ts
Normal file
183
web/utils/modelParam.test.ts
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
// web/utils/modelParam.test.ts
|
||||||
|
import { normalizeValue, validationRules } from './modelParam'
|
||||||
|
|
||||||
|
describe('validationRules', () => {
|
||||||
|
it('should validate temperature correctly', () => {
|
||||||
|
expect(validationRules.temperature(0.5)).toBe(true)
|
||||||
|
expect(validationRules.temperature(2)).toBe(true)
|
||||||
|
expect(validationRules.temperature(0)).toBe(true)
|
||||||
|
expect(validationRules.temperature(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.temperature(2.3)).toBe(false)
|
||||||
|
expect(validationRules.temperature('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate token_limit correctly', () => {
|
||||||
|
expect(validationRules.token_limit(100)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(1)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(0)).toBe(true)
|
||||||
|
expect(validationRules.token_limit(-1)).toBe(false)
|
||||||
|
expect(validationRules.token_limit('100')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate top_k correctly', () => {
|
||||||
|
expect(validationRules.top_k(0.5)).toBe(true)
|
||||||
|
expect(validationRules.top_k(1)).toBe(true)
|
||||||
|
expect(validationRules.top_k(0)).toBe(true)
|
||||||
|
expect(validationRules.top_k(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.top_k(1.1)).toBe(false)
|
||||||
|
expect(validationRules.top_k('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate top_p correctly', () => {
|
||||||
|
expect(validationRules.top_p(0.5)).toBe(true)
|
||||||
|
expect(validationRules.top_p(1)).toBe(true)
|
||||||
|
expect(validationRules.top_p(0)).toBe(true)
|
||||||
|
expect(validationRules.top_p(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.top_p(1.1)).toBe(false)
|
||||||
|
expect(validationRules.top_p('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate stream correctly', () => {
|
||||||
|
expect(validationRules.stream(true)).toBe(true)
|
||||||
|
expect(validationRules.stream(false)).toBe(true)
|
||||||
|
expect(validationRules.stream('true')).toBe(false)
|
||||||
|
expect(validationRules.stream(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate max_tokens correctly', () => {
|
||||||
|
expect(validationRules.max_tokens(100)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(1)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(0)).toBe(true)
|
||||||
|
expect(validationRules.max_tokens(-1)).toBe(false)
|
||||||
|
expect(validationRules.max_tokens('100')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate stop correctly', () => {
|
||||||
|
expect(validationRules.stop(['word1', 'word2'])).toBe(true)
|
||||||
|
expect(validationRules.stop([])).toBe(true)
|
||||||
|
expect(validationRules.stop(['word1', 2])).toBe(false)
|
||||||
|
expect(validationRules.stop('word1')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate frequency_penalty correctly', () => {
|
||||||
|
expect(validationRules.frequency_penalty(0.5)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(1)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(0)).toBe(true)
|
||||||
|
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.frequency_penalty(1.1)).toBe(false)
|
||||||
|
expect(validationRules.frequency_penalty('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate presence_penalty correctly', () => {
|
||||||
|
expect(validationRules.presence_penalty(0.5)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(1)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(0)).toBe(true)
|
||||||
|
expect(validationRules.presence_penalty(-0.1)).toBe(false)
|
||||||
|
expect(validationRules.presence_penalty(1.1)).toBe(false)
|
||||||
|
expect(validationRules.presence_penalty('0.5')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate ctx_len correctly', () => {
|
||||||
|
expect(validationRules.ctx_len(1024)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(1)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(0)).toBe(true)
|
||||||
|
expect(validationRules.ctx_len(-1)).toBe(false)
|
||||||
|
expect(validationRules.ctx_len('1024')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate ngl correctly', () => {
|
||||||
|
expect(validationRules.ngl(12)).toBe(true)
|
||||||
|
expect(validationRules.ngl(1)).toBe(true)
|
||||||
|
expect(validationRules.ngl(0)).toBe(true)
|
||||||
|
expect(validationRules.ngl(-1)).toBe(false)
|
||||||
|
expect(validationRules.ngl('12')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate embedding correctly', () => {
|
||||||
|
expect(validationRules.embedding(true)).toBe(true)
|
||||||
|
expect(validationRules.embedding(false)).toBe(true)
|
||||||
|
expect(validationRules.embedding('true')).toBe(false)
|
||||||
|
expect(validationRules.embedding(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate n_parallel correctly', () => {
|
||||||
|
expect(validationRules.n_parallel(2)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(1)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(0)).toBe(true)
|
||||||
|
expect(validationRules.n_parallel(-1)).toBe(false)
|
||||||
|
expect(validationRules.n_parallel('2')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate cpu_threads correctly', () => {
|
||||||
|
expect(validationRules.cpu_threads(4)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(1)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(0)).toBe(true)
|
||||||
|
expect(validationRules.cpu_threads(-1)).toBe(false)
|
||||||
|
expect(validationRules.cpu_threads('4')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate prompt_template correctly', () => {
|
||||||
|
expect(validationRules.prompt_template('template')).toBe(true)
|
||||||
|
expect(validationRules.prompt_template('')).toBe(true)
|
||||||
|
expect(validationRules.prompt_template(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate llama_model_path correctly', () => {
|
||||||
|
expect(validationRules.llama_model_path('path')).toBe(true)
|
||||||
|
expect(validationRules.llama_model_path('')).toBe(true)
|
||||||
|
expect(validationRules.llama_model_path(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate mmproj correctly', () => {
|
||||||
|
expect(validationRules.mmproj('mmproj')).toBe(true)
|
||||||
|
expect(validationRules.mmproj('')).toBe(true)
|
||||||
|
expect(validationRules.mmproj(123)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate vision_model correctly', () => {
|
||||||
|
expect(validationRules.vision_model(true)).toBe(true)
|
||||||
|
expect(validationRules.vision_model(false)).toBe(true)
|
||||||
|
expect(validationRules.vision_model('true')).toBe(false)
|
||||||
|
expect(validationRules.vision_model(1)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should validate text_model correctly', () => {
|
||||||
|
expect(validationRules.text_model(true)).toBe(true)
|
||||||
|
expect(validationRules.text_model(false)).toBe(true)
|
||||||
|
expect(validationRules.text_model('true')).toBe(false)
|
||||||
|
expect(validationRules.text_model(1)).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('normalizeValue', () => {
|
||||||
|
it('should normalize ctx_len correctly', () => {
|
||||||
|
expect(normalizeValue('ctx_len', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('ctx_len', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('ctx_len', 100)).toBe(100)
|
||||||
|
})
|
||||||
|
it('should normalize token_limit correctly', () => {
|
||||||
|
expect(normalizeValue('token_limit', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('token_limit', '1')).toBe(1)
|
||||||
|
expect(normalizeValue('token_limit', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize max_tokens correctly', () => {
|
||||||
|
expect(normalizeValue('max_tokens', 100.5)).toBe(100)
|
||||||
|
expect(normalizeValue('max_tokens', '1')).toBe(1)
|
||||||
|
expect(normalizeValue('max_tokens', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize ngl correctly', () => {
|
||||||
|
expect(normalizeValue('ngl', 12.5)).toBe(12)
|
||||||
|
expect(normalizeValue('ngl', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('ngl', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize n_parallel correctly', () => {
|
||||||
|
expect(normalizeValue('n_parallel', 2.5)).toBe(2)
|
||||||
|
expect(normalizeValue('n_parallel', '2')).toBe(2)
|
||||||
|
expect(normalizeValue('n_parallel', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
it('should normalize cpu_threads correctly', () => {
|
||||||
|
expect(normalizeValue('cpu_threads', 4.5)).toBe(4)
|
||||||
|
expect(normalizeValue('cpu_threads', '4')).toBe(4)
|
||||||
|
expect(normalizeValue('cpu_threads', 0)).toBe(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
@ -1,9 +1,69 @@
|
|||||||
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
|
/* eslint-disable @typescript-eslint/naming-convention */
|
||||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||||
|
|
||||||
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export const toRuntimeParams = (
|
/**
|
||||||
modelParams?: ModelParams
|
* Validation rules for model parameters
|
||||||
|
*/
|
||||||
|
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||||
|
temperature: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 2,
|
||||||
|
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
stream: (value: any) => typeof value === 'boolean',
|
||||||
|
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
stop: (value: any) =>
|
||||||
|
Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||||
|
frequency_penalty: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
presence_penalty: (value: any) =>
|
||||||
|
typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
|
|
||||||
|
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
embedding: (value: any) => typeof value === 'boolean',
|
||||||
|
n_parallel: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
cpu_threads: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
|
prompt_template: (value: any) => typeof value === 'string',
|
||||||
|
llama_model_path: (value: any) => typeof value === 'string',
|
||||||
|
mmproj: (value: any) => typeof value === 'string',
|
||||||
|
vision_model: (value: any) => typeof value === 'boolean',
|
||||||
|
text_model: (value: any) => typeof value === 'boolean',
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* There are some parameters that need to be normalized before being sent to the server
|
||||||
|
* E.g. ctx_len should be an integer, but it can be a float from the input field
|
||||||
|
* @param key
|
||||||
|
* @param value
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const normalizeValue = (key: string, value: any) => {
|
||||||
|
if (
|
||||||
|
key === 'token_limit' ||
|
||||||
|
key === 'max_tokens' ||
|
||||||
|
key === 'ctx_len' ||
|
||||||
|
key === 'ngl' ||
|
||||||
|
key === 'n_parallel' ||
|
||||||
|
key === 'cpu_threads'
|
||||||
|
) {
|
||||||
|
// Convert to integer
|
||||||
|
return Math.floor(Number(value))
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract inference parameters from flat model parameters
|
||||||
|
* @param modelParams
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const extractInferenceParams = (
|
||||||
|
modelParams?: ModelParams,
|
||||||
|
originParams?: ModelParams
|
||||||
): ModelRuntimeParams => {
|
): ModelRuntimeParams => {
|
||||||
if (!modelParams) return {}
|
if (!modelParams) return {}
|
||||||
const defaultModelParams: ModelRuntimeParams = {
|
const defaultModelParams: ModelRuntimeParams = {
|
||||||
@ -22,15 +82,35 @@ export const toRuntimeParams = (
|
|||||||
|
|
||||||
for (const [key, value] of Object.entries(modelParams)) {
|
for (const [key, value] of Object.entries(modelParams)) {
|
||||||
if (key in defaultModelParams) {
|
if (key in defaultModelParams) {
|
||||||
Object.assign(runtimeParams, { ...runtimeParams, [key]: value })
|
const validate = validationRules[key]
|
||||||
|
if (validate && !validate(normalizeValue(key, value))) {
|
||||||
|
// Invalid value - fall back to origin value
|
||||||
|
if (originParams && key in originParams) {
|
||||||
|
Object.assign(runtimeParams, {
|
||||||
|
...runtimeParams,
|
||||||
|
[key]: originParams[key as keyof typeof originParams],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Object.assign(runtimeParams, {
|
||||||
|
...runtimeParams,
|
||||||
|
[key]: normalizeValue(key, value),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return runtimeParams
|
return runtimeParams
|
||||||
}
|
}
|
||||||
|
|
||||||
export const toSettingParams = (
|
/**
|
||||||
modelParams?: ModelParams
|
* Extract model load parameters from flat model parameters
|
||||||
|
* @param modelParams
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
export const extractModelLoadParams = (
|
||||||
|
modelParams?: ModelParams,
|
||||||
|
originParams?: ModelParams
|
||||||
): ModelSettingParams => {
|
): ModelSettingParams => {
|
||||||
if (!modelParams) return {}
|
if (!modelParams) return {}
|
||||||
const defaultSettingParams: ModelSettingParams = {
|
const defaultSettingParams: ModelSettingParams = {
|
||||||
@ -49,7 +129,21 @@ export const toSettingParams = (
|
|||||||
|
|
||||||
for (const [key, value] of Object.entries(modelParams)) {
|
for (const [key, value] of Object.entries(modelParams)) {
|
||||||
if (key in defaultSettingParams) {
|
if (key in defaultSettingParams) {
|
||||||
Object.assign(settingParams, { ...settingParams, [key]: value })
|
const validate = validationRules[key]
|
||||||
|
if (validate && !validate(normalizeValue(key, value))) {
|
||||||
|
// Invalid value - fall back to origin value
|
||||||
|
if (originParams && key in originParams) {
|
||||||
|
Object.assign(modelParams, {
|
||||||
|
...modelParams,
|
||||||
|
[key]: originParams[key as keyof typeof originParams],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Object.assign(settingParams, {
|
||||||
|
...settingParams,
|
||||||
|
[key]: normalizeValue(key, value),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user