From f6433544afd0b16488012f6e7eb2b98ecdd18d96 Mon Sep 17 00:00:00 2001 From: Louis Date: Mon, 19 May 2025 23:32:55 +0700 Subject: [PATCH] feat: handle stop streaming message, scroll to bottom and model loads (#5023) --- .../inference-cortex-extension/src/index.ts | 2 +- web-app/src/containers/ChatInput.tsx | 17 +++++++++++--- web-app/src/containers/StreamingContent.tsx | 8 +++++-- web-app/src/hooks/useAppState.ts | 11 ++++++++++ web-app/src/hooks/useChat.ts | 8 +++++-- web-app/src/lib/completion.ts | 22 ++++++++++++------- web-app/src/routes/threads/$threadId.tsx | 10 ++++++--- 7 files changed, 59 insertions(+), 19 deletions(-) diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 703c2007e..49f4392af 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -185,7 +185,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { console.log('Loaded models:', loadedModels) // This is to avoid loading the same model multiple times - if (loadedModels.some((model) => model.id === model.id)) { + if (loadedModels.some((e) => e.id === model.id)) { console.log(`Model ${model.id} already loaded`) return } diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 2d5b80170..63a396f34 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -3,7 +3,7 @@ import TextareaAutosize from 'react-textarea-autosize' import { cn } from '@/lib/utils' import { usePrompt } from '@/hooks/usePrompt' -import { useEffect, useRef, useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { Button } from '@/components/ui/button' import { ArrowRight } from 'lucide-react' import { @@ -44,7 +44,7 @@ const ChatInput = ({ const textareaRef = useRef(null) const [isFocused, setIsFocused] = useState(false) const [rows, setRows] = useState(1) - const { streamingContent, updateTools } = useAppState() + const { streamingContent, updateTools, abortControllers } = useAppState() const { prompt, setPrompt } = usePrompt() const { t } = useTranslation() const { spellCheckChatInput } = useGeneralSetting() @@ -97,6 +97,13 @@ const ChatInput = ({ } }, []) + const stopStreaming = useCallback( + (threadId: string) => { + abortControllers[threadId]?.abort() + }, + [abortControllers] + ) + return (
{streamingContent ? ( - ) : ( diff --git a/web-app/src/containers/StreamingContent.tsx b/web-app/src/containers/StreamingContent.tsx index 31bf4db3b..d391ff5a6 100644 --- a/web-app/src/containers/StreamingContent.tsx +++ b/web-app/src/containers/StreamingContent.tsx @@ -2,11 +2,15 @@ import { useAppState } from '@/hooks/useAppState' import { ThreadContent } from './ThreadContent' import { memo } from 'react' +type Props = { + threadId: string +} + // Use memo with no dependencies to allow re-renders when props change -export const StreamingContent = memo(() => { +export const StreamingContent = memo(({ threadId }: Props) => { const { streamingContent } = useAppState() - if (!streamingContent) return null + if (!streamingContent || streamingContent.thread_id !== threadId) return null // Pass a new object to ThreadContent to avoid reference issues // The streaming content is always the last message diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index 1c4e37db6..9e96476a3 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -7,10 +7,12 @@ type AppState = { loadingModel?: boolean tools: MCPTool[] serverStatus: 'running' | 'stopped' | 'pending' + abortControllers: Record setServerStatus: (value: 'running' | 'stopped' | 'pending') => void updateStreamingContent: (content: ThreadMessage | undefined) => void updateLoadingModel: (loading: boolean) => void updateTools: (tools: MCPTool[]) => void + setAbortController: (threadId: string, controller: AbortController) => void } export const useAppState = create()((set) => ({ @@ -18,6 +20,7 @@ export const useAppState = create()((set) => ({ loadingModel: false, tools: [], serverStatus: 'stopped', + abortControllers: {}, updateStreamingContent: (content) => { set({ streamingContent: content }) }, @@ -28,4 +31,12 @@ export const useAppState = create()((set) => ({ set({ tools }) }, setServerStatus: (value) => set({ serverStatus: value }), + setAbortController: (threadId, controller) => { + set((state) => ({ + abortControllers: { + ...state.abortControllers, + [threadId]: controller, + }, + })) + }, })) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 92d27bf54..12effe82d 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -27,7 +27,8 @@ export const useChat = () => { useModelProvider() const { getCurrentThread: retrieveThread, createThread } = useThreads() - const { updateStreamingContent, updateLoadingModel } = useAppState() + const { updateStreamingContent, updateLoadingModel, setAbortController } = + useAppState() const { addMessage } = useMessages() const router = useRouter() @@ -83,12 +84,14 @@ export const useChat = () => { builder.addUserMessage(message) let isCompleted = false - + const abortController = new AbortController() + setAbortController(activeThread.id, abortController) while (!isCompleted) { const completion = await sendCompletion( activeThread, provider, builder.getMessages(), + abortController, tools ) @@ -141,6 +144,7 @@ export const useChat = () => { setPrompt, selectedModel, tools, + setAbortController, updateLoadingModel, ] ) diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index 4004a3bdb..78a16b2ed 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -110,6 +110,7 @@ export const sendCompletion = async ( thread: Thread, provider: ModelProvider, messages: ChatCompletionMessageParam[], + abortController: AbortController, tools: MCPTool[] = [] ): Promise => { if (!thread?.model?.id || !provider) return undefined @@ -126,14 +127,19 @@ export const sendCompletion = async ( }) // TODO: Add message history - const completion = await tokenJS.chat.completions.create({ - stream: true, - provider: providerName, - model: thread.model?.id, - messages, - tools: normalizeTools(tools), - tool_choice: tools.length ? 'auto' : undefined, - }) + const completion = await tokenJS.chat.completions.create( + { + stream: true, + provider: providerName, + model: thread.model?.id, + messages, + tools: normalizeTools(tools), + tool_choice: tools.length ? 'auto' : undefined, + }, + { + signal: abortController.signal, + } + ) return completion } diff --git a/web-app/src/routes/threads/$threadId.tsx b/web-app/src/routes/threads/$threadId.tsx index bb7a297a5..d99350ca6 100644 --- a/web-app/src/routes/threads/$threadId.tsx +++ b/web-app/src/routes/threads/$threadId.tsx @@ -94,12 +94,16 @@ function ThreadDetail() { useEffect(() => { // Only auto-scroll when the user is not actively scrolling // AND either at the bottom OR there's streaming content - if (!isUserScrolling && (streamingContent || isAtBottom)) { + if ( + !isUserScrolling && + (streamingContent || isAtBottom) && + messages?.length + ) { // Use non-smooth scrolling for auto-scroll to prevent jank scrollToBottom(false) } // eslint-disable-next-line react-hooks/exhaustive-deps - }, [streamingContent, isUserScrolling]) + }, [streamingContent, isUserScrolling, messages]) const scrollToBottom = (smooth = false) => { if (scrollContainerRef.current) { @@ -194,7 +198,7 @@ function ThreadDetail() {
) })} - +