feat: preserve model settings (#3427)

* feat: preserve model settings

* feat: preserve model settings across new threads

* chore: lint fix

* fix: feature toggle off should also affect default value retrieve
This commit is contained in:
Louis 2024-08-21 21:28:29 +07:00 committed by GitHub
parent c8474c88ca
commit ad9a4a0b4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 115 additions and 13 deletions

View File

@ -107,6 +107,9 @@ export type ModelMetadata = {
tags: string[] tags: string[]
size: number size: number
cover?: string cover?: string
// These settings to preserve model settings across threads
default_ctx_len?: number
default_max_tokens?: number
} }
/** /**

View File

@ -351,7 +351,7 @@ export default class JanModelExtension extends ModelExtension {
} }
/** /**
* Saves a machine learning model. * Saves a model file.
* @param model - The model to save. * @param model - The model to save.
* @returns A Promise that resolves when the model is saved. * @returns A Promise that resolves when the model is saved.
*/ */

View File

@ -45,6 +45,7 @@ import {
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom' import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import { import {
configuredModelsAtom, configuredModelsAtom,
@ -89,6 +90,7 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) => const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured') x.metadata.tags.includes('Featured')
) )
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [ useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
dropdownOptions, dropdownOptions,
@ -161,14 +163,25 @@ const ModelDropdown = ({
if (activeThread) { if (activeThread) {
// Default setting ctx_len for the model for a better onboarding experience // Default setting ctx_len for the model for a better onboarding experience
// TODO: When Cortex support hardware instructions, we should remove this // TODO: When Cortex support hardware instructions, we should remove this
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings = const overriddenSettings =
model?.settings.ctx_len && model.settings.ctx_len > 2048 model?.settings.ctx_len && model.settings.ctx_len > 2048
? { ctx_len: 2048 } ? { ctx_len: defaultContextLength }
: {}
const overriddenParameters =
model?.parameters.max_tokens && model.parameters.max_tokens
? { max_tokens: defaultMaxTokens }
: {} : {}
const modelParams = { const modelParams = {
...model?.parameters, ...model?.parameters,
...model?.settings, ...model?.settings,
...overriddenParameters,
...overriddenSettings, ...overriddenSettings,
} }

View File

@ -7,6 +7,7 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature' const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature' const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled' const QUICK_ASK_ENABLED = 'quickAskEnabled'
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
export const janDataFolderPathAtom = atom('') export const janDataFolderPathAtom = atom('')
@ -23,3 +24,9 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false) export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
export const hostAtom = atom('http://localhost:1337/') export const hostAtom = atom('http://localhost:1337/')
// This feature is to allow user to cache model settings on thread creation
export const preserveModelSettingsAtom = atomWithStorage(
PRESERVE_MODEL_SETTINGS,
false
)

View File

@ -34,6 +34,17 @@ export const removeDownloadingModelAtom = atom(
export const downloadedModelsAtom = atom<Model[]>([]) export const downloadedModelsAtom = atom<Model[]>([])
export const updateDownloadedModelAtom = atom(
null,
(get, set, updatedModel: Model) => {
const models: Model[] = get(downloadedModelsAtom).map((c) =>
c.id === updatedModel.id ? updatedModel : c
)
set(downloadedModelsAtom, models)
}
)
export const removeDownloadedModelAtom = atom( export const removeDownloadedModelAtom = atom(
null, null,
(get, set, modelId: string) => { (get, set, modelId: string) => {

View File

@ -10,7 +10,7 @@ import {
Model, Model,
AssistantTool, AssistantTool,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction' import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai' import { fileUploadAtom } from '@/containers/Providers/Jotai'
@ -24,7 +24,10 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import {
experimentalFeatureEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
threadsAtom, threadsAtom,
@ -62,6 +65,7 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue( const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom copyOverInstructionEnabledAtom
) )
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
@ -99,15 +103,20 @@ export const useCreateNewThread = () => {
enabled: true, enabled: true,
settings: assistant.tools && assistant.tools[0].settings, settings: assistant.tools && assistant.tools[0].settings,
} }
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings = const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048 defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
? { ctx_len: 2048 } ? { ctx_len: defaultContextLength }
: {} : {}
const overriddenParameters = const overriddenParameters =
defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens
? { max_tokens: 2048 } ? { max_tokens: defaultMaxTokens }
: {} : {}
const createdAt = Date.now() const createdAt = Date.now()

View File

@ -4,16 +4,22 @@ import {
ConversationalExtension, ConversationalExtension,
ExtensionTypeEnum, ExtensionTypeEnum,
InferenceEngine, InferenceEngine,
Model,
ModelExtension,
Thread, Thread,
ThreadAssistantInfo, ThreadAssistantInfo,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import {
selectedModelAtom,
updateDownloadedModelAtom,
} from '@/helpers/atoms/Model.atom'
import { import {
ModelParams, ModelParams,
getActiveThreadModelParamsAtom, getActiveThreadModelParamsAtom,
@ -28,8 +34,10 @@ export type UpdateModelParameter = {
export default function useUpdateModelParameters() { export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const selectedModel = useAtomValue(selectedModelAtom) const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom) const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
const updateModelParameter = useCallback( const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => { async (thread: Thread, settings: UpdateModelParameter) => {
@ -40,12 +48,11 @@ export default function useUpdateModelParameters() {
// update the state // update the state
setThreadModelParams(thread.id, updatedModelParams) setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)
const assistants = thread.assistants.map( const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => { (assistant: ThreadAssistantInfo) => {
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)
assistant.model.parameters = runtimeParams assistant.model.parameters = runtimeParams
assistant.model.settings = settingParams assistant.model.settings = settingParams
if (selectedModel) { if (selectedModel) {
@ -65,6 +72,33 @@ export default function useUpdateModelParameters() {
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread) ?.saveThread(updatedThread)
// Persists default settings to model file
// Do not overwrite ctx_len and max_tokens
if (preserveModelFeatureEnabled && selectedModel) {
const updatedModel = {
...selectedModel,
parameters: {
...runtimeParams,
max_tokens: selectedModel.parameters.max_tokens,
},
settings: {
...settingParams,
ctx_len: selectedModel.settings.ctx_len,
},
metadata: {
...selectedModel.metadata,
default_ctx_len: settingParams.ctx_len,
default_max_tokens: runtimeParams.max_tokens,
},
} as Model
await extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.saveModel(updatedModel)
setSelectedModel(updatedModel)
updateDownloadedModel(updatedModel)
}
}, },
[activeModelParams, selectedModel, setThreadModelParams] [activeModelParams, selectedModel, setThreadModelParams]
) )

View File

@ -35,6 +35,7 @@ import {
proxyEnabledAtom, proxyEnabledAtom,
vulkanEnabledAtom, vulkanEnabledAtom,
quickAskEnabledAtom, quickAskEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom' } from '@/helpers/atoms/AppConfig.atom'
type GPU = { type GPU = {
@ -64,6 +65,9 @@ const Advanced = () => {
const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom) const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom)
const quickAskEnabled = useAtomValue(quickAskEnabledAtom) const quickAskEnabled = useAtomValue(quickAskEnabledAtom)
const [preserveModelSettings, setPreserveModelSettings] = useAtom(
preserveModelSettingsAtom
)
const [proxy, setProxy] = useAtom(proxyAtom) const [proxy, setProxy] = useAtom(proxyAtom)
const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom) const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom)
@ -385,6 +389,27 @@ const Advanced = () => {
</div> </div>
)} )}
{experimentalEnabled && (
<div className="flex w-full flex-col items-start justify-between gap-4 border-b border-[hsla(var(--app-border))] py-4 first:pt-0 last:border-none sm:flex-row">
<div className="flex-shrink-0 space-y-1">
<div className="flex gap-x-2">
<h6 className="font-semibold capitalize">
Preserve Model Settings
</h6>
</div>
<p className="font-medium leading-relaxed text-[hsla(var(--text-secondary))]">
Save model settings changes directly to the model file so that
new threads will reuse the previous settings.
</p>
</div>
<Switch
checked={preserveModelSettings}
onChange={(e) => setPreserveModelSettings(e.target.checked)}
/>
</div>
)}
<DataFolder /> <DataFolder />
{/* Proxy */} {/* Proxy */}