diff --git a/extensions/inference-cortex-extension/bin/version.txt b/extensions/inference-cortex-extension/bin/version.txt index af0b7ddbf..238d6e882 100644 --- a/extensions/inference-cortex-extension/bin/version.txt +++ b/extensions/inference-cortex-extension/bin/version.txt @@ -1 +1 @@ -1.0.6 +1.0.7 diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index 8c565bab1..42376c081 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -56,7 +56,7 @@ export default function ModelHandler() { const activeModel = useAtomValue(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom) const setStateModel = useSetAtom(stateModelAtom) - const [subscribedGeneratingMessage, setSubscribedGeneratingMessage] = useAtom( + const subscribedGeneratingMessage = useAtomValue( subscribedGeneratingMessageAtom ) const activeThread = useAtomValue(activeThreadAtom) diff --git a/web/helpers/atoms/Thread.atom.ts b/web/helpers/atoms/Thread.atom.ts index c94d287b5..6704f8e57 100644 --- a/web/helpers/atoms/Thread.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -1,7 +1,7 @@ import { Thread, ThreadContent, ThreadState } from '@janhq/core' import { atom } from 'jotai' -import { atomWithStorage } from 'jotai/utils' +import { atomWithStorage, selectAtom } from 'jotai/utils' import { ModelParams } from '@/types/model' @@ -34,6 +34,22 @@ export const threadStatesAtom = atomWithStorage>( {} ) +/** + * Returns whether there is a thread waiting for response or not + */ +const isWaitingForResponseAtom = selectAtom(threadStatesAtom, (threads) => + Object.values(threads).some((t) => t.waitingForResponse) +) + +/** + * Combine 2 states to reduce rerender + * 1. isWaitingForResponse + * 2. isGenerating + */ +export const isBlockingSendAtom = atom( + (get) => get(isWaitingForResponseAtom) || get(isGeneratingResponseAtom) +) + /** * Stores all threads for the current user */ @@ -173,6 +189,29 @@ export const updateThreadWaitingForResponseAtom = atom( } ) +/** + * Reset the thread waiting for response state + */ +export const resetThreadWaitingForResponseAtom = atom(null, (get, set) => { + const currentState = { ...get(threadStatesAtom) } + Object.keys(currentState).forEach((threadId) => { + currentState[threadId] = { + ...currentState[threadId], + waitingForResponse: false, + error: undefined, + } + }) + set(threadStatesAtom, currentState) +}) + +/** + * Reset all generating states + **/ +export const resetGeneratingResponseAtom = atom(null, (get, set) => { + set(resetThreadWaitingForResponseAtom) + set(isGeneratingResponseAtom, false) +}) + /** * Update the thread last message */ diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index ed704dd61..67023d1d3 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -10,6 +10,10 @@ import { LAST_USED_MODEL_ID } from './useRecommendedModel' import { vulkanEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { + isGeneratingResponseAtom, + resetThreadWaitingForResponseAtom, +} from '@/helpers/atoms/Thread.atom' export const activeModelAtom = atom(undefined) export const loadModelErrorAtom = atom(undefined) diff --git a/web/hooks/useCreateNewThread.test.ts b/web/hooks/useCreateNewThread.test.ts index d98983830..0ef8ef195 100644 --- a/web/hooks/useCreateNewThread.test.ts +++ b/web/hooks/useCreateNewThread.test.ts @@ -67,7 +67,6 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(1) expect(extensionManager.get).toHaveBeenCalled() }) @@ -113,7 +112,6 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() }) @@ -158,7 +156,6 @@ describe('useCreateNewThread', () => { } as any) }) - expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set expect(extensionManager.get).toHaveBeenCalled() }) diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index 90024b3da..b57384344 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -20,8 +20,6 @@ import { toaster } from '@/containers/Toast' import { isLocalEngine } from '@/utils/modelEngine' -import { useActiveModel } from './useActiveModel' - import useRecommendedModel from './useRecommendedModel' import useSetActiveThread from './useSetActiveThread' @@ -52,10 +50,8 @@ export const useCreateNewThread = () => { const [activeAssistant, setActiveAssistant] = useAtom(activeAssistantAtom) const experimentalEnabled = useAtomValue(experimentalFeatureEnabledAtom) - const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const threads = useAtomValue(threadsAtom) - const { stopInference } = useActiveModel() const { recommendedModel } = useRecommendedModel() @@ -63,10 +59,6 @@ export const useCreateNewThread = () => { assistant: (ThreadAssistantInfo & { id: string; name: string }) | Assistant, model?: Model | undefined ) => { - // Stop generating if any - setIsGeneratingResponse(false) - stopInference() - const defaultModel = model || recommendedModel if (!model) { diff --git a/web/hooks/useDeleteThread.ts b/web/hooks/useDeleteThread.ts index 1ea6c7579..29b509631 100644 --- a/web/hooks/useDeleteThread.ts +++ b/web/hooks/useDeleteThread.ts @@ -48,6 +48,7 @@ export default function useDeleteThread() { if (thread) { const updatedThread = { ...thread, + title: 'New Thread', metadata: { ...thread.metadata, title: 'New Thread', diff --git a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx index 19ec3328a..d04f9b233 100644 --- a/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/AssistantSetting/index.tsx @@ -12,6 +12,7 @@ import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' import { activeThreadAtom, engineParamsUpdateAtom, + resetGeneratingResponseAtom, } from '@/helpers/atoms/Thread.atom' type Props = { @@ -24,6 +25,7 @@ const AssistantSetting: React.FC = ({ componentData }) => { const { updateThreadMetadata } = useCreateNewThread() const { stopModel } = useActiveModel() const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom) + const resetGenerating = useSetAtom(resetGeneratingResponseAtom) const onValueChanged = useCallback( (key: string, value: string | number | boolean | string[]) => { @@ -32,6 +34,7 @@ const AssistantSetting: React.FC = ({ componentData }) => { componentData.find((x) => x.key === key)?.requireModelReload ?? false if (shouldReloadModel) { setEngineParamsUpdate(true) + resetGenerating() stopModel() } diff --git a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx index f384611c5..c47d19d67 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatBody/index.tsx @@ -16,8 +16,7 @@ import EmptyThread from './EmptyThread' import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { activeThreadAtom, - isGeneratingResponseAtom, - threadStatesAtom, + isBlockingSendAtom, } from '@/helpers/atoms/Thread.atom' const ChatConfigurator = memo(() => { @@ -65,12 +64,7 @@ const ChatBody = memo( const prevScrollTop = useRef(0) const isUserManuallyScrollingUp = useRef(false) const currentThread = useAtomValue(activeThreadAtom) - const threadStates = useAtomValue(threadStatesAtom) - const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) - - const isStreamingResponse = Object.values(threadStates).some( - (threadState) => threadState.waitingForResponse - ) + const isBlockingSend = useAtomValue(isBlockingSendAtom) const count = useMemo( () => (messages?.length ?? 0) + (loadModelError ? 1 : 0), @@ -85,34 +79,13 @@ const ChatBody = memo( overscan: 5, }) - useEffect(() => { - if (parentRef.current) { - parentRef.current.scrollTo({ top: parentRef.current.scrollHeight }) - virtualizer.scrollToIndex(count - 1) - } - }, [count, virtualizer]) - - useEffect(() => { - if (parentRef.current && isGeneratingResponse) { - parentRef.current.scrollTo({ top: parentRef.current.scrollHeight }) - virtualizer.scrollToIndex(count - 1) - } - }, [count, virtualizer, isGeneratingResponse]) - - useEffect(() => { - if (parentRef.current && isGeneratingResponse) { - parentRef.current.scrollTo({ top: parentRef.current.scrollHeight }) - virtualizer.scrollToIndex(count - 1) - } - }, [count, virtualizer, isGeneratingResponse, currentThread?.id]) - useEffect(() => { isUserManuallyScrollingUp.current = false - if (parentRef.current) { + if (parentRef.current && isBlockingSend) { parentRef.current.scrollTo({ top: parentRef.current.scrollHeight }) virtualizer.scrollToIndex(count - 1) } - }, [count, currentThread?.id, virtualizer]) + }, [count, virtualizer, isBlockingSend, currentThread?.id]) const items = virtualizer.getVirtualItems() @@ -121,7 +94,7 @@ const ChatBody = memo( _, instance ) => { - if (isUserManuallyScrollingUp.current === true && isStreamingResponse) + if (isUserManuallyScrollingUp.current === true && isBlockingSend) return false return ( // item.start < (instance.scrollOffset ?? 0) && @@ -133,7 +106,7 @@ const ChatBody = memo( (event: React.UIEvent) => { const currentScrollTop = event.currentTarget.scrollTop - if (prevScrollTop.current > currentScrollTop && isStreamingResponse) { + if (prevScrollTop.current > currentScrollTop && isBlockingSend) { isUserManuallyScrollingUp.current = true } else { const currentScrollTop = event.currentTarget.scrollTop @@ -151,7 +124,7 @@ const ChatBody = memo( } prevScrollTop.current = currentScrollTop }, - [isStreamingResponse] + [isBlockingSend] ) return ( diff --git a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx index 0ba50880b..990d24c7a 100644 --- a/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/ChatInput/index.tsx @@ -35,22 +35,19 @@ import RichTextEditor from './RichTextEditor' import { showRightPanelAtom } from '@/helpers/atoms/App.atom' import { experimentalFeatureEnabledAtom } from '@/helpers/atoms/AppConfig.atom' import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom' -import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { spellCheckAtom } from '@/helpers/atoms/Setting.atom' import { activeSettingInputBoxAtom, activeThreadAtom, getActiveThreadIdAtom, - isGeneratingResponseAtom, - threadStatesAtom, + isBlockingSendAtom, } from '@/helpers/atoms/Thread.atom' import { activeTabThreadRightPanelAtom } from '@/helpers/atoms/ThreadRightPanel.atom' const ChatInput = () => { const activeThread = useAtomValue(activeThreadAtom) const { stateModel } = useActiveModel() - const messages = useAtomValue(getCurrentChatMessagesAtom) const spellCheck = useAtomValue(spellCheckAtom) const [currentPrompt, setCurrentPrompt] = useAtom(currentPromptAtom) @@ -67,8 +64,7 @@ const ChatInput = () => { const fileInputRef = useRef(null) const imageInputRef = useRef(null) const experimentalFeature = useAtomValue(experimentalFeatureEnabledAtom) - const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom) - const threadStates = useAtomValue(threadStatesAtom) + const isBlockingSend = useAtomValue(isBlockingSendAtom) const activeAssistant = useAtomValue(activeAssistantAtom) const { stopInference } = useActiveModel() @@ -77,10 +73,6 @@ const ChatInput = () => { activeTabThreadRightPanelAtom ) - const isStreamingResponse = Object.values(threadStates).some( - (threadState) => threadState.waitingForResponse - ) - const refAttachmentMenus = useClickOutside(() => setShowAttacmentMenus(false)) const [showRightPanel, setShowRightPanel] = useAtom(showRightPanelAtom) @@ -302,9 +294,7 @@ const ChatInput = () => { )} - {messages[messages.length - 1]?.status !== MessageStatus.Pending && - !isGeneratingResponse && - !isStreamingResponse ? ( + {!isBlockingSend ? ( <> {currentPrompt.length !== 0 && (