fix: correct model settings on startup and strip down irrelevant model parameters

This commit is contained in:
Louis 2024-10-25 12:33:43 +07:00
parent 90c7420c34
commit 3643c8866e
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 111 additions and 89 deletions

View File

@ -3,3 +3,8 @@
* @module * @module
*/ */
export { ModelManager } from './manager' export { ModelManager } from './manager'
/**
* Export all utils
*/
export * from './utils'

View File

@ -1,7 +1,10 @@
// web/utils/modelParam.test.ts // web/utils/modelParam.test.ts
import { normalizeValue, validationRules } from './modelParam' import {
import { extractModelLoadParams } from './modelParam'; normalizeValue,
import { extractInferenceParams } from './modelParam'; validationRules,
extractModelLoadParams,
extractInferenceParams,
} from './utils'
describe('validationRules', () => { describe('validationRules', () => {
it('should validate temperature correctly', () => { it('should validate temperature correctly', () => {
@ -151,13 +154,12 @@ describe('validationRules', () => {
}) })
}) })
it('should normalize invalid values for keys not listed in validationRules', () => {
it('should normalize invalid values for keys not listed in validationRules', () => {
expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid') expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
expect(normalizeValue('invalid_key', 123)).toBe(123) expect(normalizeValue('invalid_key', 123)).toBe(123)
expect(normalizeValue('invalid_key', true)).toBe(true) expect(normalizeValue('invalid_key', true)).toBe(true)
expect(normalizeValue('invalid_key', false)).toBe(false) expect(normalizeValue('invalid_key', false)).toBe(false)
}) })
describe('normalizeValue', () => { describe('normalizeValue', () => {
it('should normalize ctx_len correctly', () => { it('should normalize ctx_len correctly', () => {
@ -192,19 +194,16 @@ describe('normalizeValue', () => {
}) })
}) })
it('should handle invalid values correctly by falling back to originParams', () => {
const modelParams = { temperature: 'invalid', token_limit: -1 }
const originParams = { temperature: 0.5, token_limit: 100 }
expect(extractInferenceParams(modelParams as any, originParams)).toEqual(originParams)
})
it('should handle invalid values correctly by falling back to originParams', () => { it('should return an empty object when no modelParams are provided', () => {
const modelParams = { temperature: 'invalid', token_limit: -1 }; expect(extractModelLoadParams()).toEqual({})
const originParams = { temperature: 0.5, token_limit: 100 }; })
expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams);
});
it('should return an empty object when no modelParams are provided', () => {
it('should return an empty object when no modelParams are provided', () => { expect(extractInferenceParams()).toEqual({})
expect(extractModelLoadParams()).toEqual({}); })
});
it('should return an empty object when no modelParams are provided', () => {
expect(extractInferenceParams()).toEqual({});
});

View File

@ -1,26 +1,20 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/naming-convention */ /* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types'
import { ModelParams } from '@/types/model'
/** /**
* Validation rules for model parameters * Validation rules for model parameters
*/ */
export const validationRules: { [key: string]: (value: any) => boolean } = { export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) => temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0, token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean', stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0, max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) => stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
Array.isArray(value) && value.every((v) => typeof v === 'string'), frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
frequency_penalty: (value: any) => presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
ctx_len: (value: any) => Number.isInteger(value) && value >= 0, ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
ngl: (value: any) => Number.isInteger(value) && value >= 0, ngl: (value: any) => Number.isInteger(value) && value >= 0,
@ -76,6 +70,7 @@ export const extractInferenceParams = (
stop: undefined, stop: undefined,
frequency_penalty: undefined, frequency_penalty: undefined,
presence_penalty: undefined, presence_penalty: undefined,
engine: undefined,
} }
const runtimeParams: ModelRuntimeParams = {} const runtimeParams: ModelRuntimeParams = {}
@ -119,11 +114,18 @@ export const extractModelLoadParams = (
embedding: undefined, embedding: undefined,
n_parallel: undefined, n_parallel: undefined,
cpu_threads: undefined, cpu_threads: undefined,
pre_prompt: undefined,
system_prompt: undefined,
ai_prompt: undefined,
user_prompt: undefined,
prompt_template: undefined, prompt_template: undefined,
model_path: undefined,
llama_model_path: undefined, llama_model_path: undefined,
mmproj: undefined, mmproj: undefined,
cont_batching: undefined,
vision_model: undefined, vision_model: undefined,
text_model: undefined, text_model: undefined,
engine: undefined,
} }
const settingParams: ModelSettingParams = {} const settingParams: ModelSettingParams = {}

View File

@ -15,7 +15,6 @@ export type ModelInfo = {
* Represents the inference engine. * Represents the inference engine.
* @stored * @stored
*/ */
export enum InferenceEngine { export enum InferenceEngine {
anthropic = 'anthropic', anthropic = 'anthropic',
mistral = 'mistral', mistral = 'mistral',
@ -34,6 +33,7 @@ export enum InferenceEngine {
cortex_tensorrtllm = 'tensorrt-llm', cortex_tensorrtllm = 'tensorrt-llm',
} }
// Represents an artifact of a model, including its filename and URL
export type ModelArtifact = { export type ModelArtifact = {
filename: string filename: string
url: string url: string
@ -105,6 +105,7 @@ export type Model = {
engine: InferenceEngine engine: InferenceEngine
} }
// Represents metadata associated with a model
export type ModelMetadata = { export type ModelMetadata = {
author: string author: string
tags: string[] tags: string[]
@ -125,14 +126,20 @@ export type ModelSettingParams = {
n_parallel?: number n_parallel?: number
cpu_threads?: number cpu_threads?: number
prompt_template?: string prompt_template?: string
pre_prompt?: string
system_prompt?: string system_prompt?: string
ai_prompt?: string ai_prompt?: string
user_prompt?: string user_prompt?: string
// path param
model_path?: string
// legacy path param
llama_model_path?: string llama_model_path?: string
// clip model path
mmproj?: string mmproj?: string
cont_batching?: boolean cont_batching?: boolean
vision_model?: boolean vision_model?: boolean
text_model?: boolean text_model?: boolean
engine?: boolean
} }
/** /**
@ -151,6 +158,12 @@ export type ModelRuntimeParams = {
engine?: string engine?: string
} }
// Represents a model that failed to initialize, including the error
export type ModelInitFailed = Model & { export type ModelInitFailed = Model & {
error: Error error: Error
} }
/**
* ModelParams types
*/
export type ModelParams = ModelRuntimeParams | ModelSettingParams

View File

@ -10,11 +10,12 @@ import {
Model, Model,
executeOnMain, executeOnMain,
systemInformation, systemInformation,
log,
joinPath, joinPath,
dirName, dirName,
LocalOAIEngine, LocalOAIEngine,
InferenceEngine, InferenceEngine,
getJanDataFolderPath,
extractModelLoadParams,
} from '@janhq/core' } from '@janhq/core'
import PQueue from 'p-queue' import PQueue from 'p-queue'
import ky from 'ky' import ky from 'ky'
@ -62,24 +63,38 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
override async loadModel( override async loadModel(
model: Model & { file_path?: string } model: Model & { file_path?: string }
): Promise<void> { ): Promise<void> {
// Legacy model cache - should import if (
if (model.engine === InferenceEngine.nitro && model.file_path) { model.engine === InferenceEngine.nitro &&
// Try importing the model model.settings.llama_model_path
const modelPath = await this.modelPath(model) ) {
await this.queue.add(() => // Legacy chat model support
ky model.settings = {
.post(`${CORTEX_API_URL}/v1/models/${model.id}`, { ...model.settings,
json: { model: model.id, modelPath: modelPath }, llama_model_path: await getModelFilePath(
}) model.id,
.json() model.settings.llama_model_path
.catch((e) => log(e.message ?? e ?? '')) ),
) }
} else {
const { llama_model_path, ...settings } = model.settings
model.settings = settings
}
if (model.engine === InferenceEngine.nitro && model.settings.mmproj) {
// Legacy clip vision model support
model.settings = {
...model.settings,
mmproj: await getModelFilePath(model.id, model.settings.mmproj),
}
} else {
const { mmproj, ...settings } = model.settings
model.settings = settings
} }
return await ky return await ky
.post(`${CORTEX_API_URL}/v1/models/start`, { .post(`${CORTEX_API_URL}/v1/models/start`, {
json: { json: {
...model.settings, ...extractModelLoadParams(model.settings),
model: model.id, model: model.id,
engine: engine:
model.engine === InferenceEngine.nitro // Legacy model cache model.engine === InferenceEngine.nitro // Legacy model cache
@ -131,3 +146,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
.then(() => {}) .then(() => {})
} }
} }
/// Legacy
export const getModelFilePath = async (
id: string,
file: string
): Promise<string> => {
return joinPath([await getJanDataFolderPath(), 'models', id, file])
}
///

View File

@ -1,13 +1,7 @@
import PQueue from 'p-queue' import PQueue from 'p-queue'
import ky from 'ky' import ky from 'ky'
import { import { events, extractModelLoadParams, Model, ModelEvent } from '@janhq/core'
DownloadEvent, import { extractInferenceParams } from '@janhq/core'
events,
Model,
ModelEvent,
ModelRuntimeParams,
ModelSettingParams,
} from '@janhq/core'
/** /**
* cortex.cpp Model APIs interface * cortex.cpp Model APIs interface
*/ */
@ -204,20 +198,17 @@ export class CortexAPI implements ICortexAPI {
* @returns * @returns
*/ */
private transformModel(model: any) { private transformModel(model: any) {
model.parameters = setParameters<ModelRuntimeParams>(model) model.parameters = {
model.settings = setParameters<ModelSettingParams>(model) ...extractInferenceParams(model),
model.metadata = { ...model.parameters,
}
model.settings = {
...extractModelLoadParams(model),
...model.settings,
}
model.metadata = model.metadata ?? {
tags: [], tags: [],
} }
return model as Model return model as Model
} }
} }
type FilteredParams<T> = {
[K in keyof T]: T[K]
}
function setParameters<T>(params: T): T {
const filteredParams: FilteredParams<T> = { ...params }
return filteredParams
}

View File

@ -15,6 +15,7 @@ import {
Thread, Thread,
EngineManager, EngineManager,
InferenceEngine, InferenceEngine,
extractInferenceParams,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx' import { ulid } from 'ulidx'
@ -22,7 +23,6 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { isLocalEngine } from '@/utils/modelEngine' import { isLocalEngine } from '@/utils/modelEngine'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {

View File

@ -12,6 +12,7 @@ import {
ToolManager, ToolManager,
ChatCompletionMessage, ChatCompletionMessage,
} from '@janhq/core' } from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { import {
@ -23,10 +24,6 @@ import {
import { Stack } from '@/utils/Stack' import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64' import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'

View File

@ -6,15 +6,12 @@ import {
InferenceEngine, InferenceEngine,
Thread, Thread,
ThreadAssistantInfo, ThreadAssistantInfo,
extractInferenceParams,
extractModelLoadParams,
} from '@janhq/core' } from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {

View File

@ -1,5 +1,6 @@
import { useCallback, useEffect, useMemo, useState } from 'react' import { useCallback, useEffect, useMemo, useState } from 'react'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi' import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react' import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react'
@ -16,11 +17,6 @@ import { useClipboard } from '@/hooks/useClipboard'
import { getConfigurationsData } from '@/utils/componentSettings' import { getConfigurationsData } from '@/utils/componentSettings'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'

View File

@ -4,6 +4,8 @@ import {
InferenceEngine, InferenceEngine,
SettingComponentProps, SettingComponentProps,
SliderComponentProps, SliderComponentProps,
extractInferenceParams,
extractModelLoadParams,
} from '@janhq/core' } from '@janhq/core'
import { import {
Tabs, Tabs,
@ -31,10 +33,6 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings' import { getConfigurationsData } from '@/utils/componentSettings'
import { isLocalEngine } from '@/utils/modelEngine' import { isLocalEngine } from '@/utils/modelEngine'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting' import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools' import Tools from './Tools'