From ad9a4a0b4d49edc20316186422f3149e3d2d0f35 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 21 Aug 2024 21:28:29 +0700 Subject: [PATCH] feat: preserve model settings (#3427) * feat: preserve model settings * feat: preserve model settings across new threads * chore: lint fix * fix: feature toggle off should also affect default value retrieve --- core/src/types/model/modelEntity.ts | 3 ++ extensions/model-extension/src/index.ts | 2 +- web/containers/ModelDropdown/index.tsx | 15 +++++++- web/helpers/atoms/AppConfig.atom.ts | 7 ++++ web/helpers/atoms/Model.atom.ts | 11 ++++++ web/hooks/useCreateNewThread.ts | 19 +++++++--- web/hooks/useUpdateModelParameters.ts | 46 +++++++++++++++++++++---- web/screens/Settings/Advanced/index.tsx | 25 ++++++++++++++ 8 files changed, 115 insertions(+), 13 deletions(-) diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index ef6d59316..f154f7f04 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -107,6 +107,9 @@ export type ModelMetadata = { tags: string[] size: number cover?: string + // These settings to preserve model settings across threads + default_ctx_len?: number + default_max_tokens?: number } /** diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index a8977e07e..3855ff73d 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -351,7 +351,7 @@ export default class JanModelExtension extends ModelExtension { } /** - * Saves a machine learning model. + * Saves a model file. * @param model - The model to save. * @returns A Promise that resolves when the model is saved. */ diff --git a/web/containers/ModelDropdown/index.tsx b/web/containers/ModelDropdown/index.tsx index 0f0db390f..38ad2ccaf 100644 --- a/web/containers/ModelDropdown/index.tsx +++ b/web/containers/ModelDropdown/index.tsx @@ -45,6 +45,7 @@ import { import { extensionManager } from '@/extension' +import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { configuredModelsAtom, @@ -89,6 +90,7 @@ const ModelDropdown = ({ const featuredModel = configuredModels.filter((x) => x.metadata.tags.includes('Featured') ) + const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ dropdownOptions, @@ -161,14 +163,25 @@ const ModelDropdown = ({ if (activeThread) { // Default setting ctx_len for the model for a better onboarding experience // TODO: When Cortex support hardware instructions, we should remove this + const defaultContextLength = preserveModelSettings + ? model?.metadata?.default_ctx_len + : 2048 + const defaultMaxTokens = preserveModelSettings + ? model?.metadata?.default_max_tokens + : 2048 const overriddenSettings = model?.settings.ctx_len && model.settings.ctx_len > 2048 - ? { ctx_len: 2048 } + ? { ctx_len: defaultContextLength } + : {} + const overriddenParameters = + model?.parameters.max_tokens && model.parameters.max_tokens + ? { max_tokens: defaultMaxTokens } : {} const modelParams = { ...model?.parameters, ...model?.settings, + ...overriddenParameters, ...overriddenSettings, } diff --git a/web/helpers/atoms/AppConfig.atom.ts b/web/helpers/atoms/AppConfig.atom.ts index f4acc7dc2..e7b7efaec 100644 --- a/web/helpers/atoms/AppConfig.atom.ts +++ b/web/helpers/atoms/AppConfig.atom.ts @@ -7,6 +7,7 @@ const VULKAN_ENABLED = 'vulkanEnabled' const IGNORE_SSL = 'ignoreSSLFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' const QUICK_ASK_ENABLED = 'quickAskEnabled' +const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings' export const janDataFolderPathAtom = atom('') @@ -23,3 +24,9 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false) export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) export const hostAtom = atom('http://localhost:1337/') + +// This feature is to allow user to cache model settings on thread creation +export const preserveModelSettingsAtom = atomWithStorage( + PRESERVE_MODEL_SETTINGS, + false +) diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index deb7b8622..77b1bfa4e 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -34,6 +34,17 @@ export const removeDownloadingModelAtom = atom( export const downloadedModelsAtom = atom([]) +export const updateDownloadedModelAtom = atom( + null, + (get, set, updatedModel: Model) => { + const models: Model[] = get(downloadedModelsAtom).map((c) => + c.id === updatedModel.id ? updatedModel : c + ) + + set(downloadedModelsAtom, models) + } +) + export const removeDownloadedModelAtom = atom( null, (get, set, modelId: string) => { diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 03c3edf90..954b249a1 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -10,7 +10,7 @@ import { Model, AssistantTool, } from '@janhq/core' -import { atom, useAtomValue, useSetAtom } from 'jotai' +import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction' import { fileUploadAtom } from '@/containers/Providers/Jotai' @@ -24,7 +24,10 @@ import useSetActiveThread from './useSetActiveThread' import { extensionManager } from '@/extension' -import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' +import { + experimentalFeatureEnabledAtom, + preserveModelSettingsAtom, +} from '@/helpers/atoms/AppConfig.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { threadsAtom, @@ -62,6 +65,7 @@ export const useCreateNewThread = () => { const copyOverInstructionEnabled = useAtomValue( copyOverInstructionEnabledAtom ) + const preserveModelSettings = useAtomValue(preserveModelSettingsAtom) const activeThread = useAtomValue(activeThreadAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) @@ -99,15 +103,20 @@ export const useCreateNewThread = () => { enabled: true, settings: assistant.tools && assistant.tools[0].settings, } - + const defaultContextLength = preserveModelSettings + ? model?.metadata?.default_ctx_len + : 2048 + const defaultMaxTokens = preserveModelSettings + ? model?.metadata?.default_max_tokens + : 2048 const overriddenSettings = defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 - ? { ctx_len: 2048 } + ? { ctx_len: defaultContextLength } : {} const overriddenParameters = defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens - ? { max_tokens: 2048 } + ? { max_tokens: defaultMaxTokens } : {} const createdAt = Date.now() diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts index d819a85ff..b39e5f00b 100644 --- a/web/hooks/useUpdateModelParameters.ts +++ b/web/hooks/useUpdateModelParameters.ts @@ -4,16 +4,22 @@ import { ConversationalExtension, ExtensionTypeEnum, InferenceEngine, + Model, + ModelExtension, Thread, ThreadAssistantInfo, } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import { extensionManager } from '@/extension' -import { selectedModelAtom } from '@/helpers/atoms/Model.atom' +import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom' +import { + selectedModelAtom, + updateDownloadedModelAtom, +} from '@/helpers/atoms/Model.atom' import { ModelParams, getActiveThreadModelParamsAtom, @@ -28,8 +34,10 @@ export type UpdateModelParameter = { export default function useUpdateModelParameters() { const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) - const selectedModel = useAtomValue(selectedModelAtom) + const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) + const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom) + const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom) const updateModelParameter = useCallback( async (thread: Thread, settings: UpdateModelParameter) => { @@ -40,12 +48,11 @@ export default function useUpdateModelParameters() { // update the state setThreadModelParams(thread.id, updatedModelParams) + const runtimeParams = toRuntimeParams(updatedModelParams) + const settingParams = toSettingParams(updatedModelParams) const assistants = thread.assistants.map( (assistant: ThreadAssistantInfo) => { - const runtimeParams = toRuntimeParams(updatedModelParams) - const settingParams = toSettingParams(updatedModelParams) - assistant.model.parameters = runtimeParams assistant.model.settings = settingParams if (selectedModel) { @@ -65,6 +72,33 @@ export default function useUpdateModelParameters() { await extensionManager .get(ExtensionTypeEnum.Conversational) ?.saveThread(updatedThread) + + // Persists default settings to model file + // Do not overwrite ctx_len and max_tokens + if (preserveModelFeatureEnabled && selectedModel) { + const updatedModel = { + ...selectedModel, + parameters: { + ...runtimeParams, + max_tokens: selectedModel.parameters.max_tokens, + }, + settings: { + ...settingParams, + ctx_len: selectedModel.settings.ctx_len, + }, + metadata: { + ...selectedModel.metadata, + default_ctx_len: settingParams.ctx_len, + default_max_tokens: runtimeParams.max_tokens, + }, + } as Model + + await extensionManager + .get(ExtensionTypeEnum.Model) + ?.saveModel(updatedModel) + setSelectedModel(updatedModel) + updateDownloadedModel(updatedModel) + } }, [activeModelParams, selectedModel, setThreadModelParams] ) diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index f132f81e7..b66dc7b86 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -35,6 +35,7 @@ import { proxyEnabledAtom, vulkanEnabledAtom, quickAskEnabledAtom, + preserveModelSettingsAtom, } from '@/helpers/atoms/AppConfig.atom' type GPU = { @@ -64,6 +65,9 @@ const Advanced = () => { const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom) const quickAskEnabled = useAtomValue(quickAskEnabledAtom) + const [preserveModelSettings, setPreserveModelSettings] = useAtom( + preserveModelSettingsAtom + ) const [proxy, setProxy] = useAtom(proxyAtom) const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom) @@ -385,6 +389,27 @@ const Advanced = () => { )} + {experimentalEnabled && ( +
+
+
+
+ Preserve Model Settings +
+
+

+ Save model settings changes directly to the model file so that + new threads will reuse the previous settings. +

+
+ + setPreserveModelSettings(e.target.checked)} + /> +
+ )} + {/* Proxy */}