fix: correct model settings on startup and strip down irrelevant model parameters
This commit is contained in:
parent
90c7420c34
commit
3643c8866e
@ -3,3 +3,8 @@
|
||||
* @module
|
||||
*/
|
||||
export { ModelManager } from './manager'
|
||||
|
||||
/**
|
||||
* Export all utils
|
||||
*/
|
||||
export * from './utils'
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
// web/utils/modelParam.test.ts
|
||||
import { normalizeValue, validationRules } from './modelParam'
|
||||
import { extractModelLoadParams } from './modelParam';
|
||||
import { extractInferenceParams } from './modelParam';
|
||||
import {
|
||||
normalizeValue,
|
||||
validationRules,
|
||||
extractModelLoadParams,
|
||||
extractInferenceParams,
|
||||
} from './utils'
|
||||
|
||||
describe('validationRules', () => {
|
||||
it('should validate temperature correctly', () => {
|
||||
@ -151,13 +154,12 @@ describe('validationRules', () => {
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
it('should normalize invalid values for keys not listed in validationRules', () => {
|
||||
it('should normalize invalid values for keys not listed in validationRules', () => {
|
||||
expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
|
||||
expect(normalizeValue('invalid_key', 123)).toBe(123)
|
||||
expect(normalizeValue('invalid_key', true)).toBe(true)
|
||||
expect(normalizeValue('invalid_key', false)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('normalizeValue', () => {
|
||||
it('should normalize ctx_len correctly', () => {
|
||||
@ -192,19 +194,16 @@ describe('normalizeValue', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle invalid values correctly by falling back to originParams', () => {
|
||||
const modelParams = { temperature: 'invalid', token_limit: -1 }
|
||||
const originParams = { temperature: 0.5, token_limit: 100 }
|
||||
expect(extractInferenceParams(modelParams as any, originParams)).toEqual(originParams)
|
||||
})
|
||||
|
||||
it('should handle invalid values correctly by falling back to originParams', () => {
|
||||
const modelParams = { temperature: 'invalid', token_limit: -1 };
|
||||
const originParams = { temperature: 0.5, token_limit: 100 };
|
||||
expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams);
|
||||
});
|
||||
it('should return an empty object when no modelParams are provided', () => {
|
||||
expect(extractModelLoadParams()).toEqual({})
|
||||
})
|
||||
|
||||
|
||||
it('should return an empty object when no modelParams are provided', () => {
|
||||
expect(extractModelLoadParams()).toEqual({});
|
||||
});
|
||||
|
||||
|
||||
it('should return an empty object when no modelParams are provided', () => {
|
||||
expect(extractInferenceParams()).toEqual({});
|
||||
});
|
||||
it('should return an empty object when no modelParams are provided', () => {
|
||||
expect(extractInferenceParams()).toEqual({})
|
||||
})
|
||||
@ -1,26 +1,20 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||
|
||||
import { ModelParams } from '@/types/model'
|
||||
import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types'
|
||||
|
||||
/**
|
||||
* Validation rules for model parameters
|
||||
*/
|
||||
export const validationRules: { [key: string]: (value: any) => boolean } = {
|
||||
temperature: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 2,
|
||||
temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
|
||||
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
stream: (value: any) => typeof value === 'boolean',
|
||||
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
stop: (value: any) =>
|
||||
Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||
frequency_penalty: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 1,
|
||||
presence_penalty: (value: any) =>
|
||||
typeof value === 'number' && value >= 0 && value <= 1,
|
||||
stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
|
||||
frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
|
||||
|
||||
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
ngl: (value: any) => Number.isInteger(value) && value >= 0,
|
||||
@ -76,6 +70,7 @@ export const extractInferenceParams = (
|
||||
stop: undefined,
|
||||
frequency_penalty: undefined,
|
||||
presence_penalty: undefined,
|
||||
engine: undefined,
|
||||
}
|
||||
|
||||
const runtimeParams: ModelRuntimeParams = {}
|
||||
@ -119,11 +114,18 @@ export const extractModelLoadParams = (
|
||||
embedding: undefined,
|
||||
n_parallel: undefined,
|
||||
cpu_threads: undefined,
|
||||
pre_prompt: undefined,
|
||||
system_prompt: undefined,
|
||||
ai_prompt: undefined,
|
||||
user_prompt: undefined,
|
||||
prompt_template: undefined,
|
||||
model_path: undefined,
|
||||
llama_model_path: undefined,
|
||||
mmproj: undefined,
|
||||
cont_batching: undefined,
|
||||
vision_model: undefined,
|
||||
text_model: undefined,
|
||||
engine: undefined,
|
||||
}
|
||||
const settingParams: ModelSettingParams = {}
|
||||
|
||||
@ -15,7 +15,6 @@ export type ModelInfo = {
|
||||
* Represents the inference engine.
|
||||
* @stored
|
||||
*/
|
||||
|
||||
export enum InferenceEngine {
|
||||
anthropic = 'anthropic',
|
||||
mistral = 'mistral',
|
||||
@ -34,6 +33,7 @@ export enum InferenceEngine {
|
||||
cortex_tensorrtllm = 'tensorrt-llm',
|
||||
}
|
||||
|
||||
// Represents an artifact of a model, including its filename and URL
|
||||
export type ModelArtifact = {
|
||||
filename: string
|
||||
url: string
|
||||
@ -105,6 +105,7 @@ export type Model = {
|
||||
engine: InferenceEngine
|
||||
}
|
||||
|
||||
// Represents metadata associated with a model
|
||||
export type ModelMetadata = {
|
||||
author: string
|
||||
tags: string[]
|
||||
@ -125,14 +126,20 @@ export type ModelSettingParams = {
|
||||
n_parallel?: number
|
||||
cpu_threads?: number
|
||||
prompt_template?: string
|
||||
pre_prompt?: string
|
||||
system_prompt?: string
|
||||
ai_prompt?: string
|
||||
user_prompt?: string
|
||||
// path param
|
||||
model_path?: string
|
||||
// legacy path param
|
||||
llama_model_path?: string
|
||||
// clip model path
|
||||
mmproj?: string
|
||||
cont_batching?: boolean
|
||||
vision_model?: boolean
|
||||
text_model?: boolean
|
||||
engine?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
@ -151,6 +158,12 @@ export type ModelRuntimeParams = {
|
||||
engine?: string
|
||||
}
|
||||
|
||||
// Represents a model that failed to initialize, including the error
|
||||
export type ModelInitFailed = Model & {
|
||||
error: Error
|
||||
}
|
||||
|
||||
/**
|
||||
* ModelParams types
|
||||
*/
|
||||
export type ModelParams = ModelRuntimeParams | ModelSettingParams
|
||||
|
||||
@ -10,11 +10,12 @@ import {
|
||||
Model,
|
||||
executeOnMain,
|
||||
systemInformation,
|
||||
log,
|
||||
joinPath,
|
||||
dirName,
|
||||
LocalOAIEngine,
|
||||
InferenceEngine,
|
||||
getJanDataFolderPath,
|
||||
extractModelLoadParams,
|
||||
} from '@janhq/core'
|
||||
import PQueue from 'p-queue'
|
||||
import ky from 'ky'
|
||||
@ -62,24 +63,38 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
override async loadModel(
|
||||
model: Model & { file_path?: string }
|
||||
): Promise<void> {
|
||||
// Legacy model cache - should import
|
||||
if (model.engine === InferenceEngine.nitro && model.file_path) {
|
||||
// Try importing the model
|
||||
const modelPath = await this.modelPath(model)
|
||||
await this.queue.add(() =>
|
||||
ky
|
||||
.post(`${CORTEX_API_URL}/v1/models/${model.id}`, {
|
||||
json: { model: model.id, modelPath: modelPath },
|
||||
})
|
||||
.json()
|
||||
.catch((e) => log(e.message ?? e ?? ''))
|
||||
)
|
||||
if (
|
||||
model.engine === InferenceEngine.nitro &&
|
||||
model.settings.llama_model_path
|
||||
) {
|
||||
// Legacy chat model support
|
||||
model.settings = {
|
||||
...model.settings,
|
||||
llama_model_path: await getModelFilePath(
|
||||
model.id,
|
||||
model.settings.llama_model_path
|
||||
),
|
||||
}
|
||||
} else {
|
||||
const { llama_model_path, ...settings } = model.settings
|
||||
model.settings = settings
|
||||
}
|
||||
|
||||
if (model.engine === InferenceEngine.nitro && model.settings.mmproj) {
|
||||
// Legacy clip vision model support
|
||||
model.settings = {
|
||||
...model.settings,
|
||||
mmproj: await getModelFilePath(model.id, model.settings.mmproj),
|
||||
}
|
||||
} else {
|
||||
const { mmproj, ...settings } = model.settings
|
||||
model.settings = settings
|
||||
}
|
||||
|
||||
return await ky
|
||||
.post(`${CORTEX_API_URL}/v1/models/start`, {
|
||||
json: {
|
||||
...model.settings,
|
||||
...extractModelLoadParams(model.settings),
|
||||
model: model.id,
|
||||
engine:
|
||||
model.engine === InferenceEngine.nitro // Legacy model cache
|
||||
@ -131,3 +146,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
.then(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
/// Legacy
|
||||
export const getModelFilePath = async (
|
||||
id: string,
|
||||
file: string
|
||||
): Promise<string> => {
|
||||
return joinPath([await getJanDataFolderPath(), 'models', id, file])
|
||||
}
|
||||
///
|
||||
|
||||
@ -1,13 +1,7 @@
|
||||
import PQueue from 'p-queue'
|
||||
import ky from 'ky'
|
||||
import {
|
||||
DownloadEvent,
|
||||
events,
|
||||
Model,
|
||||
ModelEvent,
|
||||
ModelRuntimeParams,
|
||||
ModelSettingParams,
|
||||
} from '@janhq/core'
|
||||
import { events, extractModelLoadParams, Model, ModelEvent } from '@janhq/core'
|
||||
import { extractInferenceParams } from '@janhq/core'
|
||||
/**
|
||||
* cortex.cpp Model APIs interface
|
||||
*/
|
||||
@ -204,20 +198,17 @@ export class CortexAPI implements ICortexAPI {
|
||||
* @returns
|
||||
*/
|
||||
private transformModel(model: any) {
|
||||
model.parameters = setParameters<ModelRuntimeParams>(model)
|
||||
model.settings = setParameters<ModelSettingParams>(model)
|
||||
model.metadata = {
|
||||
model.parameters = {
|
||||
...extractInferenceParams(model),
|
||||
...model.parameters,
|
||||
}
|
||||
model.settings = {
|
||||
...extractModelLoadParams(model),
|
||||
...model.settings,
|
||||
}
|
||||
model.metadata = model.metadata ?? {
|
||||
tags: [],
|
||||
}
|
||||
return model as Model
|
||||
}
|
||||
}
|
||||
|
||||
type FilteredParams<T> = {
|
||||
[K in keyof T]: T[K]
|
||||
}
|
||||
|
||||
function setParameters<T>(params: T): T {
|
||||
const filteredParams: FilteredParams<T> = { ...params }
|
||||
return filteredParams
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ import {
|
||||
Thread,
|
||||
EngineManager,
|
||||
InferenceEngine,
|
||||
extractInferenceParams,
|
||||
} from '@janhq/core'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { ulid } from 'ulidx'
|
||||
@ -22,7 +23,6 @@ import { ulid } from 'ulidx'
|
||||
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
|
||||
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
import { extractInferenceParams } from '@/utils/modelParam'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
|
||||
@ -12,6 +12,7 @@ import {
|
||||
ToolManager,
|
||||
ChatCompletionMessage,
|
||||
} from '@janhq/core'
|
||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import {
|
||||
@ -23,10 +24,6 @@ import {
|
||||
import { Stack } from '@/utils/Stack'
|
||||
import { compressImage, getBase64 } from '@/utils/base64'
|
||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
|
||||
|
||||
@ -6,15 +6,12 @@ import {
|
||||
InferenceEngine,
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
|
||||
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
|
||||
import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi'
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react'
|
||||
@ -16,11 +17,6 @@ import { useClipboard } from '@/hooks/useClipboard'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
|
||||
|
||||
@ -4,6 +4,8 @@ import {
|
||||
InferenceEngine,
|
||||
SettingComponentProps,
|
||||
SliderComponentProps,
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@janhq/core'
|
||||
import {
|
||||
Tabs,
|
||||
@ -31,10 +33,6 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
import { isLocalEngine } from '@/utils/modelEngine'
|
||||
import {
|
||||
extractInferenceParams,
|
||||
extractModelLoadParams,
|
||||
} from '@/utils/modelParam'
|
||||
|
||||
import PromptTemplateSetting from './PromptTemplateSetting'
|
||||
import Tools from './Tools'
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user