fix: top_k validation
This commit is contained in:
parent
9e9bc49729
commit
0b88d93e18
@ -29,7 +29,7 @@ describe('validationRules', () => {
|
||||
expect(validationRules.top_k(1)).toBe(true)
|
||||
expect(validationRules.top_k(0)).toBe(true)
|
||||
expect(validationRules.top_k(-0.1)).toBe(false)
|
||||
expect(validationRules.top_k(1.1)).toBe(false)
|
||||
expect(validationRules.top_k(1.1)).toBe(true)
|
||||
expect(validationRules.top_k('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
@ -68,8 +68,8 @@ describe('validationRules', () => {
|
||||
expect(validationRules.frequency_penalty(0.5)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(1)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(0)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(-0.1)).toBe(false)
|
||||
expect(validationRules.frequency_penalty(1.1)).toBe(false)
|
||||
expect(validationRules.frequency_penalty(-0.1)).toBe(true)
|
||||
expect(validationRules.frequency_penalty(1.1)).toBe(true)
|
||||
expect(validationRules.frequency_penalty('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
@ -77,8 +77,8 @@ describe('validationRules', () => {
|
||||
expect(validationRules.presence_penalty(0.5)).toBe(true)
|
||||
expect(validationRules.presence_penalty(1)).toBe(true)
|
||||
expect(validationRules.presence_penalty(0)).toBe(true)
|
||||
expect(validationRules.presence_penalty(-0.1)).toBe(false)
|
||||
expect(validationRules.presence_penalty(1.1)).toBe(false)
|
||||
expect(validationRules.presence_penalty(-0.1)).toBe(true)
|
||||
expect(validationRules.presence_penalty(1.1)).toBe(true)
|
||||
expect(validationRules.presence_penalty('0.5')).toBe(false)
|
||||
})
|
||||
|
||||
@ -255,16 +255,16 @@ describe('extractInferenceParams', () => {
|
||||
top_p: 0.9,
|
||||
stream: true,
|
||||
max_tokens: 50.3,
|
||||
invalid_param: 'should_be_ignored'
|
||||
invalid_param: 'should_be_ignored',
|
||||
}
|
||||
|
||||
|
||||
const result = extractInferenceParams(modelParams as any)
|
||||
expect(result).toEqual({
|
||||
temperature: 1.5,
|
||||
token_limit: 100,
|
||||
top_p: 0.9,
|
||||
stream: true,
|
||||
max_tokens: 50
|
||||
max_tokens: 50,
|
||||
})
|
||||
})
|
||||
|
||||
@ -296,9 +296,9 @@ describe('extractModelLoadParams', () => {
|
||||
prompt_template: 'template',
|
||||
llama_model_path: '/path/to/model',
|
||||
vision_model: false,
|
||||
invalid_param: 'should_be_ignored'
|
||||
invalid_param: 'should_be_ignored',
|
||||
}
|
||||
|
||||
|
||||
const result = extractModelLoadParams(modelParams as any)
|
||||
expect(result).toEqual({
|
||||
ctx_len: 2048,
|
||||
@ -308,23 +308,23 @@ describe('extractModelLoadParams', () => {
|
||||
cpu_threads: 8,
|
||||
prompt_template: 'template',
|
||||
llama_model_path: '/path/to/model',
|
||||
vision_model: false
|
||||
vision_model: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle parameters without validation rules', () => {
|
||||
const modelParams = {
|
||||
const modelParams = {
|
||||
engine: 'llama',
|
||||
pre_prompt: 'System:',
|
||||
system_prompt: 'You are helpful',
|
||||
model_path: '/path'
|
||||
model_path: '/path',
|
||||
}
|
||||
const result = extractModelLoadParams(modelParams as any)
|
||||
expect(result).toEqual({
|
||||
engine: 'llama',
|
||||
pre_prompt: 'System:',
|
||||
system_prompt: 'You are helpful',
|
||||
model_path: '/path'
|
||||
model_path: '/path',
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -8,13 +8,13 @@ import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types
|
||||
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||
temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
|
||||
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
top_k: (value: any) => typeof value === 'number' && value >= 0,
|
||||
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
stream: (value: any) => typeof value === 'boolean',
|
||||
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||
frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
frequency_penalty: (value: any) => typeof value === 'number' && value >= -2 && value <= 2,
|
||||
presence_penalty: (value: any) => typeof value === 'number' && value >= -2 && value <= 2,
|
||||
repeat_last_n: (value: any) => typeof value === 'number',
|
||||
repeat_penalty: (value: any) => typeof value === 'number',
|
||||
min_p: (value: any) => typeof value === 'number',
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user