fix(thread): #1043 default model to prefer active model (#1070)

Signed-off-by: James <james@jan.ai>
Co-authored-by: James <james@jan.ai>
This commit is contained in:
NamH 2023-12-19 10:51:41 +07:00 committed by GitHub
parent d528dc8a81
commit 55ab4ae70f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 137 additions and 48 deletions

View File

@ -1,4 +1,4 @@
import { useEffect, useState } from 'react' import { useCallback, useEffect, useState } from 'react'
import { InferenceEngine, Model } from '@janhq/core' import { InferenceEngine, Model } from '@janhq/core'
import { import {
@ -20,12 +20,12 @@ import { twMerge } from 'tailwind-merge'
import { MainViewState } from '@/constants/screens' import { MainViewState } from '@/constants/screens'
import { useActiveModel } from '@/hooks/useActiveModel'
import { useEngineSettings } from '@/hooks/useEngineSettings' import { useEngineSettings } from '@/hooks/useEngineSettings'
import { getDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { useMainViewState } from '@/hooks/useMainViewState' import { useMainViewState } from '@/hooks/useMainViewState'
import useRecommendedModel from '@/hooks/useRecommendedModel'
import { toGigabytes } from '@/utils/converter' import { toGigabytes } from '@/utils/converter'
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
@ -33,13 +33,12 @@ import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
export const selectedModelAtom = atom<Model | undefined>(undefined) export const selectedModelAtom = atom<Model | undefined>(undefined)
export default function DropdownListSidebar() { export default function DropdownListSidebar() {
const [downloadedModels, setDownloadedModels] = useState<Model[]>([])
const setSelectedModel = useSetAtom(selectedModelAtom) const setSelectedModel = useSetAtom(selectedModelAtom)
const threadStates = useAtomValue(threadStatesAtom)
const activeThread = useAtomValue(activeThreadAtom) const activeThread = useAtomValue(activeThreadAtom)
const [selected, setSelected] = useState<Model | undefined>() const [selected, setSelected] = useState<Model | undefined>()
const { setMainViewState } = useMainViewState() const { setMainViewState } = useMainViewState()
const { activeModel, stateModel } = useActiveModel() const [openAISettings, setOpenAISettings] = useState<
const [opeenAISettings, setOpenAISettings] = useState<
{ api_key: string } | undefined { api_key: string } | undefined
>(undefined) >(undefined)
const { readOpenAISettings, saveOpenAISettings } = useEngineSettings() const { readOpenAISettings, saveOpenAISettings } = useEngineSettings()
@ -50,43 +49,27 @@ export default function DropdownListSidebar() {
}) })
}, []) }, [])
useEffect(() => { const { recommendedModel, downloadedModels } = useRecommendedModel()
getDownloadedModels().then((downloadedModels) => {
setDownloadedModels( useEffect(() => {
downloadedModels.sort((a, b) => setSelected(recommendedModel)
a.engine !== InferenceEngine.nitro && setSelectedModel(recommendedModel)
b.engine === InferenceEngine.nitro }, [recommendedModel, setSelectedModel])
? 1
: -1 const onValueSelected = useCallback(
) (modelId: string) => {
) const model = downloadedModels.find((m) => m.id === modelId)
if (downloadedModels.length > 0) { setSelected(model)
setSelected( setSelectedModel(model)
downloadedModels.filter( },
(x) => x.id === activeThread?.assistants[0].model.id [downloadedModels, setSelectedModel]
)[0] || downloadedModels[0] )
)
setSelectedModel(
downloadedModels.filter(
(x) => x.id === activeThread?.assistants[0].model.id
)[0] || downloadedModels[0]
)
}
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [activeThread, activeModel, stateModel.loading])
const threadStates = useAtomValue(threadStatesAtom)
if (!activeThread) { if (!activeThread) {
return null return null
} }
const finishInit = threadStates[activeThread.id].isFinishInit ?? true const finishInit = threadStates[activeThread.id].isFinishInit ?? true
const onValueSelected = (value: string) => {
setSelected(downloadedModels.filter((x) => x.id === value)[0])
setSelectedModel(downloadedModels.filter((x) => x.id === value)[0])
}
return ( return (
<> <>
<Select <Select
@ -151,7 +134,7 @@ export default function DropdownListSidebar() {
<Input <Input
id="assistant-instructions" id="assistant-instructions"
placeholder="Enter your API_KEY" placeholder="Enter your API_KEY"
defaultValue={opeenAISettings?.api_key} defaultValue={openAISettings?.api_key}
onChange={(e) => { onChange={(e) => {
saveOpenAISettings({ apiKey: e.target.value }) saveOpenAISettings({ apiKey: e.target.value })
}} }}

View File

@ -1,13 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-explicit-any */
import { EventName, events } from '@janhq/core' import { EventName, events, Model } from '@janhq/core'
import { Model } from '@janhq/core'
import { atom, useAtom } from 'jotai' import { atom, useAtom } from 'jotai'
import { toaster } from '@/containers/Toast' import { toaster } from '@/containers/Toast'
import { useGetDownloadedModels } from './useGetDownloadedModels' import { useGetDownloadedModels } from './useGetDownloadedModels'
import { LAST_USED_MODEL_ID } from './useRecommendedModel'
import { extensionManager } from '@/extension'
export const activeModelAtom = atom<Model | undefined>(undefined) export const activeModelAtom = atom<Model | undefined>(undefined)
@ -51,6 +49,7 @@ export function useActiveModel() {
return return
} }
localStorage.setItem(LAST_USED_MODEL_ID, model.id)
events.emit(EventName.OnModelInit, model) events.emit(EventName.OnModelInit, model)
} }

View File

@ -20,10 +20,7 @@ export function useGetDownloadedModels() {
return { downloadedModels, setDownloadedModels } return { downloadedModels, setDownloadedModels }
} }
export async function getDownloadedModels(): Promise<Model[]> { export const getDownloadedModels = async (): Promise<Model[]> =>
const models = await extensionManager extensionManager
.get<ModelExtension>(ExtensionType.Model) .get<ModelExtension>(ExtensionType.Model)
?.getDownloadedModels() ?.getDownloadedModels() ?? []
return models ?? []
}

View File

@ -0,0 +1,110 @@
import { useCallback, useEffect, useState } from 'react'
import { Model, InferenceEngine } from '@janhq/core'
import { atom, useAtomValue } from 'jotai'
import { activeModelAtom } from './useActiveModel'
import { getDownloadedModels } from './useGetDownloadedModels'
import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom'
export const lastUsedModel = atom<Model | undefined>(undefined)
export const LAST_USED_MODEL_ID = 'last-used-model-id'
/**
* A hook that return the recommended model when user
* wants to create a new thread.
*
* The precedence is as follows:
* 1. Active model
* 2. If no active model(s), then the last used model
* 3. If no active or last used model, then the 1st model on the list
*/
export default function useRecommendedModel() {
const activeModel = useAtomValue(activeModelAtom)
const [downloadedModels, setDownloadedModels] = useState<Model[]>([])
const [recommendedModel, setRecommendedModel] = useState<Model | undefined>()
const threadStates = useAtomValue(threadStatesAtom)
const activeThread = useAtomValue(activeThreadAtom)
const getAndSortDownloadedModels = useCallback(async (): Promise<Model[]> => {
const models = (await getDownloadedModels()).sort((a, b) =>
a.engine !== InferenceEngine.nitro && b.engine === InferenceEngine.nitro
? 1
: -1
)
setDownloadedModels(models)
return models
}, [])
const getRecommendedModel = useCallback(async (): Promise<
Model | undefined
> => {
if (!activeThread) {
return
}
const finishInit = threadStates[activeThread.id].isFinishInit ?? true
if (finishInit) {
const modelId = activeThread.assistants[0]?.model.id
const models = await getAndSortDownloadedModels()
const model = models.find((model) => model.id === modelId)
if (model) {
setRecommendedModel(model)
}
return
}
if (activeModel) {
// if we have active model alr, then we can just use that
console.debug(`Using active model ${activeModel.id}`)
setRecommendedModel(activeModel)
return
}
// sort the model, for display purpose
const models = await getAndSortDownloadedModels()
if (models.length === 0) {
// if we have no downloaded models, then can't recommend anything
console.debug("No downloaded models, can't recommend anything")
return
}
// otherwise, get the last used model id
const lastUsedModelId = localStorage.getItem(LAST_USED_MODEL_ID)
// if we don't have [lastUsedModelId], then we can just use the first model
// in the downloaded list
if (!lastUsedModelId) {
console.debug(
`No last used model, using first model in list ${models[0].id}}`
)
setRecommendedModel(models[0])
return
}
const lastUsedModel = models.find((model) => model.id === lastUsedModelId)
if (!lastUsedModel) {
// if we can't find the last used model, then we can just use the first model
// in the downloaded list
console.debug(
`Last used model ${lastUsedModelId} not found, using first model in list ${models[0].id}}`
)
setRecommendedModel(models[0])
return
}
console.debug(`Using last used model ${lastUsedModel.id}`)
setRecommendedModel(lastUsedModel)
}, [getAndSortDownloadedModels, activeThread])
useEffect(() => {
getRecommendedModel()
}, [getRecommendedModel])
return { recommendedModel, downloadedModels }
}