Merge pull request #5650 from menloresearch/fix/top_k-model-setting-validation

fix: top_k validation
This commit is contained in:
Louis 2025-07-01 17:31:03 +07:00 committed by GitHub
commit 94b25ec6e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 17 deletions

View File

@ -29,7 +29,7 @@ describe('validationRules', () => {
expect(validationRules.top_k(1)).toBe(true) expect(validationRules.top_k(1)).toBe(true)
expect(validationRules.top_k(0)).toBe(true) expect(validationRules.top_k(0)).toBe(true)
expect(validationRules.top_k(-0.1)).toBe(false) expect(validationRules.top_k(-0.1)).toBe(false)
expect(validationRules.top_k(1.1)).toBe(false) expect(validationRules.top_k(1.1)).toBe(true)
expect(validationRules.top_k('0.5')).toBe(false) expect(validationRules.top_k('0.5')).toBe(false)
}) })
@ -68,8 +68,8 @@ describe('validationRules', () => {
expect(validationRules.frequency_penalty(0.5)).toBe(true) expect(validationRules.frequency_penalty(0.5)).toBe(true)
expect(validationRules.frequency_penalty(1)).toBe(true) expect(validationRules.frequency_penalty(1)).toBe(true)
expect(validationRules.frequency_penalty(0)).toBe(true) expect(validationRules.frequency_penalty(0)).toBe(true)
expect(validationRules.frequency_penalty(-0.1)).toBe(false) expect(validationRules.frequency_penalty(-0.1)).toBe(true)
expect(validationRules.frequency_penalty(1.1)).toBe(false) expect(validationRules.frequency_penalty(1.1)).toBe(true)
expect(validationRules.frequency_penalty('0.5')).toBe(false) expect(validationRules.frequency_penalty('0.5')).toBe(false)
}) })
@ -77,8 +77,8 @@ describe('validationRules', () => {
expect(validationRules.presence_penalty(0.5)).toBe(true) expect(validationRules.presence_penalty(0.5)).toBe(true)
expect(validationRules.presence_penalty(1)).toBe(true) expect(validationRules.presence_penalty(1)).toBe(true)
expect(validationRules.presence_penalty(0)).toBe(true) expect(validationRules.presence_penalty(0)).toBe(true)
expect(validationRules.presence_penalty(-0.1)).toBe(false) expect(validationRules.presence_penalty(-0.1)).toBe(true)
expect(validationRules.presence_penalty(1.1)).toBe(false) expect(validationRules.presence_penalty(1.1)).toBe(true)
expect(validationRules.presence_penalty('0.5')).toBe(false) expect(validationRules.presence_penalty('0.5')).toBe(false)
}) })
@ -255,16 +255,16 @@ describe('extractInferenceParams', () => {
top_p: 0.9, top_p: 0.9,
stream: true, stream: true,
max_tokens: 50.3, max_tokens: 50.3,
invalid_param: 'should_be_ignored' invalid_param: 'should_be_ignored',
} }
const result = extractInferenceParams(modelParams as any) const result = extractInferenceParams(modelParams as any)
expect(result).toEqual({ expect(result).toEqual({
temperature: 1.5, temperature: 1.5,
token_limit: 100, token_limit: 100,
top_p: 0.9, top_p: 0.9,
stream: true, stream: true,
max_tokens: 50 max_tokens: 50,
}) })
}) })
@ -296,9 +296,9 @@ describe('extractModelLoadParams', () => {
prompt_template: 'template', prompt_template: 'template',
llama_model_path: '/path/to/model', llama_model_path: '/path/to/model',
vision_model: false, vision_model: false,
invalid_param: 'should_be_ignored' invalid_param: 'should_be_ignored',
} }
const result = extractModelLoadParams(modelParams as any) const result = extractModelLoadParams(modelParams as any)
expect(result).toEqual({ expect(result).toEqual({
ctx_len: 2048, ctx_len: 2048,
@ -308,23 +308,23 @@ describe('extractModelLoadParams', () => {
cpu_threads: 8, cpu_threads: 8,
prompt_template: 'template', prompt_template: 'template',
llama_model_path: '/path/to/model', llama_model_path: '/path/to/model',
vision_model: false vision_model: false,
}) })
}) })
it('should handle parameters without validation rules', () => { it('should handle parameters without validation rules', () => {
const modelParams = { const modelParams = {
engine: 'llama', engine: 'llama',
pre_prompt: 'System:', pre_prompt: 'System:',
system_prompt: 'You are helpful', system_prompt: 'You are helpful',
model_path: '/path' model_path: '/path',
} }
const result = extractModelLoadParams(modelParams as any) const result = extractModelLoadParams(modelParams as any)
expect(result).toEqual({ expect(result).toEqual({
engine: 'llama', engine: 'llama',
pre_prompt: 'System:', pre_prompt: 'System:',
system_prompt: 'You are helpful', system_prompt: 'You are helpful',
model_path: '/path' model_path: '/path',
}) })
}) })

View File

@ -8,13 +8,13 @@ import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types
export const validationRules: { [key: string]: (value: any) => boolean } = { export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2, temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0, token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, top_k: (value: any) => typeof value === 'number' && value >= 0,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean', stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0, max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'), stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, frequency_penalty: (value: any) => typeof value === 'number' && value >= -2 && value <= 2,
presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, presence_penalty: (value: any) => typeof value === 'number' && value >= -2 && value <= 2,
repeat_last_n: (value: any) => typeof value === 'number', repeat_last_n: (value: any) => typeof value === 'number',
repeat_penalty: (value: any) => typeof value === 'number', repeat_penalty: (value: any) => typeof value === 'number',
min_p: (value: any) => typeof value === 'number', min_p: (value: any) => typeof value === 'number',