feat: add engine settings (#1199)
* feat: add engine settings Signed-off-by: James <james@jan.ai> --------- Signed-off-by: James <james@jan.ai> Co-authored-by: Louis <louis@jan.ai>
This commit is contained in:
parent
2df43e91ef
commit
c580b4c848
@ -1,48 +1,33 @@
|
||||
import { FieldValues, UseFormRegister } from 'react-hook-form'
|
||||
import React from 'react'
|
||||
|
||||
import { ModelRuntimeParams } from '@janhq/core'
|
||||
import { Switch } from '@janhq/uikit'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
type Props = {
|
||||
name: string
|
||||
title: string
|
||||
checked: boolean
|
||||
register: UseFormRegister<FieldValues>
|
||||
}
|
||||
|
||||
const Checkbox: React.FC<Props> = ({ name, title, checked, register }) => {
|
||||
const Checkbox: React.FC<Props> = ({ name, title, checked }) => {
|
||||
const { updateModelParameter } = useUpdateModelParameters()
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
|
||||
const onCheckedChange = (checked: boolean) => {
|
||||
if (!threadId || !activeModelParams) return
|
||||
if (!threadId) return
|
||||
|
||||
const updatedModelParams: ModelRuntimeParams = {
|
||||
...activeModelParams,
|
||||
[name]: checked,
|
||||
}
|
||||
|
||||
updateModelParameter(threadId, updatedModelParams)
|
||||
updateModelParameter(threadId, name, checked)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex justify-between">
|
||||
<label>{title}</label>
|
||||
<Switch
|
||||
checked={checked}
|
||||
{...register(name)}
|
||||
onCheckedChange={onCheckedChange}
|
||||
/>
|
||||
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
|
||||
<Switch checked={checked} onCheckedChange={onCheckedChange} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -28,14 +28,22 @@ import useRecommendedModel from '@/hooks/useRecommendedModel'
|
||||
|
||||
import { toGigabytes } from '@/utils/converter'
|
||||
|
||||
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
getActiveThreadIdAtom,
|
||||
setThreadModelParamsAtom,
|
||||
threadStatesAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const selectedModelAtom = atom<Model | undefined>(undefined)
|
||||
|
||||
export default function DropdownListSidebar() {
|
||||
const setSelectedModel = useSetAtom(selectedModelAtom)
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const activeThreadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const setSelectedModel = useSetAtom(selectedModelAtom)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
|
||||
const [selected, setSelected] = useState<Model | undefined>()
|
||||
const { setMainViewState } = useMainViewState()
|
||||
const [openAISettings, setOpenAISettings] = useState<
|
||||
@ -54,15 +62,39 @@ export default function DropdownListSidebar() {
|
||||
useEffect(() => {
|
||||
setSelected(recommendedModel)
|
||||
setSelectedModel(recommendedModel)
|
||||
}, [recommendedModel, setSelectedModel])
|
||||
|
||||
if (activeThread) {
|
||||
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
|
||||
if (finishInit) return
|
||||
const modelParams = {
|
||||
...recommendedModel?.parameters,
|
||||
...recommendedModel?.settings,
|
||||
}
|
||||
setThreadModelParams(activeThread.id, modelParams)
|
||||
}
|
||||
}, [
|
||||
recommendedModel,
|
||||
activeThread,
|
||||
setSelectedModel,
|
||||
setThreadModelParams,
|
||||
threadStates,
|
||||
])
|
||||
|
||||
const onValueSelected = useCallback(
|
||||
(modelId: string) => {
|
||||
const model = downloadedModels.find((m) => m.id === modelId)
|
||||
setSelected(model)
|
||||
setSelectedModel(model)
|
||||
|
||||
if (activeThreadId) {
|
||||
const modelParams = {
|
||||
...model?.parameters,
|
||||
...model?.settings,
|
||||
}
|
||||
setThreadModelParams(activeThreadId, modelParams)
|
||||
}
|
||||
},
|
||||
[downloadedModels, setSelectedModel]
|
||||
[downloadedModels, activeThreadId, setSelectedModel, setThreadModelParams]
|
||||
)
|
||||
|
||||
if (!activeThread) {
|
||||
|
||||
43
web/containers/ModelConfigInput/index.tsx
Normal file
43
web/containers/ModelConfigInput/index.tsx
Normal file
@ -0,0 +1,43 @@
|
||||
import { Textarea } from '@janhq/uikit'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
type Props = {
|
||||
title: string
|
||||
name: string
|
||||
placeholder: string
|
||||
value: string
|
||||
}
|
||||
|
||||
const ModelConfigInput: React.FC<Props> = ({
|
||||
title,
|
||||
name,
|
||||
value,
|
||||
placeholder,
|
||||
}) => {
|
||||
const { updateModelParameter } = useUpdateModelParameters()
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
|
||||
const onValueChanged = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
if (!threadId) return
|
||||
|
||||
updateModelParameter(threadId, name, e.target.value)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col">
|
||||
<p className="mb-2 text-sm font-semibold text-gray-600">{title}</p>
|
||||
<Textarea
|
||||
placeholder={placeholder}
|
||||
onChange={onValueChanged}
|
||||
value={value}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default ModelConfigInput
|
||||
@ -1,15 +1,11 @@
|
||||
import { FieldValues, UseFormRegister } from 'react-hook-form'
|
||||
import React from 'react'
|
||||
|
||||
import { ModelRuntimeParams } from '@janhq/core'
|
||||
import { Slider, Input } from '@janhq/uikit'
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import useUpdateModelParameters from '@/hooks/useUpdateModelParameters'
|
||||
|
||||
import {
|
||||
getActiveThreadIdAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
import { getActiveThreadIdAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
type Props = {
|
||||
name: string
|
||||
@ -18,7 +14,6 @@ type Props = {
|
||||
max: number
|
||||
step: number
|
||||
value: number
|
||||
register: UseFormRegister<FieldValues>
|
||||
}
|
||||
|
||||
const SliderRightPanel: React.FC<Props> = ({
|
||||
@ -28,21 +23,14 @@ const SliderRightPanel: React.FC<Props> = ({
|
||||
max,
|
||||
step,
|
||||
value,
|
||||
register,
|
||||
}) => {
|
||||
const { updateModelParameter } = useUpdateModelParameters()
|
||||
const threadId = useAtomValue(getActiveThreadIdAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
|
||||
const onValueChanged = (e: number[]) => {
|
||||
if (!threadId || !activeModelParams) return
|
||||
if (!threadId) return
|
||||
|
||||
const updatedModelParams: ModelRuntimeParams = {
|
||||
...activeModelParams,
|
||||
[name]: Number(e[0]),
|
||||
}
|
||||
|
||||
updateModelParameter(threadId, updatedModelParams)
|
||||
updateModelParameter(threadId, name, e[0])
|
||||
}
|
||||
|
||||
return (
|
||||
@ -51,9 +39,6 @@ const SliderRightPanel: React.FC<Props> = ({
|
||||
<div className="flex items-center gap-x-4">
|
||||
<div className="relative w-full">
|
||||
<Slider
|
||||
{...register(name, {
|
||||
setValueAs: (v: string) => parseInt(v),
|
||||
})}
|
||||
value={[value]}
|
||||
onValueChange={onValueChanged}
|
||||
min={min}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import {
|
||||
ModelRuntimeParams,
|
||||
ModelSettingParams,
|
||||
Thread,
|
||||
ThreadContent,
|
||||
ThreadState,
|
||||
@ -110,30 +111,26 @@ export const activeThreadAtom = atom<Thread | undefined>((get) =>
|
||||
/**
|
||||
* Store model params at thread level settings
|
||||
*/
|
||||
export const threadModelRuntimeParamsAtom = atom<
|
||||
Record<string, ModelRuntimeParams>
|
||||
>({})
|
||||
export const threadModelParamsAtom = atom<Record<string, ModelParams>>({})
|
||||
|
||||
export const getActiveThreadModelRuntimeParamsAtom = atom<
|
||||
ModelRuntimeParams | undefined
|
||||
>((get) => {
|
||||
const threadId = get(activeThreadIdAtom)
|
||||
if (!threadId) {
|
||||
console.debug('Active thread id is undefined')
|
||||
return undefined
|
||||
export type ModelParams = ModelRuntimeParams | ModelSettingParams
|
||||
|
||||
export const getActiveThreadModelParamsAtom = atom<ModelParams | undefined>(
|
||||
(get) => {
|
||||
const threadId = get(activeThreadIdAtom)
|
||||
if (!threadId) {
|
||||
console.debug('Active thread id is undefined')
|
||||
return undefined
|
||||
}
|
||||
|
||||
return get(threadModelParamsAtom)[threadId]
|
||||
}
|
||||
|
||||
return get(threadModelRuntimeParamsAtom)[threadId]
|
||||
})
|
||||
|
||||
export const getThreadModelRuntimeParamsAtom = atom(
|
||||
(get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId]
|
||||
)
|
||||
|
||||
export const setThreadModelRuntimeParamsAtom = atom(
|
||||
export const setThreadModelParamsAtom = atom(
|
||||
null,
|
||||
(get, set, threadId: string, params: ModelRuntimeParams) => {
|
||||
const currentState = { ...get(threadModelRuntimeParamsAtom) }
|
||||
(get, set, threadId: string, params: ModelParams) => {
|
||||
const currentState = { ...get(threadModelParamsAtom) }
|
||||
currentState[threadId] = params
|
||||
console.debug(
|
||||
`Update model params for thread ${threadId}, ${JSON.stringify(
|
||||
@ -142,6 +139,6 @@ export const setThreadModelRuntimeParamsAtom = atom(
|
||||
2
|
||||
)}`
|
||||
)
|
||||
set(threadModelRuntimeParamsAtom, currentState)
|
||||
set(threadModelParamsAtom, currentState)
|
||||
}
|
||||
)
|
||||
|
||||
@ -19,7 +19,6 @@ import {
|
||||
setActiveThreadIdAtom,
|
||||
threadStatesAtom,
|
||||
updateThreadAtom,
|
||||
setThreadModelRuntimeParamsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||
@ -45,10 +44,6 @@ export const useCreateNewThread = () => {
|
||||
const createNewThread = useSetAtom(createNewThreadAtom)
|
||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
const setThreadModelRuntimeParams = useSetAtom(
|
||||
setThreadModelRuntimeParamsAtom
|
||||
)
|
||||
|
||||
const { deleteThread } = useDeleteThread()
|
||||
|
||||
const requestCreateNewThread = async (
|
||||
@ -77,10 +72,7 @@ export const useCreateNewThread = () => {
|
||||
model: {
|
||||
id: modelId,
|
||||
settings: {},
|
||||
parameters: {
|
||||
stream: true,
|
||||
max_tokens: 1024,
|
||||
},
|
||||
parameters: {},
|
||||
engine: undefined,
|
||||
},
|
||||
instructions: assistant.instructions,
|
||||
@ -94,7 +86,6 @@ export const useCreateNewThread = () => {
|
||||
created: createdAt,
|
||||
updated: createdAt,
|
||||
}
|
||||
setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters)
|
||||
|
||||
// add the new thread on top of the thread list to the state
|
||||
createNewThread(thread)
|
||||
|
||||
@ -42,6 +42,7 @@ export default function useRecommendedModel() {
|
||||
const getRecommendedModel = useCallback(async (): Promise<
|
||||
Model | undefined
|
||||
> => {
|
||||
const models = await getAndSortDownloadedModels()
|
||||
if (!activeThread) {
|
||||
return
|
||||
}
|
||||
@ -49,7 +50,6 @@ export default function useRecommendedModel() {
|
||||
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
|
||||
if (finishInit) {
|
||||
const modelId = activeThread.assistants[0]?.model.id
|
||||
const models = await getAndSortDownloadedModels()
|
||||
const model = models.find((model) => model.id === modelId)
|
||||
|
||||
if (model) {
|
||||
@ -60,7 +60,6 @@ export default function useRecommendedModel() {
|
||||
} else {
|
||||
const modelId = activeThread.assistants[0]?.model.id
|
||||
if (modelId !== '*') {
|
||||
const models = await getAndSortDownloadedModels()
|
||||
const model = models.find((model) => model.id === modelId)
|
||||
|
||||
if (model) {
|
||||
@ -78,7 +77,7 @@ export default function useRecommendedModel() {
|
||||
}
|
||||
|
||||
// sort the model, for display purpose
|
||||
const models = await getAndSortDownloadedModels()
|
||||
|
||||
if (models.length === 0) {
|
||||
// if we have no downloaded models, then can't recommend anything
|
||||
console.debug("No downloaded models, can't recommend anything")
|
||||
|
||||
@ -24,6 +24,8 @@ import { currentPromptAtom } from '@/containers/Providers/Jotai'
|
||||
|
||||
import { toaster } from '@/containers/Toast'
|
||||
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/model_param'
|
||||
|
||||
import { useActiveModel } from './useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
@ -33,7 +35,7 @@ import {
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
getActiveThreadModelRuntimeParamsAtom,
|
||||
getActiveThreadModelParamsAtom,
|
||||
threadStatesAtom,
|
||||
updateThreadAtom,
|
||||
updateThreadInitSuccessAtom,
|
||||
@ -56,7 +58,7 @@ export default function useSendChatMessage() {
|
||||
const modelRef = useRef<Model | undefined>()
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
const updateThreadInitSuccess = useSetAtom(updateThreadInitSuccessAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
|
||||
useEffect(() => {
|
||||
modelRef.current = activeModel
|
||||
@ -128,15 +130,17 @@ export default function useSendChatMessage() {
|
||||
}
|
||||
|
||||
const sendChatMessage = async () => {
|
||||
if (!currentPrompt || currentPrompt.trim().length === 0) {
|
||||
return
|
||||
}
|
||||
if (!currentPrompt || currentPrompt.trim().length === 0) return
|
||||
|
||||
if (!activeThread) {
|
||||
console.error('No active thread')
|
||||
return
|
||||
}
|
||||
const activeThreadState = threadStates[activeThread.id]
|
||||
|
||||
const runtimeParams = toRuntimeParams(activeModelParams)
|
||||
const settingParams = toSettingParams(activeModelParams)
|
||||
|
||||
// if the thread is not initialized, we need to initialize it first
|
||||
if (
|
||||
!activeThreadState.isFinishInit ||
|
||||
@ -150,11 +154,6 @@ export default function useSendChatMessage() {
|
||||
const assistantName = activeThread.assistants[0].assistant_name ?? ''
|
||||
const instructions = activeThread.assistants[0].instructions ?? ''
|
||||
|
||||
const modelParams: ModelRuntimeParams = {
|
||||
...selectedModel.parameters,
|
||||
...activeModelParams,
|
||||
}
|
||||
|
||||
const updatedThread: Thread = {
|
||||
...activeThread,
|
||||
assistants: [
|
||||
@ -164,8 +163,8 @@ export default function useSendChatMessage() {
|
||||
instructions: instructions,
|
||||
model: {
|
||||
id: selectedModel.id,
|
||||
settings: selectedModel.settings,
|
||||
parameters: modelParams,
|
||||
settings: settingParams,
|
||||
parameters: runtimeParams,
|
||||
engine: selectedModel.engine,
|
||||
},
|
||||
},
|
||||
@ -211,13 +210,17 @@ export default function useSendChatMessage() {
|
||||
const msgId = ulid()
|
||||
|
||||
const modelRequest = selectedModel ?? activeThread.assistants[0].model
|
||||
if (runtimeParams.stream == null) {
|
||||
runtimeParams.stream = true
|
||||
}
|
||||
const messageRequest: MessageRequest = {
|
||||
id: msgId,
|
||||
threadId: activeThread.id,
|
||||
messages,
|
||||
model: {
|
||||
...modelRequest,
|
||||
...(activeModelParams ? { parameters: activeModelParams } : {}),
|
||||
settings: settingParams,
|
||||
parameters: runtimeParams,
|
||||
},
|
||||
}
|
||||
const timestamp = Date.now()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import {
|
||||
ExtensionType,
|
||||
ModelRuntimeParams,
|
||||
Thread,
|
||||
ThreadState,
|
||||
ConversationalExtension,
|
||||
@ -12,7 +11,8 @@ import useSetActiveThread from './useSetActiveThread'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
import {
|
||||
threadModelRuntimeParamsAtom,
|
||||
ModelParams,
|
||||
threadModelParamsAtom,
|
||||
threadStatesAtom,
|
||||
threadsAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
@ -21,7 +21,7 @@ const useThreads = () => {
|
||||
const [threadStates, setThreadStates] = useAtom(threadStatesAtom)
|
||||
const [threads, setThreads] = useAtom(threadsAtom)
|
||||
const [threadModelRuntimeParams, setThreadModelRuntimeParams] = useAtom(
|
||||
threadModelRuntimeParamsAtom
|
||||
threadModelParamsAtom
|
||||
)
|
||||
const { setActiveThread } = useSetActiveThread()
|
||||
|
||||
@ -29,7 +29,7 @@ const useThreads = () => {
|
||||
try {
|
||||
const localThreads = await getLocalThreads()
|
||||
const localThreadStates: Record<string, ThreadState> = {}
|
||||
const threadModelParams: Record<string, ModelRuntimeParams> = {}
|
||||
const threadModelParams: Record<string, ModelParams> = {}
|
||||
|
||||
localThreads.forEach((thread) => {
|
||||
if (thread.id != null) {
|
||||
@ -42,9 +42,12 @@ const useThreads = () => {
|
||||
isFinishInit: true,
|
||||
}
|
||||
|
||||
// model params
|
||||
const modelParams = thread.assistants?.[0]?.model?.parameters
|
||||
threadModelParams[thread.id] = modelParams
|
||||
const engineParams = thread.assistants?.[0]?.model?.settings
|
||||
threadModelParams[thread.id] = {
|
||||
...modelParams,
|
||||
...engineParams,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@ -1,31 +1,34 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import {
|
||||
ConversationalExtension,
|
||||
ExtensionType,
|
||||
ModelRuntimeParams,
|
||||
Thread,
|
||||
ThreadAssistantInfo,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/model_param'
|
||||
|
||||
import { extensionManager } from '@/extension'
|
||||
import {
|
||||
ModelParams,
|
||||
activeThreadStateAtom,
|
||||
setThreadModelRuntimeParamsAtom,
|
||||
getActiveThreadModelParamsAtom,
|
||||
setThreadModelParamsAtom,
|
||||
threadsAtom,
|
||||
updateThreadAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export default function useUpdateModelParameters() {
|
||||
const threads = useAtomValue(threadsAtom)
|
||||
const updateThread = useSetAtom(updateThreadAtom)
|
||||
const setThreadModelRuntimeParams = useSetAtom(
|
||||
setThreadModelRuntimeParamsAtom
|
||||
)
|
||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||
const activeThreadState = useAtomValue(activeThreadStateAtom)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
|
||||
const updateModelParameter = async (
|
||||
threadId: string,
|
||||
params: ModelRuntimeParams
|
||||
name: string,
|
||||
value: number | boolean | string
|
||||
) => {
|
||||
const thread = threads.find((thread) => thread.id === threadId)
|
||||
if (!thread) {
|
||||
@ -37,27 +40,37 @@ export default function useUpdateModelParameters() {
|
||||
console.error('No active thread')
|
||||
return
|
||||
}
|
||||
const updatedModelParams: ModelParams = {
|
||||
...activeModelParams,
|
||||
[name]: value,
|
||||
}
|
||||
|
||||
// update the state
|
||||
setThreadModelRuntimeParams(thread.id, params)
|
||||
setThreadModelParams(thread.id, updatedModelParams)
|
||||
|
||||
if (!activeThreadState.isFinishInit) {
|
||||
// if thread is not initialized, we don't need to update thread.json
|
||||
return
|
||||
}
|
||||
|
||||
const assistants = thread.assistants.map((assistant) => {
|
||||
assistant.model.parameters = params
|
||||
return assistant
|
||||
})
|
||||
const assistants = thread.assistants.map(
|
||||
(assistant: ThreadAssistantInfo) => {
|
||||
const runtimeParams = toRuntimeParams(updatedModelParams)
|
||||
const settingParams = toSettingParams(updatedModelParams)
|
||||
|
||||
assistant.model.parameters = runtimeParams
|
||||
assistant.model.settings = settingParams
|
||||
return assistant
|
||||
}
|
||||
)
|
||||
|
||||
// update thread
|
||||
const updatedThread: Thread = {
|
||||
...thread,
|
||||
assistants,
|
||||
}
|
||||
updateThread(updatedThread)
|
||||
extensionManager
|
||||
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionType.Conversational)
|
||||
?.saveThread(updatedThread)
|
||||
}
|
||||
|
||||
31
web/screens/Chat/EngineSetting/index.tsx
Normal file
31
web/screens/Chat/EngineSetting/index.tsx
Normal file
@ -0,0 +1,31 @@
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import { selectedModelAtom } from '@/containers/DropdownListSidebar'
|
||||
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
import { toSettingParams } from '@/utils/model_param'
|
||||
|
||||
import settingComponentBuilder from '../ModelSetting/settingComponentBuilder'
|
||||
|
||||
import { getActiveThreadModelParamsAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const EngineSetting: React.FC = () => {
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
|
||||
if (!selectedModel || !activeModelParams) return null
|
||||
|
||||
const modelSettingParams = toSettingParams(activeModelParams)
|
||||
|
||||
const componentData = getConfigurationsData(modelSettingParams)
|
||||
|
||||
componentData.sort((a, b) => a.title.localeCompare(b.title))
|
||||
|
||||
return (
|
||||
<form className="flex flex-col">
|
||||
{settingComponentBuilder(componentData)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default EngineSetting
|
||||
@ -1,47 +1,33 @@
|
||||
import { useForm } from 'react-hook-form'
|
||||
|
||||
import { ModelRuntimeParams } from '@janhq/core'
|
||||
import React from 'react'
|
||||
|
||||
import { useAtomValue } from 'jotai'
|
||||
|
||||
import { presetConfiguration } from './predefinedComponent'
|
||||
import settingComponentBuilder, {
|
||||
SettingComponentData,
|
||||
} from './settingComponentBuilder'
|
||||
import { selectedModelAtom } from '@/containers/DropdownListSidebar'
|
||||
|
||||
import { getActiveThreadModelRuntimeParamsAtom } from '@/helpers/atoms/Thread.atom'
|
||||
import { getConfigurationsData } from '@/utils/componentSettings'
|
||||
import { toRuntimeParams } from '@/utils/model_param'
|
||||
|
||||
export default function ModelSetting() {
|
||||
const { register } = useForm()
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom)
|
||||
import settingComponentBuilder from './settingComponentBuilder'
|
||||
|
||||
if (!activeModelParams) {
|
||||
return null
|
||||
}
|
||||
import { getActiveThreadModelParamsAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
const componentData: SettingComponentData[] = []
|
||||
Object.keys(activeModelParams).forEach((key) => {
|
||||
const componentSetting = presetConfiguration[key]
|
||||
const ModelSetting: React.FC = () => {
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
|
||||
if (componentSetting) {
|
||||
if ('value' in componentSetting.controllerData) {
|
||||
componentSetting.controllerData.value = Number(
|
||||
activeModelParams[key as keyof ModelRuntimeParams]
|
||||
)
|
||||
} else if ('checked' in componentSetting.controllerData) {
|
||||
const checked = activeModelParams[
|
||||
key as keyof ModelRuntimeParams
|
||||
] as boolean
|
||||
if (!selectedModel || !activeModelParams) return null
|
||||
|
||||
componentSetting.controllerData.checked = checked
|
||||
}
|
||||
componentData.push(componentSetting)
|
||||
}
|
||||
})
|
||||
const modelRuntimeParams = toRuntimeParams(activeModelParams)
|
||||
|
||||
const componentData = getConfigurationsData(modelRuntimeParams)
|
||||
|
||||
componentData.sort((a, b) => a.title.localeCompare(b.title))
|
||||
|
||||
return (
|
||||
<form className="flex flex-col">
|
||||
{settingComponentBuilder(componentData, register)}
|
||||
{settingComponentBuilder(componentData)}
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(ModelSetting)
|
||||
|
||||
@ -1,10 +1,43 @@
|
||||
import { SettingComponentData } from './settingComponentBuilder'
|
||||
|
||||
export const presetConfiguration: Record<string, SettingComponentData> = {
|
||||
prompt_template: {
|
||||
name: 'prompt_template',
|
||||
title: 'Prompt template',
|
||||
description: 'Prompt template',
|
||||
controllerType: 'input',
|
||||
controllerData: {
|
||||
placeholder: 'Prompt template',
|
||||
value: '',
|
||||
},
|
||||
},
|
||||
stop: {
|
||||
name: 'stop',
|
||||
title: 'Stop',
|
||||
description: 'Stop',
|
||||
controllerType: 'input',
|
||||
controllerData: {
|
||||
placeholder: 'Stop',
|
||||
value: '',
|
||||
},
|
||||
},
|
||||
ctx_len: {
|
||||
name: 'ctx_len',
|
||||
title: 'Context Length',
|
||||
description: 'Context Length',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 4096,
|
||||
step: 128,
|
||||
value: 1024,
|
||||
},
|
||||
},
|
||||
max_tokens: {
|
||||
name: 'max_tokens',
|
||||
title: 'Max Tokens',
|
||||
description: 'Maximum context length the model can handle.',
|
||||
description:
|
||||
'The maximum number of tokens the model will generate in a single response.',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
@ -56,4 +89,52 @@ export const presetConfiguration: Record<string, SettingComponentData> = {
|
||||
value: 0.7,
|
||||
},
|
||||
},
|
||||
frequency_penalty: {
|
||||
name: 'frequency_penalty',
|
||||
title: 'Frequency Penalty',
|
||||
description: 'Frequency Penalty',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.1,
|
||||
value: 0.7,
|
||||
},
|
||||
},
|
||||
presence_penalty: {
|
||||
name: 'presence_penalty',
|
||||
title: 'Presence Penalty',
|
||||
description: 'Presence Penalty',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.1,
|
||||
value: 0.7,
|
||||
},
|
||||
},
|
||||
top_p: {
|
||||
name: 'top_p',
|
||||
title: 'Top P',
|
||||
description: 'Top P',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 0,
|
||||
max: 1,
|
||||
step: 0.1,
|
||||
value: 0.95,
|
||||
},
|
||||
},
|
||||
n_parallel: {
|
||||
name: 'n_parallel',
|
||||
title: 'N Parallel',
|
||||
description: 'N Parallel',
|
||||
controllerType: 'slider',
|
||||
controllerData: {
|
||||
min: 1,
|
||||
max: 4,
|
||||
step: 1,
|
||||
value: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
/* eslint-disable no-case-declarations */
|
||||
import { FieldValues, UseFormRegister } from 'react-hook-form'
|
||||
|
||||
import Checkbox from '@/containers/Checkbox'
|
||||
import ModelConfigInput from '@/containers/ModelConfigInput'
|
||||
import Slider from '@/containers/Slider'
|
||||
|
||||
export type ControllerType = 'slider' | 'checkbox'
|
||||
export type ControllerType = 'slider' | 'checkbox' | 'input'
|
||||
|
||||
export type SettingComponentData = {
|
||||
name: string
|
||||
title: string
|
||||
description: string
|
||||
controllerType: ControllerType
|
||||
controllerData: SliderData | CheckboxData
|
||||
controllerData: SliderData | CheckboxData | InputData
|
||||
}
|
||||
|
||||
export type InputData = {
|
||||
placeholder: string
|
||||
value: string
|
||||
}
|
||||
|
||||
export type SliderData = {
|
||||
@ -25,10 +29,7 @@ type CheckboxData = {
|
||||
checked: boolean
|
||||
}
|
||||
|
||||
const settingComponentBuilder = (
|
||||
componentData: SettingComponentData[],
|
||||
register: UseFormRegister<FieldValues>
|
||||
) => {
|
||||
const settingComponentBuilder = (componentData: SettingComponentData[]) => {
|
||||
const components = componentData.map((data) => {
|
||||
switch (data.controllerType) {
|
||||
case 'slider':
|
||||
@ -42,7 +43,18 @@ const settingComponentBuilder = (
|
||||
step={step}
|
||||
value={value}
|
||||
name={data.name}
|
||||
register={register}
|
||||
/>
|
||||
)
|
||||
case 'input':
|
||||
const { placeholder, value: textValue } =
|
||||
data.controllerData as InputData
|
||||
return (
|
||||
<ModelConfigInput
|
||||
title={data.title}
|
||||
key={data.name}
|
||||
name={data.name}
|
||||
placeholder={placeholder}
|
||||
value={textValue}
|
||||
/>
|
||||
)
|
||||
case 'checkbox':
|
||||
@ -50,7 +62,6 @@ const settingComponentBuilder = (
|
||||
return (
|
||||
<Checkbox
|
||||
key={data.name}
|
||||
register={register}
|
||||
name={data.name}
|
||||
title={data.title}
|
||||
checked={checked}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import { join } from 'path'
|
||||
|
||||
import React from 'react'
|
||||
|
||||
import { getUserSpace, openFileExplorer } from '@janhq/core'
|
||||
|
||||
import { Input, Textarea } from '@janhq/uikit'
|
||||
@ -16,19 +18,29 @@ import DropdownListSidebar, {
|
||||
|
||||
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
|
||||
|
||||
import { toSettingParams } from '@/utils/model_param'
|
||||
|
||||
import EngineSetting from '../EngineSetting'
|
||||
import ModelSetting from '../ModelSetting'
|
||||
|
||||
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
|
||||
import {
|
||||
activeThreadAtom,
|
||||
getActiveThreadModelParamsAtom,
|
||||
threadStatesAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const showRightSideBarAtom = atom<boolean>(true)
|
||||
|
||||
export default function Sidebar() {
|
||||
const Sidebar: React.FC = () => {
|
||||
const showing = useAtomValue(showRightSideBarAtom)
|
||||
const activeThread = useAtomValue(activeThreadAtom)
|
||||
const selectedModel = useAtomValue(selectedModelAtom)
|
||||
const { updateThreadMetadata } = useCreateNewThread()
|
||||
const threadStates = useAtomValue(threadStatesAtom)
|
||||
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const modelSettingParams = toSettingParams(activeModelParams)
|
||||
|
||||
const onReviewInFinderClick = async (type: string) => {
|
||||
if (!activeThread) return
|
||||
const activeThreadState = threadStates[activeThread.id]
|
||||
@ -187,6 +199,17 @@ export default function Sidebar() {
|
||||
</div>
|
||||
</div>
|
||||
</CardSidebar>
|
||||
{Object.keys(modelSettingParams).length ? (
|
||||
<CardSidebar
|
||||
title="Engine"
|
||||
onRevealInFinderClick={onReviewInFinderClick}
|
||||
onViewJsonClick={onViewJsonClick}
|
||||
>
|
||||
<div className="p-2">
|
||||
<EngineSetting />
|
||||
</div>
|
||||
</CardSidebar>
|
||||
) : null}
|
||||
<CardSidebar
|
||||
title="Model"
|
||||
onRevealInFinderClick={onReviewInFinderClick}
|
||||
@ -203,3 +226,5 @@ export default function Sidebar() {
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(Sidebar)
|
||||
|
||||
39
web/utils/componentSettings.ts
Normal file
39
web/utils/componentSettings.ts
Normal file
@ -0,0 +1,39 @@
|
||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||
|
||||
import { presetConfiguration } from '@/screens/Chat/ModelSetting/predefinedComponent'
|
||||
|
||||
import { SettingComponentData } from '@/screens/Chat/ModelSetting/settingComponentBuilder'
|
||||
|
||||
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const getConfigurationsData = (
|
||||
settings: ModelSettingParams | ModelRuntimeParams
|
||||
) => {
|
||||
const componentData: SettingComponentData[] = []
|
||||
Object.keys(settings).forEach((key: string) => {
|
||||
const componentSetting = presetConfiguration[key]
|
||||
|
||||
if (!componentSetting) {
|
||||
return
|
||||
}
|
||||
if ('slider' === componentSetting.controllerType) {
|
||||
const value = Number(settings[key as keyof ModelParams])
|
||||
if ('value' in componentSetting.controllerData)
|
||||
componentSetting.controllerData.value = value
|
||||
} else if ('input' === componentSetting.controllerType) {
|
||||
const value = settings[key as keyof ModelParams] as string
|
||||
const placeholder = settings[key as keyof ModelParams] as string
|
||||
if ('value' in componentSetting.controllerData)
|
||||
componentSetting.controllerData.value = value
|
||||
if ('placeholder' in componentSetting.controllerData)
|
||||
componentSetting.controllerData.placeholder = placeholder
|
||||
} else if ('checkbox' === componentSetting.controllerType) {
|
||||
const checked = settings[key as keyof ModelParams] as boolean
|
||||
|
||||
if ('checked' in componentSetting.controllerData)
|
||||
componentSetting.controllerData.checked = checked
|
||||
}
|
||||
componentData.push(componentSetting)
|
||||
})
|
||||
return componentData
|
||||
}
|
||||
55
web/utils/model_param.ts
Normal file
55
web/utils/model_param.ts
Normal file
@ -0,0 +1,55 @@
|
||||
import { ModelRuntimeParams, ModelSettingParams } from '@janhq/core'
|
||||
|
||||
import { ModelParams } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const toRuntimeParams = (
|
||||
modelParams?: ModelParams
|
||||
): ModelRuntimeParams => {
|
||||
if (!modelParams) return {}
|
||||
const defaultModelParams: ModelRuntimeParams = {
|
||||
temperature: undefined,
|
||||
token_limit: undefined,
|
||||
top_k: undefined,
|
||||
top_p: undefined,
|
||||
stream: undefined,
|
||||
max_tokens: undefined,
|
||||
stop: undefined,
|
||||
frequency_penalty: undefined,
|
||||
presence_penalty: undefined,
|
||||
}
|
||||
|
||||
const runtimeParams: ModelRuntimeParams = {}
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultModelParams) {
|
||||
// @ts-ignore
|
||||
runtimeParams[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return runtimeParams
|
||||
}
|
||||
|
||||
export const toSettingParams = (
|
||||
modelParams?: ModelParams
|
||||
): ModelSettingParams => {
|
||||
if (!modelParams) return {}
|
||||
const defaultSettingParams: ModelSettingParams = {
|
||||
ctx_len: undefined,
|
||||
ngl: undefined,
|
||||
embedding: undefined,
|
||||
n_parallel: undefined,
|
||||
cpu_threads: undefined,
|
||||
prompt_template: undefined,
|
||||
}
|
||||
const settingParams: ModelSettingParams = {}
|
||||
|
||||
for (const [key, value] of Object.entries(modelParams)) {
|
||||
if (key in defaultSettingParams) {
|
||||
// @ts-ignore
|
||||
settingParams[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return settingParams
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user