fix: correct model settings on startup and strip down irrelevant model parameters
This commit is contained in:
parent
90c7420c34
commit
3643c8866e
@ -3,3 +3,8 @@
|
|||||||
* @module
|
* @module
|
||||||
*/
|
*/
|
||||||
export { ModelManager } from './manager'
|
export { ModelManager } from './manager'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Export all utils
|
||||||
|
*/
|
||||||
|
export * from './utils'
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
// web/utils/modelParam.test.ts
|
// web/utils/modelParam.test.ts
|
||||||
import { normalizeValue, validationRules } from './modelParam'
|
import {
|
||||||
import { extractModelLoadParams } from './modelParam';
|
normalizeValue,
|
||||||
import { extractInferenceParams } from './modelParam';
|
validationRules,
|
||||||
|
extractModelLoadParams,
|
||||||
|
extractInferenceParams,
|
||||||
|
} from './utils'
|
||||||
|
|
||||||
describe('validationRules', () => {
|
describe('validationRules', () => {
|
||||||
it('should validate temperature correctly', () => {
|
it('should validate temperature correctly', () => {
|
||||||
@ -151,13 +154,12 @@ describe('validationRules', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should normalize invalid values for keys not listed in validationRules', () => {
|
||||||
it('should normalize invalid values for keys not listed in validationRules', () => {
|
expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
|
||||||
expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
|
expect(normalizeValue('invalid_key', 123)).toBe(123)
|
||||||
expect(normalizeValue('invalid_key', 123)).toBe(123)
|
expect(normalizeValue('invalid_key', true)).toBe(true)
|
||||||
expect(normalizeValue('invalid_key', true)).toBe(true)
|
expect(normalizeValue('invalid_key', false)).toBe(false)
|
||||||
expect(normalizeValue('invalid_key', false)).toBe(false)
|
})
|
||||||
})
|
|
||||||
|
|
||||||
describe('normalizeValue', () => {
|
describe('normalizeValue', () => {
|
||||||
it('should normalize ctx_len correctly', () => {
|
it('should normalize ctx_len correctly', () => {
|
||||||
@ -192,19 +194,16 @@ describe('normalizeValue', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should handle invalid values correctly by falling back to originParams', () => {
|
||||||
|
const modelParams = { temperature: 'invalid', token_limit: -1 }
|
||||||
|
const originParams = { temperature: 0.5, token_limit: 100 }
|
||||||
|
expect(extractInferenceParams(modelParams as any, originParams)).toEqual(originParams)
|
||||||
|
})
|
||||||
|
|
||||||
it('should handle invalid values correctly by falling back to originParams', () => {
|
it('should return an empty object when no modelParams are provided', () => {
|
||||||
const modelParams = { temperature: 'invalid', token_limit: -1 };
|
expect(extractModelLoadParams()).toEqual({})
|
||||||
const originParams = { temperature: 0.5, token_limit: 100 };
|
})
|
||||||
expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
it('should return an empty object when no modelParams are provided', () => {
|
||||||
it('should return an empty object when no modelParams are provided', () => {
|
expect(extractInferenceParams()).toEqual({})
|
||||||
expect(extractModelLoadParams()).toEqual({});
|
})
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
it('should return an empty object when no modelParams are provided', () => {
|
|
||||||
expect(extractInferenceParams()).toEqual({});
|
|
||||||
});
|
|
||||||
@ -1,26 +1,20 @@
|
|||||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||||
/* eslint-disable @typescript-eslint/naming-convention */
|
/* eslint-disable @typescript-eslint/naming-convention */
|
||||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types'
|
||||||
|
|
||||||
import { ModelParams } from '@/types/model'
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Validation rules for model parameters
|
* Validation rules for model parameters
|
||||||
*/
|
*/
|
||||||
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||||
temperature: (value: any) =>
|
temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
|
||||||
typeof value === 'number' && value >= 0 && value <= 2,
|
|
||||||
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
stream: (value: any) => typeof value === 'boolean',
|
stream: (value: any) => typeof value === 'boolean',
|
||||||
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
stop: (value: any) =>
|
stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||||
Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
frequency_penalty: (value: any) =>
|
presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||||
typeof value === 'number' && value >= 0 && value <= 1,
|
|
||||||
presence_penalty: (value: any) =>
|
|
||||||
typeof value === 'number' && value >= 0 && value <= 1,
|
|
||||||
|
|
||||||
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
||||||
@ -76,6 +70,7 @@ export const extractInferenceParams = (
|
|||||||
stop: undefined,
|
stop: undefined,
|
||||||
frequency_penalty: undefined,
|
frequency_penalty: undefined,
|
||||||
presence_penalty: undefined,
|
presence_penalty: undefined,
|
||||||
|
engine: undefined,
|
||||||
}
|
}
|
||||||
|
|
||||||
const runtimeParams: ModelRuntimeParams = {}
|
const runtimeParams: ModelRuntimeParams = {}
|
||||||
@ -119,11 +114,18 @@ export const extractModelLoadParams = (
|
|||||||
embedding: undefined,
|
embedding: undefined,
|
||||||
n_parallel: undefined,
|
n_parallel: undefined,
|
||||||
cpu_threads: undefined,
|
cpu_threads: undefined,
|
||||||
|
pre_prompt: undefined,
|
||||||
|
system_prompt: undefined,
|
||||||
|
ai_prompt: undefined,
|
||||||
|
user_prompt: undefined,
|
||||||
prompt_template: undefined,
|
prompt_template: undefined,
|
||||||
|
model_path: undefined,
|
||||||
llama_model_path: undefined,
|
llama_model_path: undefined,
|
||||||
mmproj: undefined,
|
mmproj: undefined,
|
||||||
|
cont_batching: undefined,
|
||||||
vision_model: undefined,
|
vision_model: undefined,
|
||||||
text_model: undefined,
|
text_model: undefined,
|
||||||
|
engine: undefined,
|
||||||
}
|
}
|
||||||
const settingParams: ModelSettingParams = {}
|
const settingParams: ModelSettingParams = {}
|
||||||
|
|
||||||
@ -15,7 +15,6 @@ export type ModelInfo = {
|
|||||||
* Represents the inference engine.
|
* Represents the inference engine.
|
||||||
* @stored
|
* @stored
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export enum InferenceEngine {
|
export enum InferenceEngine {
|
||||||
anthropic = 'anthropic',
|
anthropic = 'anthropic',
|
||||||
mistral = 'mistral',
|
mistral = 'mistral',
|
||||||
@ -34,6 +33,7 @@ export enum InferenceEngine {
|
|||||||
cortex_tensorrtllm = 'tensorrt-llm',
|
cortex_tensorrtllm = 'tensorrt-llm',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Represents an artifact of a model, including its filename and URL
|
||||||
export type ModelArtifact = {
|
export type ModelArtifact = {
|
||||||
filename: string
|
filename: string
|
||||||
url: string
|
url: string
|
||||||
@ -105,6 +105,7 @@ export type Model = {
|
|||||||
engine: InferenceEngine
|
engine: InferenceEngine
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Represents metadata associated with a model
|
||||||
export type ModelMetadata = {
|
export type ModelMetadata = {
|
||||||
author: string
|
author: string
|
||||||
tags: string[]
|
tags: string[]
|
||||||
@ -125,14 +126,20 @@ export type ModelSettingParams = {
|
|||||||
n_parallel?: number
|
n_parallel?: number
|
||||||
cpu_threads?: number
|
cpu_threads?: number
|
||||||
prompt_template?: string
|
prompt_template?: string
|
||||||
|
pre_prompt?: string
|
||||||
system_prompt?: string
|
system_prompt?: string
|
||||||
ai_prompt?: string
|
ai_prompt?: string
|
||||||
user_prompt?: string
|
user_prompt?: string
|
||||||
|
// path param
|
||||||
|
model_path?: string
|
||||||
|
// legacy path param
|
||||||
llama_model_path?: string
|
llama_model_path?: string
|
||||||
|
// clip model path
|
||||||
mmproj?: string
|
mmproj?: string
|
||||||
cont_batching?: boolean
|
cont_batching?: boolean
|
||||||
vision_model?: boolean
|
vision_model?: boolean
|
||||||
text_model?: boolean
|
text_model?: boolean
|
||||||
|
engine?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -151,6 +158,12 @@ export type ModelRuntimeParams = {
|
|||||||
engine?: string
|
engine?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Represents a model that failed to initialize, including the error
|
||||||
export type ModelInitFailed = Model & {
|
export type ModelInitFailed = Model & {
|
||||||
error: Error
|
error: Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ModelParams types
|
||||||
|
*/
|
||||||
|
export type ModelParams = ModelRuntimeParams | ModelSettingParams
|
||||||
|
|||||||
@ -10,11 +10,12 @@ import {
|
|||||||
Model,
|
Model,
|
||||||
executeOnMain,
|
executeOnMain,
|
||||||
systemInformation,
|
systemInformation,
|
||||||
log,
|
|
||||||
joinPath,
|
joinPath,
|
||||||
dirName,
|
dirName,
|
||||||
LocalOAIEngine,
|
LocalOAIEngine,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
|
getJanDataFolderPath,
|
||||||
|
extractModelLoadParams,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import PQueue from 'p-queue'
|
import PQueue from 'p-queue'
|
||||||
import ky from 'ky'
|
import ky from 'ky'
|
||||||
@ -62,24 +63,38 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
override async loadModel(
|
override async loadModel(
|
||||||
model: Model & { file_path?: string }
|
model: Model & { file_path?: string }
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
// Legacy model cache - should import
|
if (
|
||||||
if (model.engine === InferenceEngine.nitro && model.file_path) {
|
model.engine === InferenceEngine.nitro &&
|
||||||
// Try importing the model
|
model.settings.llama_model_path
|
||||||
const modelPath = await this.modelPath(model)
|
) {
|
||||||
await this.queue.add(() =>
|
// Legacy chat model support
|
||||||
ky
|
model.settings = {
|
||||||
.post(`${CORTEX_API_URL}/v1/models/${model.id}`, {
|
...model.settings,
|
||||||
json: { model: model.id, modelPath: modelPath },
|
llama_model_path: await getModelFilePath(
|
||||||
})
|
model.id,
|
||||||
.json()
|
model.settings.llama_model_path
|
||||||
.catch((e) => log(e.message ?? e ?? ''))
|
),
|
||||||
)
|
}
|
||||||
|
} else {
|
||||||
|
const { llama_model_path, ...settings } = model.settings
|
||||||
|
model.settings = settings
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.engine === InferenceEngine.nitro && model.settings.mmproj) {
|
||||||
|
// Legacy clip vision model support
|
||||||
|
model.settings = {
|
||||||
|
...model.settings,
|
||||||
|
mmproj: await getModelFilePath(model.id, model.settings.mmproj),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const { mmproj, ...settings } = model.settings
|
||||||
|
model.settings = settings
|
||||||
}
|
}
|
||||||
|
|
||||||
return await ky
|
return await ky
|
||||||
.post(`${CORTEX_API_URL}/v1/models/start`, {
|
.post(`${CORTEX_API_URL}/v1/models/start`, {
|
||||||
json: {
|
json: {
|
||||||
...model.settings,
|
...extractModelLoadParams(model.settings),
|
||||||
model: model.id,
|
model: model.id,
|
||||||
engine:
|
engine:
|
||||||
model.engine === InferenceEngine.nitro // Legacy model cache
|
model.engine === InferenceEngine.nitro // Legacy model cache
|
||||||
@ -131,3 +146,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
.then(() => {})
|
.then(() => {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Legacy
|
||||||
|
export const getModelFilePath = async (
|
||||||
|
id: string,
|
||||||
|
file: string
|
||||||
|
): Promise<string> => {
|
||||||
|
return joinPath([await getJanDataFolderPath(), 'models', id, file])
|
||||||
|
}
|
||||||
|
///
|
||||||
|
|||||||
@ -1,13 +1,7 @@
|
|||||||
import PQueue from 'p-queue'
|
import PQueue from 'p-queue'
|
||||||
import ky from 'ky'
|
import ky from 'ky'
|
||||||
import {
|
import { events, extractModelLoadParams, Model, ModelEvent } from '@janhq/core'
|
||||||
DownloadEvent,
|
import { extractInferenceParams } from '@janhq/core'
|
||||||
events,
|
|
||||||
Model,
|
|
||||||
ModelEvent,
|
|
||||||
ModelRuntimeParams,
|
|
||||||
ModelSettingParams,
|
|
||||||
} from '@janhq/core'
|
|
||||||
/**
|
/**
|
||||||
* cortex.cpp Model APIs interface
|
* cortex.cpp Model APIs interface
|
||||||
*/
|
*/
|
||||||
@ -204,20 +198,17 @@ export class CortexAPI implements ICortexAPI {
|
|||||||
* @returns
|
* @returns
|
||||||
*/
|
*/
|
||||||
private transformModel(model: any) {
|
private transformModel(model: any) {
|
||||||
model.parameters = setParameters<ModelRuntimeParams>(model)
|
model.parameters = {
|
||||||
model.settings = setParameters<ModelSettingParams>(model)
|
...extractInferenceParams(model),
|
||||||
model.metadata = {
|
...model.parameters,
|
||||||
|
}
|
||||||
|
model.settings = {
|
||||||
|
...extractModelLoadParams(model),
|
||||||
|
...model.settings,
|
||||||
|
}
|
||||||
|
model.metadata = model.metadata ?? {
|
||||||
tags: [],
|
tags: [],
|
||||||
}
|
}
|
||||||
return model as Model
|
return model as Model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilteredParams<T> = {
|
|
||||||
[K in keyof T]: T[K]
|
|
||||||
}
|
|
||||||
|
|
||||||
function setParameters<T>(params: T): T {
|
|
||||||
const filteredParams: FilteredParams<T> = { ...params }
|
|
||||||
return filteredParams
|
|
||||||
}
|
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import {
|
|||||||
Thread,
|
Thread,
|
||||||
EngineManager,
|
EngineManager,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
|
extractInferenceParams,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
import { ulid } from 'ulidx'
|
import { ulid } from 'ulidx'
|
||||||
@ -22,7 +23,6 @@ import { ulid } from 'ulidx'
|
|||||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||||
|
|
||||||
import { isLocalEngine } from '@/utils/modelEngine'
|
import { isLocalEngine } from '@/utils/modelEngine'
|
||||||
import { extractInferenceParams } from '@/utils/modelParam'
|
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import {
|
import {
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import {
|
|||||||
ToolManager,
|
ToolManager,
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@ -23,10 +24,6 @@ import {
|
|||||||
import { Stack } from '@/utils/Stack'
|
import { Stack } from '@/utils/Stack'
|
||||||
import { compressImage, getBase64 } from '@/utils/base64'
|
import { compressImage, getBase64 } from '@/utils/base64'
|
||||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||||
import {
|
|
||||||
extractInferenceParams,
|
|
||||||
extractModelLoadParams,
|
|
||||||
} from '@/utils/modelParam'
|
|
||||||
|
|
||||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||||
|
|
||||||
|
|||||||
@ -6,15 +6,12 @@ import {
|
|||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
Thread,
|
Thread,
|
||||||
ThreadAssistantInfo,
|
ThreadAssistantInfo,
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
|
|
||||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import {
|
|
||||||
extractInferenceParams,
|
|
||||||
extractModelLoadParams,
|
|
||||||
} from '@/utils/modelParam'
|
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||||
|
|
||||||
|
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||||
import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi'
|
import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi'
|
||||||
import { useAtomValue, useSetAtom } from 'jotai'
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react'
|
import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react'
|
||||||
@ -16,11 +17,6 @@ import { useClipboard } from '@/hooks/useClipboard'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
|
|
||||||
import {
|
|
||||||
extractInferenceParams,
|
|
||||||
extractModelLoadParams,
|
|
||||||
} from '@/utils/modelParam'
|
|
||||||
|
|
||||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,8 @@ import {
|
|||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
SettingComponentProps,
|
SettingComponentProps,
|
||||||
SliderComponentProps,
|
SliderComponentProps,
|
||||||
|
extractInferenceParams,
|
||||||
|
extractModelLoadParams,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import {
|
import {
|
||||||
Tabs,
|
Tabs,
|
||||||
@ -31,10 +33,6 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
|||||||
|
|
||||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||||
import { isLocalEngine } from '@/utils/modelEngine'
|
import { isLocalEngine } from '@/utils/modelEngine'
|
||||||
import {
|
|
||||||
extractInferenceParams,
|
|
||||||
extractModelLoadParams,
|
|
||||||
} from '@/utils/modelParam'
|
|
||||||
|
|
||||||
import PromptTemplateSetting from './PromptTemplateSetting'
|
import PromptTemplateSetting from './PromptTemplateSetting'
|
||||||
import Tools from './Tools'
|
import Tools from './Tools'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user