feat: handle stop streaming message, scroll to bottom and model loads (#5023)

This commit is contained in:
Louis 2025-05-19 23:32:55 +07:00 committed by GitHub
parent b69a9ceb0f
commit f6433544af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 59 additions and 19 deletions

View File

@ -185,7 +185,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
console.log('Loaded models:', loadedModels)
// This is to avoid loading the same model multiple times
if (loadedModels.some((model) => model.id === model.id)) {
if (loadedModels.some((e) => e.id === model.id)) {
console.log(`Model ${model.id} already loaded`)
return
}

View File

@ -3,7 +3,7 @@
import TextareaAutosize from 'react-textarea-autosize'
import { cn } from '@/lib/utils'
import { usePrompt } from '@/hooks/usePrompt'
import { useEffect, useRef, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { Button } from '@/components/ui/button'
import { ArrowRight } from 'lucide-react'
import {
@ -44,7 +44,7 @@ const ChatInput = ({
const textareaRef = useRef<HTMLTextAreaElement>(null)
const [isFocused, setIsFocused] = useState(false)
const [rows, setRows] = useState(1)
const { streamingContent, updateTools } = useAppState()
const { streamingContent, updateTools, abortControllers } = useAppState()
const { prompt, setPrompt } = usePrompt()
const { t } = useTranslation()
const { spellCheckChatInput } = useGeneralSetting()
@ -97,6 +97,13 @@ const ChatInput = ({
}
}, [])
const stopStreaming = useCallback(
(threadId: string) => {
abortControllers[threadId]?.abort()
},
[abortControllers]
)
return (
<div className="relative">
<div
@ -218,7 +225,11 @@ const ChatInput = ({
</div>
{streamingContent ? (
<Button variant="destructive" size="icon">
<Button
variant="destructive"
size="icon"
onClick={() => stopStreaming(streamingContent.thread_id)}
>
<IconPlayerStopFilled />
</Button>
) : (

View File

@ -2,11 +2,15 @@ import { useAppState } from '@/hooks/useAppState'
import { ThreadContent } from './ThreadContent'
import { memo } from 'react'
type Props = {
threadId: string
}
// Use memo with no dependencies to allow re-renders when props change
export const StreamingContent = memo(() => {
export const StreamingContent = memo(({ threadId }: Props) => {
const { streamingContent } = useAppState()
if (!streamingContent) return null
if (!streamingContent || streamingContent.thread_id !== threadId) return null
// Pass a new object to ThreadContent to avoid reference issues
// The streaming content is always the last message

View File

@ -7,10 +7,12 @@ type AppState = {
loadingModel?: boolean
tools: MCPTool[]
serverStatus: 'running' | 'stopped' | 'pending'
abortControllers: Record<string, AbortController>
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void
updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
setAbortController: (threadId: string, controller: AbortController) => void
}
export const useAppState = create<AppState>()((set) => ({
@ -18,6 +20,7 @@ export const useAppState = create<AppState>()((set) => ({
loadingModel: false,
tools: [],
serverStatus: 'stopped',
abortControllers: {},
updateStreamingContent: (content) => {
set({ streamingContent: content })
},
@ -28,4 +31,12 @@ export const useAppState = create<AppState>()((set) => ({
set({ tools })
},
setServerStatus: (value) => set({ serverStatus: value }),
setAbortController: (threadId, controller) => {
set((state) => ({
abortControllers: {
...state.abortControllers,
[threadId]: controller,
},
}))
},
}))

View File

@ -27,7 +27,8 @@ export const useChat = () => {
useModelProvider()
const { getCurrentThread: retrieveThread, createThread } = useThreads()
const { updateStreamingContent, updateLoadingModel } = useAppState()
const { updateStreamingContent, updateLoadingModel, setAbortController } =
useAppState()
const { addMessage } = useMessages()
const router = useRouter()
@ -83,12 +84,14 @@ export const useChat = () => {
builder.addUserMessage(message)
let isCompleted = false
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
while (!isCompleted) {
const completion = await sendCompletion(
activeThread,
provider,
builder.getMessages(),
abortController,
tools
)
@ -141,6 +144,7 @@ export const useChat = () => {
setPrompt,
selectedModel,
tools,
setAbortController,
updateLoadingModel,
]
)

View File

@ -110,6 +110,7 @@ export const sendCompletion = async (
thread: Thread,
provider: ModelProvider,
messages: ChatCompletionMessageParam[],
abortController: AbortController,
tools: MCPTool[] = []
): Promise<StreamCompletionResponse | undefined> => {
if (!thread?.model?.id || !provider) return undefined
@ -126,14 +127,19 @@ export const sendCompletion = async (
})
// TODO: Add message history
const completion = await tokenJS.chat.completions.create({
stream: true,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
})
const completion = await tokenJS.chat.completions.create(
{
stream: true,
provider: providerName,
model: thread.model?.id,
messages,
tools: normalizeTools(tools),
tool_choice: tools.length ? 'auto' : undefined,
},
{
signal: abortController.signal,
}
)
return completion
}

View File

@ -94,12 +94,16 @@ function ThreadDetail() {
useEffect(() => {
// Only auto-scroll when the user is not actively scrolling
// AND either at the bottom OR there's streaming content
if (!isUserScrolling && (streamingContent || isAtBottom)) {
if (
!isUserScrolling &&
(streamingContent || isAtBottom) &&
messages?.length
) {
// Use non-smooth scrolling for auto-scroll to prevent jank
scrollToBottom(false)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [streamingContent, isUserScrolling])
}, [streamingContent, isUserScrolling, messages])
const scrollToBottom = (smooth = false) => {
if (scrollContainerRef.current) {
@ -194,7 +198,7 @@ function ThreadDetail() {
</div>
)
})}
<StreamingContent />
<StreamingContent threadId={threadId} />
</div>
</div>
<div className="w-4/6 mx-auto pt-2 pb-3 shrink-0 relative">