fix: correct model settings on startup and strip down irrelevant model parameters

This commit is contained in:
Louis 2024-10-25 12:33:43 +07:00
parent 90c7420c34
commit 3643c8866e
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
11 changed files with 111 additions and 89 deletions

View File

@ -3,3 +3,8 @@
* @module
*/
export { ModelManager } from './manager'
/**
* Export all utils
*/
export * from './utils'

View File

@ -1,7 +1,10 @@
// web/utils/modelParam.test.ts
import { normalizeValue, validationRules } from './modelParam'
import { extractModelLoadParams } from './modelParam';
import { extractInferenceParams } from './modelParam';
import {
normalizeValue,
validationRules,
extractModelLoadParams,
extractInferenceParams,
} from './utils'
describe('validationRules', () => {
it('should validate temperature correctly', () => {
@ -151,13 +154,12 @@ describe('validationRules', () => {
})
})
it('should normalize invalid values for keys not listed in validationRules', () => {
it('should normalize invalid values for keys not listed in validationRules', () => {
expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid')
expect(normalizeValue('invalid_key', 123)).toBe(123)
expect(normalizeValue('invalid_key', true)).toBe(true)
expect(normalizeValue('invalid_key', false)).toBe(false)
})
})
describe('normalizeValue', () => {
it('should normalize ctx_len correctly', () => {
@ -192,19 +194,16 @@ describe('normalizeValue', () => {
})
})
it('should handle invalid values correctly by falling back to originParams', () => {
const modelParams = { temperature: 'invalid', token_limit: -1 }
const originParams = { temperature: 0.5, token_limit: 100 }
expect(extractInferenceParams(modelParams as any, originParams)).toEqual(originParams)
})
it('should handle invalid values correctly by falling back to originParams', () => {
const modelParams = { temperature: 'invalid', token_limit: -1 };
const originParams = { temperature: 0.5, token_limit: 100 };
expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams);
});
it('should return an empty object when no modelParams are provided', () => {
expect(extractModelLoadParams()).toEqual({})
})
it('should return an empty object when no modelParams are provided', () => {
expect(extractModelLoadParams()).toEqual({});
});
it('should return an empty object when no modelParams are provided', () => {
expect(extractInferenceParams()).toEqual({});
});
it('should return an empty object when no modelParams are provided', () => {
expect(extractInferenceParams()).toEqual({})
})

View File

@ -1,26 +1,20 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/naming-convention */
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
import { ModelParams } from '@/types/model'
import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types'
/**
* Validation rules for model parameters
*/
export const validationRules: { [key: string]: (value: any) => boolean } = {
temperature: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 2,
temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2,
token_limit: (value: any) => Number.isInteger(value) && value >= 0,
top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
stream: (value: any) => typeof value === 'boolean',
max_tokens: (value: any) => Number.isInteger(value) && value >= 0,
stop: (value: any) =>
Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) =>
typeof value === 'number' && value >= 0 && value <= 1,
stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'),
frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1,
ctx_len: (value: any) => Number.isInteger(value) && value >= 0,
ngl: (value: any) => Number.isInteger(value) && value >= 0,
@ -76,6 +70,7 @@ export const extractInferenceParams = (
stop: undefined,
frequency_penalty: undefined,
presence_penalty: undefined,
engine: undefined,
}
const runtimeParams: ModelRuntimeParams = {}
@ -119,11 +114,18 @@ export const extractModelLoadParams = (
embedding: undefined,
n_parallel: undefined,
cpu_threads: undefined,
pre_prompt: undefined,
system_prompt: undefined,
ai_prompt: undefined,
user_prompt: undefined,
prompt_template: undefined,
model_path: undefined,
llama_model_path: undefined,
mmproj: undefined,
cont_batching: undefined,
vision_model: undefined,
text_model: undefined,
engine: undefined,
}
const settingParams: ModelSettingParams = {}

View File

@ -15,7 +15,6 @@ export type ModelInfo = {
* Represents the inference engine.
* @stored
*/
export enum InferenceEngine {
anthropic = 'anthropic',
mistral = 'mistral',
@ -34,6 +33,7 @@ export enum InferenceEngine {
cortex_tensorrtllm = 'tensorrt-llm',
}
// Represents an artifact of a model, including its filename and URL
export type ModelArtifact = {
filename: string
url: string
@ -105,6 +105,7 @@ export type Model = {
engine: InferenceEngine
}
// Represents metadata associated with a model
export type ModelMetadata = {
author: string
tags: string[]
@ -125,14 +126,20 @@ export type ModelSettingParams = {
n_parallel?: number
cpu_threads?: number
prompt_template?: string
pre_prompt?: string
system_prompt?: string
ai_prompt?: string
user_prompt?: string
// path param
model_path?: string
// legacy path param
llama_model_path?: string
// clip model path
mmproj?: string
cont_batching?: boolean
vision_model?: boolean
text_model?: boolean
engine?: boolean
}
/**
@ -151,6 +158,12 @@ export type ModelRuntimeParams = {
engine?: string
}
// Represents a model that failed to initialize, including the error
export type ModelInitFailed = Model & {
error: Error
}
/**
* ModelParams types
*/
export type ModelParams = ModelRuntimeParams | ModelSettingParams

View File

@ -10,11 +10,12 @@ import {
Model,
executeOnMain,
systemInformation,
log,
joinPath,
dirName,
LocalOAIEngine,
InferenceEngine,
getJanDataFolderPath,
extractModelLoadParams,
} from '@janhq/core'
import PQueue from 'p-queue'
import ky from 'ky'
@ -62,24 +63,38 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
override async loadModel(
model: Model & { file_path?: string }
): Promise<void> {
// Legacy model cache - should import
if (model.engine === InferenceEngine.nitro && model.file_path) {
// Try importing the model
const modelPath = await this.modelPath(model)
await this.queue.add(() =>
ky
.post(`${CORTEX_API_URL}/v1/models/${model.id}`, {
json: { model: model.id, modelPath: modelPath },
})
.json()
.catch((e) => log(e.message ?? e ?? ''))
)
if (
model.engine === InferenceEngine.nitro &&
model.settings.llama_model_path
) {
// Legacy chat model support
model.settings = {
...model.settings,
llama_model_path: await getModelFilePath(
model.id,
model.settings.llama_model_path
),
}
} else {
const { llama_model_path, ...settings } = model.settings
model.settings = settings
}
if (model.engine === InferenceEngine.nitro && model.settings.mmproj) {
// Legacy clip vision model support
model.settings = {
...model.settings,
mmproj: await getModelFilePath(model.id, model.settings.mmproj),
}
} else {
const { mmproj, ...settings } = model.settings
model.settings = settings
}
return await ky
.post(`${CORTEX_API_URL}/v1/models/start`, {
json: {
...model.settings,
...extractModelLoadParams(model.settings),
model: model.id,
engine:
model.engine === InferenceEngine.nitro // Legacy model cache
@ -131,3 +146,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
.then(() => {})
}
}
/// Legacy
export const getModelFilePath = async (
id: string,
file: string
): Promise<string> => {
return joinPath([await getJanDataFolderPath(), 'models', id, file])
}
///

View File

@ -1,13 +1,7 @@
import PQueue from 'p-queue'
import ky from 'ky'
import {
DownloadEvent,
events,
Model,
ModelEvent,
ModelRuntimeParams,
ModelSettingParams,
} from '@janhq/core'
import { events, extractModelLoadParams, Model, ModelEvent } from '@janhq/core'
import { extractInferenceParams } from '@janhq/core'
/**
* cortex.cpp Model APIs interface
*/
@ -204,20 +198,17 @@ export class CortexAPI implements ICortexAPI {
* @returns
*/
private transformModel(model: any) {
model.parameters = setParameters<ModelRuntimeParams>(model)
model.settings = setParameters<ModelSettingParams>(model)
model.metadata = {
model.parameters = {
...extractInferenceParams(model),
...model.parameters,
}
model.settings = {
...extractModelLoadParams(model),
...model.settings,
}
model.metadata = model.metadata ?? {
tags: [],
}
return model as Model
}
}
type FilteredParams<T> = {
[K in keyof T]: T[K]
}
function setParameters<T>(params: T): T {
const filteredParams: FilteredParams<T> = { ...params }
return filteredParams
}

View File

@ -15,6 +15,7 @@ import {
Thread,
EngineManager,
InferenceEngine,
extractInferenceParams,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
@ -22,7 +23,6 @@ import { ulid } from 'ulidx'
import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel'
import { isLocalEngine } from '@/utils/modelEngine'
import { extractInferenceParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import {

View File

@ -12,6 +12,7 @@ import {
ToolManager,
ChatCompletionMessage,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import {
@ -23,10 +24,6 @@ import {
import { Stack } from '@/utils/Stack'
import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'

View File

@ -6,15 +6,12 @@ import {
InferenceEngine,
Thread,
ThreadAssistantInfo,
extractInferenceParams,
extractModelLoadParams,
} from '@janhq/core'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {

View File

@ -1,5 +1,6 @@
import { useCallback, useEffect, useMemo, useState } from 'react'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi'
import { useAtomValue, useSetAtom } from 'jotai'
import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react'
@ -16,11 +17,6 @@ import { useClipboard } from '@/hooks/useClipboard'
import { getConfigurationsData } from '@/utils/componentSettings'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'

View File

@ -4,6 +4,8 @@ import {
InferenceEngine,
SettingComponentProps,
SliderComponentProps,
extractInferenceParams,
extractModelLoadParams,
} from '@janhq/core'
import {
Tabs,
@ -31,10 +33,6 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
import { getConfigurationsData } from '@/utils/componentSettings'
import { isLocalEngine } from '@/utils/modelEngine'
import {
extractInferenceParams,
extractModelLoadParams,
} from '@/utils/modelParam'
import PromptTemplateSetting from './PromptTemplateSetting'
import Tools from './Tools'