diff --git a/core/src/browser/models/index.ts b/core/src/browser/models/index.ts index c16479b2b..81d37e501 100644 --- a/core/src/browser/models/index.ts +++ b/core/src/browser/models/index.ts @@ -3,3 +3,8 @@ * @module */ export { ModelManager } from './manager' + +/** + * Export all utils + */ +export * from './utils' diff --git a/web/utils/modelParam.test.ts b/core/src/browser/models/utils.test.ts similarity index 87% rename from web/utils/modelParam.test.ts rename to core/src/browser/models/utils.test.ts index 97325d277..ac876c3dc 100644 --- a/web/utils/modelParam.test.ts +++ b/core/src/browser/models/utils.test.ts @@ -1,7 +1,10 @@ // web/utils/modelParam.test.ts -import { normalizeValue, validationRules } from './modelParam' -import { extractModelLoadParams } from './modelParam'; -import { extractInferenceParams } from './modelParam'; +import { + normalizeValue, + validationRules, + extractModelLoadParams, + extractInferenceParams, +} from './utils' describe('validationRules', () => { it('should validate temperature correctly', () => { @@ -151,13 +154,12 @@ describe('validationRules', () => { }) }) - - it('should normalize invalid values for keys not listed in validationRules', () => { - expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid') - expect(normalizeValue('invalid_key', 123)).toBe(123) - expect(normalizeValue('invalid_key', true)).toBe(true) - expect(normalizeValue('invalid_key', false)).toBe(false) - }) +it('should normalize invalid values for keys not listed in validationRules', () => { + expect(normalizeValue('invalid_key', 'invalid')).toBe('invalid') + expect(normalizeValue('invalid_key', 123)).toBe(123) + expect(normalizeValue('invalid_key', true)).toBe(true) + expect(normalizeValue('invalid_key', false)).toBe(false) +}) describe('normalizeValue', () => { it('should normalize ctx_len correctly', () => { @@ -192,19 +194,16 @@ describe('normalizeValue', () => { }) }) +it('should handle invalid values correctly by falling back to originParams', () => { + const modelParams = { temperature: 'invalid', token_limit: -1 } + const originParams = { temperature: 0.5, token_limit: 100 } + expect(extractInferenceParams(modelParams as any, originParams)).toEqual(originParams) +}) - it('should handle invalid values correctly by falling back to originParams', () => { - const modelParams = { temperature: 'invalid', token_limit: -1 }; - const originParams = { temperature: 0.5, token_limit: 100 }; - expect(extractInferenceParams(modelParams, originParams)).toEqual(originParams); - }); +it('should return an empty object when no modelParams are provided', () => { + expect(extractModelLoadParams()).toEqual({}) +}) - - it('should return an empty object when no modelParams are provided', () => { - expect(extractModelLoadParams()).toEqual({}); - }); - - - it('should return an empty object when no modelParams are provided', () => { - expect(extractInferenceParams()).toEqual({}); - }); +it('should return an empty object when no modelParams are provided', () => { + expect(extractInferenceParams()).toEqual({}) +}) diff --git a/web/utils/modelParam.ts b/core/src/browser/models/utils.ts similarity index 86% rename from web/utils/modelParam.ts rename to core/src/browser/models/utils.ts index 315aeaeb3..0e52441b2 100644 --- a/web/utils/modelParam.ts +++ b/core/src/browser/models/utils.ts @@ -1,26 +1,20 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/naming-convention */ -import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core' - -import { ModelParams } from '@/types/model' +import { ModelParams, ModelRuntimeParams, ModelSettingParams } from '../../types' /** * Validation rules for model parameters */ export const validationRules: { [key: string]: (value: any) => boolean } = { - temperature: (value: any) => - typeof value === 'number' && value >= 0 && value <= 2, + temperature: (value: any) => typeof value === 'number' && value >= 0 && value <= 2, token_limit: (value: any) => Number.isInteger(value) && value >= 0, top_k: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, top_p: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, stream: (value: any) => typeof value === 'boolean', max_tokens: (value: any) => Number.isInteger(value) && value >= 0, - stop: (value: any) => - Array.isArray(value) && value.every((v) => typeof v === 'string'), - frequency_penalty: (value: any) => - typeof value === 'number' && value >= 0 && value <= 1, - presence_penalty: (value: any) => - typeof value === 'number' && value >= 0 && value <= 1, + stop: (value: any) => Array.isArray(value) && value.every((v) => typeof v === 'string'), + frequency_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, + presence_penalty: (value: any) => typeof value === 'number' && value >= 0 && value <= 1, ctx_len: (value: any) => Number.isInteger(value) && value >= 0, ngl: (value: any) => Number.isInteger(value) && value >= 0, @@ -76,6 +70,7 @@ export const extractInferenceParams = ( stop: undefined, frequency_penalty: undefined, presence_penalty: undefined, + engine: undefined, } const runtimeParams: ModelRuntimeParams = {} @@ -119,11 +114,18 @@ export const extractModelLoadParams = ( embedding: undefined, n_parallel: undefined, cpu_threads: undefined, + pre_prompt: undefined, + system_prompt: undefined, + ai_prompt: undefined, + user_prompt: undefined, prompt_template: undefined, + model_path: undefined, llama_model_path: undefined, mmproj: undefined, + cont_batching: undefined, vision_model: undefined, text_model: undefined, + engine: undefined, } const settingParams: ModelSettingParams = {} diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index 25ed95b8d..7b67a8e94 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -15,7 +15,6 @@ export type ModelInfo = { * Represents the inference engine. * @stored */ - export enum InferenceEngine { anthropic = 'anthropic', mistral = 'mistral', @@ -34,6 +33,7 @@ export enum InferenceEngine { cortex_tensorrtllm = 'tensorrt-llm', } +// Represents an artifact of a model, including its filename and URL export type ModelArtifact = { filename: string url: string @@ -105,6 +105,7 @@ export type Model = { engine: InferenceEngine } +// Represents metadata associated with a model export type ModelMetadata = { author: string tags: string[] @@ -125,14 +126,20 @@ export type ModelSettingParams = { n_parallel?: number cpu_threads?: number prompt_template?: string + pre_prompt?: string system_prompt?: string ai_prompt?: string user_prompt?: string + // path param + model_path?: string + // legacy path param llama_model_path?: string + // clip model path mmproj?: string cont_batching?: boolean vision_model?: boolean text_model?: boolean + engine?: boolean } /** @@ -151,6 +158,12 @@ export type ModelRuntimeParams = { engine?: string } +// Represents a model that failed to initialize, including the error export type ModelInitFailed = Model & { error: Error } + +/** + * ModelParams types + */ +export type ModelParams = ModelRuntimeParams | ModelSettingParams diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 364bfe79c..8143a71cf 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -10,11 +10,12 @@ import { Model, executeOnMain, systemInformation, - log, joinPath, dirName, LocalOAIEngine, InferenceEngine, + getJanDataFolderPath, + extractModelLoadParams, } from '@janhq/core' import PQueue from 'p-queue' import ky from 'ky' @@ -62,24 +63,38 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { override async loadModel( model: Model & { file_path?: string } ): Promise { - // Legacy model cache - should import - if (model.engine === InferenceEngine.nitro && model.file_path) { - // Try importing the model - const modelPath = await this.modelPath(model) - await this.queue.add(() => - ky - .post(`${CORTEX_API_URL}/v1/models/${model.id}`, { - json: { model: model.id, modelPath: modelPath }, - }) - .json() - .catch((e) => log(e.message ?? e ?? '')) - ) + if ( + model.engine === InferenceEngine.nitro && + model.settings.llama_model_path + ) { + // Legacy chat model support + model.settings = { + ...model.settings, + llama_model_path: await getModelFilePath( + model.id, + model.settings.llama_model_path + ), + } + } else { + const { llama_model_path, ...settings } = model.settings + model.settings = settings + } + + if (model.engine === InferenceEngine.nitro && model.settings.mmproj) { + // Legacy clip vision model support + model.settings = { + ...model.settings, + mmproj: await getModelFilePath(model.id, model.settings.mmproj), + } + } else { + const { mmproj, ...settings } = model.settings + model.settings = settings } return await ky .post(`${CORTEX_API_URL}/v1/models/start`, { json: { - ...model.settings, + ...extractModelLoadParams(model.settings), model: model.id, engine: model.engine === InferenceEngine.nitro // Legacy model cache @@ -131,3 +146,12 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { .then(() => {}) } } + +/// Legacy +export const getModelFilePath = async ( + id: string, + file: string +): Promise => { + return joinPath([await getJanDataFolderPath(), 'models', id, file]) +} +/// diff --git a/extensions/model-extension/src/cortex.ts b/extensions/model-extension/src/cortex.ts index c690f0c16..7f48f10ec 100644 --- a/extensions/model-extension/src/cortex.ts +++ b/extensions/model-extension/src/cortex.ts @@ -1,13 +1,7 @@ import PQueue from 'p-queue' import ky from 'ky' -import { - DownloadEvent, - events, - Model, - ModelEvent, - ModelRuntimeParams, - ModelSettingParams, -} from '@janhq/core' +import { events, extractModelLoadParams, Model, ModelEvent } from '@janhq/core' +import { extractInferenceParams } from '@janhq/core' /** * cortex.cpp Model APIs interface */ @@ -204,20 +198,17 @@ export class CortexAPI implements ICortexAPI { * @returns */ private transformModel(model: any) { - model.parameters = setParameters(model) - model.settings = setParameters(model) - model.metadata = { + model.parameters = { + ...extractInferenceParams(model), + ...model.parameters, + } + model.settings = { + ...extractModelLoadParams(model), + ...model.settings, + } + model.metadata = model.metadata ?? { tags: [], } return model as Model } } - -type FilteredParams = { - [K in keyof T]: T[K] -} - -function setParameters(params: T): T { - const filteredParams: FilteredParams = { ...params } - return filteredParams -} diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 72d35aad3..0f5cf389d 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -15,6 +15,7 @@ import { Thread, EngineManager, InferenceEngine, + extractInferenceParams, } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulidx' @@ -22,7 +23,6 @@ import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' import { isLocalEngine } from '@/utils/modelEngine' -import { extractInferenceParams } from '@/utils/modelParam' import { extensionManager } from '@/extension' import { diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 4bc91cad2..cda53b24a 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -12,6 +12,7 @@ import { ToolManager, ChatCompletionMessage, } from '@janhq/core' +import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { @@ -23,10 +24,6 @@ import { import { Stack } from '@/utils/Stack' import { compressImage, getBase64 } from '@/utils/base64' import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' -import { - extractInferenceParams, - extractModelLoadParams, -} from '@/utils/modelParam' import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index 2af6e3323..6eb7c3c5a 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -6,15 +6,12 @@ import { InferenceEngine, Thread, ThreadAssistantInfo, + extractInferenceParams, + extractModelLoadParams, } from '@janhq/core' import { useAtom, useAtomValue, useSetAtom } from 'jotai' -import { - extractInferenceParams, - extractModelLoadParams, -} from '@/utils/modelParam' - import { extensionManager } from '@/extension' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { diff --git a/web/screens/LocalServer/LocalServerRightPanel/index.tsx b/web/screens/LocalServer/LocalServerRightPanel/index.tsx index 628a61512..0d2fe0f7c 100644 --- a/web/screens/LocalServer/LocalServerRightPanel/index.tsx +++ b/web/screens/LocalServer/LocalServerRightPanel/index.tsx @@ -1,5 +1,6 @@ import { useCallback, useEffect, useMemo, useState } from 'react' +import { extractInferenceParams, extractModelLoadParams } from '@janhq/core' import { Accordion, AccordionItem, Input, Tooltip } from '@janhq/joi' import { useAtomValue, useSetAtom } from 'jotai' import { AlertTriangleIcon, CheckIcon, CopyIcon, InfoIcon } from 'lucide-react' @@ -16,11 +17,6 @@ import { useClipboard } from '@/hooks/useClipboard' import { getConfigurationsData } from '@/utils/componentSettings' -import { - extractInferenceParams, - extractModelLoadParams, -} from '@/utils/modelParam' - import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' diff --git a/web/screens/Thread/ThreadRightPanel/index.tsx b/web/screens/Thread/ThreadRightPanel/index.tsx index 5a8fd3ebb..674c97766 100644 --- a/web/screens/Thread/ThreadRightPanel/index.tsx +++ b/web/screens/Thread/ThreadRightPanel/index.tsx @@ -4,6 +4,8 @@ import { InferenceEngine, SettingComponentProps, SliderComponentProps, + extractInferenceParams, + extractModelLoadParams, } from '@janhq/core' import { Tabs, @@ -31,10 +33,6 @@ import useUpdateModelParameters from '@/hooks/useUpdateModelParameters' import { getConfigurationsData } from '@/utils/componentSettings' import { isLocalEngine } from '@/utils/modelEngine' -import { - extractInferenceParams, - extractModelLoadParams, -} from '@/utils/modelParam' import PromptTemplateSetting from './PromptTemplateSetting' import Tools from './Tools'