feat: preserve model settings (#3427)
* feat: preserve model settings * feat: preserve model settings across new threads * chore: lint fix * fix: feature toggle off should also affect default value retrieve
This commit is contained in:
parent
c8474c88ca
commit
ad9a4a0b4d
@ -107,6 +107,9 @@ export type ModelMetadata = {
|
||||
tags: string[]
|
||||
size: number
|
||||
cover?: string
|
||||
// These settings to preserve model settings across threads
|
||||
default_ctx_len?: number
|
||||
default_max_tokens?: number
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -351,7 +351,7 @@ export default class JanModelExtension extends ModelExtension {
|
||||
}
|
||||
|
||||
/**
|
||||
* Saves a machine learning model.
|
||||
* Saves a model file.
|
||||
* @param model - The model to save.
|
||||
* @returns A Promise that resolves when the model is saved.
|
||||
*/
|
||||
|
||||
@ -45,6 +45,7 @@ import {
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import { inActiveEngineProviderAtom } from '@/helpers/atoms/Extension.atom'
|
||||
import {
|
||||
configuredModelsAtom,
|
||||
@ -89,6 +90,7 @@ const ModelDropdown = ({
|
||||
const featuredModel = configuredModels.filter((x) =>
|
||||
x.metadata.tags.includes('Featured')
|
||||
)
|
||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
||||
|
||||
useClickOutside(() => !filterOptionsOpen && setOpen(false), null, [
|
||||
dropdownOptions,
|
||||
@ -161,14 +163,25 @@ const ModelDropdown = ({
|
||||
if (activeThread) {
|
||||
// Default setting ctx_len for the model for a better onboarding experience
|
||||
// TODO: When Cortex support hardware instructions, we should remove this
|
||||
const defaultContextLength = preserveModelSettings
|
||||
? model?.metadata?.default_ctx_len
|
||||
: 2048
|
||||
const defaultMaxTokens = preserveModelSettings
|
||||
? model?.metadata?.default_max_tokens
|
||||
: 2048
|
||||
const overriddenSettings =
|
||||
model?.settings.ctx_len && model.settings.ctx_len > 2048
|
||||
? { ctx_len: 2048 }
|
||||
? { ctx_len: defaultContextLength }
|
||||
: {}
|
||||
const overriddenParameters =
|
||||
model?.parameters.max_tokens && model.parameters.max_tokens
|
||||
? { max_tokens: defaultMaxTokens }
|
||||
: {}
|
||||
|
||||
const modelParams = {
|
||||
...model?.parameters,
|
||||
...model?.settings,
|
||||
...overriddenParameters,
|
||||
...overriddenSettings,
|
||||
}
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ const VULKAN_ENABLED = 'vulkanEnabled'
|
||||
const IGNORE_SSL = 'ignoreSSLFeature'
|
||||
const HTTPS_PROXY_FEATURE = 'httpsProxyFeature'
|
||||
const QUICK_ASK_ENABLED = 'quickAskEnabled'
|
||||
const PRESERVE_MODEL_SETTINGS = 'preserveModelSettings'
|
||||
|
||||
export const janDataFolderPathAtom = atom('')
|
||||
|
||||
@ -23,3 +24,9 @@ export const vulkanEnabledAtom = atomWithStorage(VULKAN_ENABLED, false)
|
||||
export const quickAskEnabledAtom = atomWithStorage(QUICK_ASK_ENABLED, false)
|
||||
|
||||
export const hostAtom = atom('http://localhost:1337/')
|
||||
|
||||
// This feature is to allow user to cache model settings on thread creation
|
||||
export const preserveModelSettingsAtom = atomWithStorage(
|
||||
PRESERVE_MODEL_SETTINGS,
|
||||
false
|
||||
)
|
||||
|
||||
@ -34,6 +34,17 @@ export const removeDownloadingModelAtom = atom(
|
||||
|
||||
export const downloadedModelsAtom = atom<Model[]>([])
|
||||
|
||||
export const updateDownloadedModelAtom = atom(
|
||||
null,
|
||||
(get, set, updatedModel: Model) => {
|
||||
const models: Model[] = get(downloadedModelsAtom).map((c) =>
|
||||
c.id === updatedModel.id ? updatedModel : c
|
||||
)
|
||||
|
||||
set(downloadedModelsAtom, models)
|
||||
}
|
||||
)
|
||||
|
||||
export const removeDownloadedModelAtom = atom(
|
||||
null,
|
||||
(get, set, modelId: string) => {
|
||||
|
||||
@ -10,7 +10,7 @@ import {
|
||||
Model,
|
||||
AssistantTool,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { copyOverInstructionEnabledAtom } from '@/containers/CopyInstruction'
|
||||
import { fileUploadAtom } from '@/containers/Providers/Jotai'
|
||||
@ -24,7 +24,10 @@ import useSetActiveThread from './useSetActiveThread'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
|
||||
import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import {
|
||||
experimentalFeatureEnabledAtom,
|
||||
preserveModelSettingsAtom,
|
||||
} from '@/helpers/atoms/AppConfig.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
threadsAtom,
|
||||
@ -62,6 +65,7 @@ export const useCreateNewThread = () => {
|
||||
const copyOverInstructionEnabled = useAtomValue(
|
||||
copyOverInstructionEnabledAtom
|
||||
)
|
||||
const preserveModelSettings = useAtomValue(preserveModelSettingsAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
|
||||
const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom)
|
||||
@ -99,15 +103,20 @@ export const useCreateNewThread = () => {
|
||||
enabled: true,
|
||||
settings: assistant.tools && assistant.tools[0].settings,
|
||||
}
|
||||
|
||||
const defaultContextLength = preserveModelSettings
|
||||
? model?.metadata?.default_ctx_len
|
||||
: 2048
|
||||
const defaultMaxTokens = preserveModelSettings
|
||||
? model?.metadata?.default_max_tokens
|
||||
: 2048
|
||||
const overriddenSettings =
|
||||
defaultModel?.settings.ctx_len && defaultModel.settings.ctx_len > 2048
|
||||
? { ctx_len: 2048 }
|
||||
? { ctx_len: defaultContextLength }
|
||||
: {}
|
||||
|
||||
const overriddenParameters =
|
||||
defaultModel?.parameters.max_tokens && defaultModel.parameters.max_tokens
|
||||
? { max_tokens: 2048 }
|
||||
? { max_tokens: defaultMaxTokens }
|
||||
: {}
|
||||
|
||||
const createdAt = Date.now()
|
||||
|
||||
@ -4,16 +4,22 @@ import {
|
||||
ConversationalExtension,
|
||||
ExtensionTypeEnum,
|
||||
InferenceEngine,
|
||||
Model,
|
||||
ModelExtension,
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
import { useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import { preserveModelSettingsAtom } from '@/helpers/atoms/AppConfig.atom'
|
||||
import {
|
||||
selectedModelAtom,
|
||||
updateDownloadedModelAtom,
|
||||
} from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
ModelParams,
|
||||
getActiveThreadModelParamsAtom,
|
||||
@ -28,8 +34,10 @@ export type UpdateModelParameter = {
|
||||
|
||||
export default function useUpdateModelParameters() {
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const [selectedModel, setSelectedModel] = useAtom(selectedModelAtom)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
const updateDownloadedModel = useSetAtom(updateDownloadedModelAtom)
|
||||
const preserveModelFeatureEnabled = useAtomValue(preserveModelSettingsAtom)
|
||||
|
||||
const updateModelParameter = useCallback(
|
||||
async (thread: Thread, settings: UpdateModelParameter) => {
|
||||
@ -40,12 +48,11 @@ export default function useUpdateModelParameters() {
|
||||
|
||||
// update the state
|
||||
setThreadModelParams(thread.id, updatedModelParams)
|
||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
||||
const settingParams = toSettingParams(updatedModelParams)
|
||||
|
||||
const assistants = thread.assistants.map(
|
||||
(assistant: ThreadAssistantInfo) => {
|
||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
||||
const settingParams = toSettingParams(updatedModelParams)
|
||||
|
||||
assistant.model.parameters = runtimeParams
|
||||
assistant.model.settings = settingParams
|
||||
if (selectedModel) {
|
||||
@ -65,6 +72,33 @@ export default function useUpdateModelParameters() {
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
|
||||
// Persists default settings to model file
|
||||
// Do not overwrite ctx_len and max_tokens
|
||||
if (preserveModelFeatureEnabled && selectedModel) {
|
||||
const updatedModel = {
|
||||
...selectedModel,
|
||||
parameters: {
|
||||
...runtimeParams,
|
||||
max_tokens: selectedModel.parameters.max_tokens,
|
||||
},
|
||||
settings: {
|
||||
...settingParams,
|
||||
ctx_len: selectedModel.settings.ctx_len,
|
||||
},
|
||||
metadata: {
|
||||
...selectedModel.metadata,
|
||||
default_ctx_len: settingParams.ctx_len,
|
||||
default_max_tokens: runtimeParams.max_tokens,
|
||||
},
|
||||
} as Model
|
||||
|
||||
await extensionManager
|
||||
.get<ModelExtension>(ExtensionTypeEnum.Model)
|
||||
?.saveModel(updatedModel)
|
||||
setSelectedModel(updatedModel)
|
||||
updateDownloadedModel(updatedModel)
|
||||
}
|
||||
},
|
||||
[activeModelParams, selectedModel, setThreadModelParams]
|
||||
)
|
||||
|
||||
@ -35,6 +35,7 @@ import {
|
||||
proxyEnabledAtom,
|
||||
vulkanEnabledAtom,
|
||||
quickAskEnabledAtom,
|
||||
preserveModelSettingsAtom,
|
||||
} from '@/helpers/atoms/AppConfig.atom'
|
||||
|
||||
type GPU = {
|
||||
@ -64,6 +65,9 @@ const Advanced = () => {
|
||||
const [proxyEnabled, setProxyEnabled] = useAtom(proxyEnabledAtom)
|
||||
const quickAskEnabled = useAtomValue(quickAskEnabledAtom)
|
||||
|
||||
const [preserveModelSettings, setPreserveModelSettings] = useAtom(
|
||||
preserveModelSettingsAtom
|
||||
)
|
||||
const [proxy, setProxy] = useAtom(proxyAtom)
|
||||
const [ignoreSSL, setIgnoreSSL] = useAtom(ignoreSslAtom)
|
||||
|
||||
@ -385,6 +389,27 @@ const Advanced = () => {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{experimentalEnabled && (
|
||||
<div className="flex w-full flex-col items-start justify-between gap-4 border-b border-[hsla(var(--app-border))] py-4 first:pt-0 last:border-none sm:flex-row">
|
||||
<div className="flex-shrink-0 space-y-1">
|
||||
<div className="flex gap-x-2">
|
||||
<h6 className="font-semibold capitalize">
|
||||
Preserve Model Settings
|
||||
</h6>
|
||||
</div>
|
||||
<p className="font-medium leading-relaxed text-[hsla(var(--text-secondary))]">
|
||||
Save model settings changes directly to the model file so that
|
||||
new threads will reuse the previous settings.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<Switch
|
||||
checked={preserveModelSettings}
|
||||
onChange={(e) => setPreserveModelSettings(e.target.checked)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<DataFolder />
|
||||
|
||||
{/* Proxy */}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user