diff --git a/web-app/src/containers/ScrollToBottom.tsx b/web-app/src/containers/ScrollToBottom.tsx index ac924df91..b1259480f 100644 --- a/web-app/src/containers/ScrollToBottom.tsx +++ b/web-app/src/containers/ScrollToBottom.tsx @@ -18,7 +18,7 @@ const ScrollToBottom = ({ }) => { const { t } = useTranslation() const appMainViewBgColor = useAppearance((state) => state.appMainViewBgColor) - const { showScrollToBottomBtn, scrollToBottom, setIsUserScrolling } = + const { showScrollToBottomBtn, scrollToBottom } = useThreadScrolling(threadId, scrollContainerRef) const { messages } = useMessages( useShallow((state) => ({ @@ -50,7 +50,6 @@ const ScrollToBottom = ({ className="bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto" onClick={() => { scrollToBottom(true) - setIsUserScrolling(false) }} >

{t('scrollToBottom')}

diff --git a/web-app/src/hooks/useThreadScrolling.tsx b/web-app/src/hooks/useThreadScrolling.tsx index 9dfbeefb7..9352a88bb 100644 --- a/web-app/src/hooks/useThreadScrolling.tsx +++ b/web-app/src/hooks/useThreadScrolling.tsx @@ -1,8 +1,9 @@ -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useAppState } from './useAppState' import { useMessages } from './useMessages' -import { useShallow } from 'zustand/react/shallow' -import debounce from 'lodash.debounce' + +const VIEWPORT_PADDING = 40 // Offset from viewport bottom for user message positioning +const MAX_DOM_RETRY_ATTEMPTS = 3 // Maximum attempts to find DOM elements before giving up export const useThreadScrolling = ( threadId: string, @@ -10,18 +11,36 @@ export const useThreadScrolling = ( ) => { const streamingContent = useAppState((state) => state.streamingContent) const isFirstRender = useRef(true) - const { messages } = useMessages( - useShallow((state) => ({ - messages: state.messages[threadId], - })) - ) const wasStreamingRef = useRef(false) const userIntendedPositionRef = useRef(null) - const [isUserScrolling, setIsUserScrolling] = useState(false) const [isAtBottom, setIsAtBottom] = useState(true) const [hasScrollbar, setHasScrollbar] = useState(false) const lastScrollTopRef = useRef(0) - const messagesCount = useMemo(() => messages?.length ?? 0, [messages]) + + const messageCount = useMessages((state) => state.messages[threadId]?.length ?? 0) + const lastMessageRole = useMessages((state) => { + const msgs = state.messages[threadId] + return msgs && msgs.length > 0 ? msgs[msgs.length - 1].role : null + }) + + const [paddingHeight, setPaddingHeightInternal] = useState(0) + const setPaddingHeight = setPaddingHeightInternal + const originalPaddingRef = useRef(0) + + const getDOMElements = useCallback(() => { + const scrollContainer = scrollContainerRef.current + if (!scrollContainer) return null + + const userMessages = scrollContainer.querySelectorAll('[data-message-author-role="user"]') + const assistantMessages = scrollContainer.querySelectorAll('[data-message-author-role="assistant"]') + + return { + scrollContainer, + lastUserMessage: userMessages[userMessages.length - 1] as HTMLElement, + lastAssistantMessage: assistantMessages[assistantMessages.length - 1] as HTMLElement, + } + }, []) + const showScrollToBottomBtn = !isAtBottom && hasScrollbar @@ -32,20 +51,16 @@ export const useThreadScrolling = ( ...(smooth ? { behavior: 'smooth' } : {}), }) } - }, []) + }, [scrollContainerRef]) + const handleScroll = useCallback((e: Event) => { const target = e.target as HTMLDivElement const { scrollTop, scrollHeight, clientHeight } = target - // Use a small tolerance to better detect when we're at the bottom const isBottom = Math.abs(scrollHeight - scrollTop - clientHeight) < 10 const hasScroll = scrollHeight > clientHeight - // Detect if this is a user-initiated scroll if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) { - setIsUserScrolling(!isBottom) - - // If user scrolls during streaming and moves away from bottom, record their intended position if (streamingContent && !isBottom) { userIntendedPositionRef.current = scrollTop } @@ -76,117 +91,103 @@ export const useThreadScrolling = ( setHasScrollbar(hasScroll) }, []) - // Single useEffect for all auto-scrolling logic useEffect(() => { - // Track streaming state changes - const isCurrentlyStreaming = !!streamingContent - const justFinishedStreaming = - wasStreamingRef.current && !isCurrentlyStreaming - wasStreamingRef.current = isCurrentlyStreaming - - // If streaming just finished and user had an intended position, restore it - if (justFinishedStreaming && userIntendedPositionRef.current !== null) { - // Small delay to ensure DOM has updated - setTimeout(() => { - if ( - scrollContainerRef.current && - userIntendedPositionRef.current !== null - ) { - scrollContainerRef.current.scrollTo({ - top: userIntendedPositionRef.current, - behavior: 'smooth', - }) - userIntendedPositionRef.current = null - setIsUserScrolling(false) - } - }, 100) - return - } - // Clear intended position when streaming starts fresh - if (isCurrentlyStreaming && !wasStreamingRef.current) { - userIntendedPositionRef.current = null - } - - // Only auto-scroll when the user is not actively scrolling - // AND either at the bottom OR there's streaming content - if (!isUserScrolling && (streamingContent || isAtBottom) && messagesCount) { - // Use non-smooth scrolling for auto-scroll to prevent jank - scrollToBottom(false) - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [streamingContent, isUserScrolling, messagesCount]) - - useEffect(() => { - if (streamingContent) { - const interval = setInterval(checkScrollState, 100) - return () => clearInterval(interval) - } - }, [streamingContent, checkScrollState]) - - // Auto-scroll to bottom when component mounts or thread content changes - useEffect(() => { - const scrollContainer = scrollContainerRef.current - if (!scrollContainer) return - - // Always scroll to bottom on first render or when thread changes + if (!scrollContainerRef.current) return if (isFirstRender.current) { isFirstRender.current = false - scrollToBottom() - setIsAtBottom(true) - setIsUserScrolling(false) userIntendedPositionRef.current = null wasStreamingRef.current = false + scrollToBottom(false) checkScrollState() - return } }, [checkScrollState, scrollToBottom]) - const handleDOMScroll = (e: Event) => { - const target = e.target as HTMLDivElement - const { scrollTop, scrollHeight, clientHeight } = target - // Use a small tolerance to better detect when we're at the bottom - const isBottom = Math.abs(scrollHeight - scrollTop - clientHeight) < 10 - const hasScroll = scrollHeight > clientHeight - // Detect if this is a user-initiated scroll - if (Math.abs(scrollTop - lastScrollTopRef.current) > 10) { - setIsUserScrolling(!isBottom) + const prevCountRef = useRef(messageCount) + useEffect(() => { + const prevCount = prevCountRef.current + const becameLonger = messageCount > prevCount + const isUserMessage = lastMessageRole === 'user' - // If user scrolls during streaming and moves away from bottom, record their intended position - if (streamingContent && !isBottom) { - userIntendedPositionRef.current = scrollTop + if (becameLonger && messageCount > 0 && isUserMessage) { + const calculatePadding = () => { + const elements = getDOMElements() + if (!elements?.lastUserMessage) return + + const viewableHeight = elements.scrollContainer.clientHeight + const userMessageHeight = elements.lastUserMessage.offsetHeight + const calculatedPadding = Math.max(0, viewableHeight - VIEWPORT_PADDING - userMessageHeight) + + setPaddingHeight(calculatedPadding) + originalPaddingRef.current = calculatedPadding + + requestAnimationFrame(() => { + elements.scrollContainer.scrollTo({ + top: elements.scrollContainer.scrollHeight, + behavior: 'smooth', + }) + }) } + + let retryCount = 0 + + const tryCalculatePadding = () => { + if (getDOMElements()?.lastUserMessage) { + calculatePadding() + } else if (retryCount < MAX_DOM_RETRY_ATTEMPTS) { + retryCount++ + requestAnimationFrame(tryCalculatePadding) + } + } + + requestAnimationFrame(tryCalculatePadding) } - setIsAtBottom(isBottom) - setHasScrollbar(hasScroll) - lastScrollTopRef.current = scrollTop - } - // Use a shorter debounce time for more responsive scrolling - const debouncedScroll = debounce(handleDOMScroll) + + prevCountRef.current = messageCount + }, [messageCount, lastMessageRole]) useEffect(() => { - const chatHistoryElement = scrollContainerRef.current - if (chatHistoryElement) { - chatHistoryElement.addEventListener('scroll', debouncedScroll) - return () => - chatHistoryElement.removeEventListener('scroll', debouncedScroll) - } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, []) + const previouslyStreaming = wasStreamingRef.current + const currentlyStreaming = !!streamingContent && streamingContent.thread_id === threadId + + const streamingEnded = previouslyStreaming && !currentlyStreaming + const hasPaddingToAdjust = originalPaddingRef.current > 0 + + if (streamingEnded && hasPaddingToAdjust) { + requestAnimationFrame(() => { + const elements = getDOMElements() + if (!elements?.lastAssistantMessage || !elements?.lastUserMessage) return + + const userRect = elements.lastUserMessage.getBoundingClientRect() + const assistantRect = elements.lastAssistantMessage.getBoundingClientRect() + const actualSpacing = assistantRect.top - userRect.bottom + const totalAssistantHeight = elements.lastAssistantMessage.offsetHeight + actualSpacing + const newPadding = Math.max(0, originalPaddingRef.current - totalAssistantHeight) + + setPaddingHeight(newPadding) + originalPaddingRef.current = newPadding + }) + } + + wasStreamingRef.current = currentlyStreaming + }, [streamingContent, threadId]) - // Reset scroll state when thread changes useEffect(() => { - isFirstRender.current = true - scrollToBottom() - setIsAtBottom(true) - setIsUserScrolling(false) userIntendedPositionRef.current = null wasStreamingRef.current = false + setPaddingHeight(0) + originalPaddingRef.current = 0 + prevCountRef.current = messageCount + scrollToBottom(false) checkScrollState() - }, [threadId, checkScrollState, scrollToBottom]) + }, [threadId]) return useMemo( - () => ({ showScrollToBottomBtn, scrollToBottom, setIsUserScrolling }), - [showScrollToBottomBtn, scrollToBottom, setIsUserScrolling] + () => ({ + showScrollToBottomBtn, + scrollToBottom, + paddingHeight + }), + [showScrollToBottomBtn, scrollToBottom, paddingHeight] ) } diff --git a/web-app/src/routes/threads/$threadId.tsx b/web-app/src/routes/threads/$threadId.tsx index 80740935d..384cb764f 100644 --- a/web-app/src/routes/threads/$threadId.tsx +++ b/web-app/src/routes/threads/$threadId.tsx @@ -21,6 +21,7 @@ import { PlatformFeatures } from '@/lib/platform/const' import { PlatformFeature } from '@/lib/platform/types' import ScrollToBottom from '@/containers/ScrollToBottom' import { PromptProgress } from '@/components/PromptProgress' +import { useThreadScrolling } from '@/hooks/useThreadScrolling' // as route.threadsDetail export const Route = createFileRoute('/threads/$threadId')({ @@ -48,6 +49,9 @@ function ThreadDetail() { const thread = useThreads(useShallow((state) => state.threads[threadId])) const scrollContainerRef = useRef(null) + // Get padding height for ChatGPT-style message positioning + const { paddingHeight } = useThreadScrolling(threadId, scrollContainerRef) + useEffect(() => { setCurrentThreadId(threadId) const assistant = assistants.find( @@ -186,6 +190,12 @@ function ThreadDetail() { threadId={threadId} data-test-id="thread-content-text" /> + {/* Persistent padding element for ChatGPT-style message positioning */} +