feat: preserve model settings (#3427)

* feat: preserve model settings

* feat: preserve model settings across new threads

* chore: lint fix

* fix: feature toggle off should also affect default value retrieve
This commit is contained in:
Louis 2024-08-21 21:28:29 +07:00 committed by GitHub
parent c8474c88ca
commit ad9a4a0b4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 115 additions and 13 deletions

View File

@ -107,6 +107,9 @@ export type ModelMetadata = {
tags: string[]
size: number
cover?: string
// These settings to preserve model settings across threads
default_ctx_len?: number
default_max_tokens?: number
}
/**

View File

@ -351,7 +351,7 @@ export default class JanModelExtension extends ModelExtension {
}
/**
* Saves a machine learning model.
* Saves a model file.
* @param model - The model to save.
* @returns A Promise that resolves when the model is saved.
*/

View File

@ -45,6 +45,7 @@ import {
import { extensionManager } from '@/extension'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
import {
configuredModelsAtom,
@ -89,6 +90,7 @@ const ModelDropdown = ({
const featuredModel = configuredModels.filter((x) =>
x.metadata.tags.includes('Featured')
)
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
dropdownOptions,
@ -161,14 +163,25 @@ const ModelDropdown = ({
if (activeThread) {
// Default setting ctx_len for the model for a better onboarding experience
// TODO: When Cortex support hardware instructions, we should remove this
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings =
model?.settings.ctx_len && model.settings.ctx_len > 2048
? { ctx_len: 2048 }
? { ctx_len: defaultContextLength }
: {}
const overriddenParameters =
model?.parameters.max_tokens && model.parameters.max_tokens
? { max_tokens: defaultMaxTokens }
: {}
const modelParams = {
...model?.parameters,
...model?.settings,
...overriddenParameters,
...overriddenSettings,
}

View File

@ -7,6 +7,7 @@ const VULKAN_ENABLED = 'vulkanEnabled'
const IGNORE_SSL = 'ignoreSSLFeature'
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
const QUICK_ASK_ENABLED = 'quickAskEnabled'
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
export const janDataFolderPathAtom = atom('')
@ -23,3 +24,9 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
export const hostAtom = atom('http://localhost:1337/')
// This feature is to allow user to cache model settings on thread creation
export const preserveModelSettingsAtom = atomWithStorage(
PRESERVE_MODEL_SETTINGS,
false
)

View File

@ -34,6 +34,17 @@ export const removeDownloadingModelAtom = atom(
export const downloadedModelsAtom = atom<Model[]>([])
export const updateDownloadedModelAtom = atom(
null,
(get, set, updatedModel: Model) => {
const models: Model[] = get(downloadedModelsAtom).map((c) =>
c.id === updatedModel.id ? updatedModel : c
)
set(downloadedModelsAtom, models)
}
)
export const removeDownloadedModelAtom = atom(
null,
(get, set, modelId: string) => {

View File

@ -10,7 +10,7 @@ import {
Model,
AssistantTool,
} from '@janhq/core'
import { atom, useAtomValue, useSetAtom } from 'jotai'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
import { fileUploadAtom } from '@/containers/Providers/Jotai'
@ -24,7 +24,10 @@ import useSetActiveThread from './useSetActiveThread'
import { extensionManager } from '@/extension'
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
import {
experimentalFeatureEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
threadsAtom,
@ -62,6 +65,7 @@ export const useCreateNewThread = () => {
const copyOverInstructionEnabled = useAtomValue(
copyOverInstructionEnabledAtom
)
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
const activeThread = useAtomValue(activeThreadAtom)
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
@ -99,15 +103,20 @@ export const useCreateNewThread = () => {
enabled: true,
settings: assistant.tools && assistant.tools[0].settings,
}
const defaultContextLength = preserveModelSettings
? model?.metadata?.default_ctx_len
: 2048
const defaultMaxTokens = preserveModelSettings
? model?.metadata?.default_max_tokens
: 2048
const overriddenSettings =
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
? { ctx_len: 2048 }
? { ctx_len: defaultContextLength }
: {}
const overriddenParameters =
defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens
? { max_tokens: 2048 }
? { max_tokens: defaultMaxTokens }
: {}
const createdAt = Date.now()

View File

@ -4,16 +4,22 @@ import {
ConversationalExtension,
ExtensionTypeEnum,
InferenceEngine,
Model,
ModelExtension,
Thread,
ThreadAssistantInfo,
} from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai'
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import { extensionManager } from '@/extension'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
import {
selectedModelAtom,
updateDownloadedModelAtom,
} from '@/helpers/atoms/Model.atom'
import {
ModelParams,
getActiveThreadModelParamsAtom,
@ -28,8 +34,10 @@ export type UpdateModelParameter = {
export default function useUpdateModelParameters() {
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
const updateModelParameter = useCallback(
async (thread: Thread, settings: UpdateModelParameter) => {
@ -40,12 +48,11 @@ export default function useUpdateModelParameters() {
// update the state
setThreadModelParams(thread.id, updatedModelParams)
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)
const assistants = thread.assistants.map(
(assistant: ThreadAssistantInfo) => {
const runtimeParams = toRuntimeParams(updatedModelParams)
const settingParams = toSettingParams(updatedModelParams)
assistant.model.parameters = runtimeParams
assistant.model.settings = settingParams
if (selectedModel) {
@ -65,6 +72,33 @@ export default function useUpdateModelParameters() {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.saveThread(updatedThread)
// Persists default settings to model file
// Do not overwrite ctx_len and max_tokens
if (preserveModelFeatureEnabled && selectedModel) {
const updatedModel = {
...selectedModel,
parameters: {
...runtimeParams,
max_tokens: selectedModel.parameters.max_tokens,
},
settings: {
...settingParams,
ctx_len: selectedModel.settings.ctx_len,
},
metadata: {
...selectedModel.metadata,
default_ctx_len: settingParams.ctx_len,
default_max_tokens: runtimeParams.max_tokens,
},
} as Model
await extensionManager
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.saveModel(updatedModel)
setSelectedModel(updatedModel)
updateDownloadedModel(updatedModel)
}
},
[activeModelParams, selectedModel, setThreadModelParams]
)

View File

@ -35,6 +35,7 @@ import {
proxyEnabledAtom,
vulkanEnabledAtom,
quickAskEnabledAtom,
preserveModelSettingsAtom,
} from '@/helpers/atoms/AppConfig.atom'
type GPU = {
@ -64,6 +65,9 @@ const Advanced = () => {
const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom)
const quickAskEnabled = useAtomValue(quickAskEnabledAtom)
const [preserveModelSettings, setPreserveModelSettings] = useAtom(
preserveModelSettingsAtom
)
const [proxy, setProxy] = useAtom(proxyAtom)
const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom)
@ -385,6 +389,27 @@ const Advanced = () => {
</div>
)}
{experimentalEnabled && (
<div className="flex w-full flex-col items-start justify-between gap-4 border-b border-[hsla(var(--app-border))] py-4 first:pt-0 last:border-none sm:flex-row">
<div className="flex-shrink-0 space-y-1">
<div className="flex gap-x-2">
<h6 className="font-semibold capitalize">
Preserve Model Settings
</h6>
</div>
<p className="font-medium leading-relaxed text-[hsla(var(--text-secondary))]">
Save model settings changes directly to the model file so that
new threads will reuse the previous settings.
</p>
</div>
<Switch
checked={preserveModelSettings}
onChange={(e) => setPreserveModelSettings(e.target.checked)}
/>
</div>
)}
<DataFolder />
{/* Proxy */}