feat: handle stop streaming message, scroll to bottom and model loads (#5023)
This commit is contained in:
parent
b69a9ceb0f
commit
f6433544af
@ -185,7 +185,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
||||
console.log('Loaded models:', loadedModels)
|
||||
|
||||
// This is to avoid loading the same model multiple times
|
||||
if (loadedModels.some((model) => model.id === model.id)) {
|
||||
if (loadedModels.some((e) => e.id === model.id)) {
|
||||
console.log(`Model ${model.id} already loaded`)
|
||||
return
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import TextareaAutosize from 'react-textarea-autosize'
|
||||
import { cn } from '@/lib/utils'
|
||||
import { usePrompt } from '@/hooks/usePrompt'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { ArrowRight } from 'lucide-react'
|
||||
import {
|
||||
@ -44,7 +44,7 @@ const ChatInput = ({
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||
const [isFocused, setIsFocused] = useState(false)
|
||||
const [rows, setRows] = useState(1)
|
||||
const { streamingContent, updateTools } = useAppState()
|
||||
const { streamingContent, updateTools, abortControllers } = useAppState()
|
||||
const { prompt, setPrompt } = usePrompt()
|
||||
const { t } = useTranslation()
|
||||
const { spellCheckChatInput } = useGeneralSetting()
|
||||
@ -97,6 +97,13 @@ const ChatInput = ({
|
||||
}
|
||||
}, [])
|
||||
|
||||
const stopStreaming = useCallback(
|
||||
(threadId: string) => {
|
||||
abortControllers[threadId]?.abort()
|
||||
},
|
||||
[abortControllers]
|
||||
)
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div
|
||||
@ -218,7 +225,11 @@ const ChatInput = ({
|
||||
</div>
|
||||
|
||||
{streamingContent ? (
|
||||
<Button variant="destructive" size="icon">
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="icon"
|
||||
onClick={() => stopStreaming(streamingContent.thread_id)}
|
||||
>
|
||||
<IconPlayerStopFilled />
|
||||
</Button>
|
||||
) : (
|
||||
|
||||
@ -2,11 +2,15 @@ import { useAppState } from '@/hooks/useAppState'
|
||||
import { ThreadContent } from './ThreadContent'
|
||||
import { memo } from 'react'
|
||||
|
||||
type Props = {
|
||||
threadId: string
|
||||
}
|
||||
|
||||
// Use memo with no dependencies to allow re-renders when props change
|
||||
export const StreamingContent = memo(() => {
|
||||
export const StreamingContent = memo(({ threadId }: Props) => {
|
||||
const { streamingContent } = useAppState()
|
||||
|
||||
if (!streamingContent) return null
|
||||
if (!streamingContent || streamingContent.thread_id !== threadId) return null
|
||||
|
||||
// Pass a new object to ThreadContent to avoid reference issues
|
||||
// The streaming content is always the last message
|
||||
|
||||
@ -7,10 +7,12 @@ type AppState = {
|
||||
loadingModel?: boolean
|
||||
tools: MCPTool[]
|
||||
serverStatus: 'running' | 'stopped' | 'pending'
|
||||
abortControllers: Record<string, AbortController>
|
||||
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||
updateLoadingModel: (loading: boolean) => void
|
||||
updateTools: (tools: MCPTool[]) => void
|
||||
setAbortController: (threadId: string, controller: AbortController) => void
|
||||
}
|
||||
|
||||
export const useAppState = create<AppState>()((set) => ({
|
||||
@ -18,6 +20,7 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
loadingModel: false,
|
||||
tools: [],
|
||||
serverStatus: 'stopped',
|
||||
abortControllers: {},
|
||||
updateStreamingContent: (content) => {
|
||||
set({ streamingContent: content })
|
||||
},
|
||||
@ -28,4 +31,12 @@ export const useAppState = create<AppState>()((set) => ({
|
||||
set({ tools })
|
||||
},
|
||||
setServerStatus: (value) => set({ serverStatus: value }),
|
||||
setAbortController: (threadId, controller) => {
|
||||
set((state) => ({
|
||||
abortControllers: {
|
||||
...state.abortControllers,
|
||||
[threadId]: controller,
|
||||
},
|
||||
}))
|
||||
},
|
||||
}))
|
||||
|
||||
@ -27,7 +27,8 @@ export const useChat = () => {
|
||||
useModelProvider()
|
||||
|
||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||
const { updateStreamingContent, updateLoadingModel } = useAppState()
|
||||
const { updateStreamingContent, updateLoadingModel, setAbortController } =
|
||||
useAppState()
|
||||
const { addMessage } = useMessages()
|
||||
const router = useRouter()
|
||||
|
||||
@ -83,12 +84,14 @@ export const useChat = () => {
|
||||
builder.addUserMessage(message)
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
const abortController = new AbortController()
|
||||
setAbortController(activeThread.id, abortController)
|
||||
while (!isCompleted) {
|
||||
const completion = await sendCompletion(
|
||||
activeThread,
|
||||
provider,
|
||||
builder.getMessages(),
|
||||
abortController,
|
||||
tools
|
||||
)
|
||||
|
||||
@ -141,6 +144,7 @@ export const useChat = () => {
|
||||
setPrompt,
|
||||
selectedModel,
|
||||
tools,
|
||||
setAbortController,
|
||||
updateLoadingModel,
|
||||
]
|
||||
)
|
||||
|
||||
@ -110,6 +110,7 @@ export const sendCompletion = async (
|
||||
thread: Thread,
|
||||
provider: ModelProvider,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
abortController: AbortController,
|
||||
tools: MCPTool[] = []
|
||||
): Promise<StreamCompletionResponse | undefined> => {
|
||||
if (!thread?.model?.id || !provider) return undefined
|
||||
@ -126,14 +127,19 @@ export const sendCompletion = async (
|
||||
})
|
||||
|
||||
// TODO: Add message history
|
||||
const completion = await tokenJS.chat.completions.create({
|
||||
stream: true,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
})
|
||||
const completion = await tokenJS.chat.completions.create(
|
||||
{
|
||||
stream: true,
|
||||
provider: providerName,
|
||||
model: thread.model?.id,
|
||||
messages,
|
||||
tools: normalizeTools(tools),
|
||||
tool_choice: tools.length ? 'auto' : undefined,
|
||||
},
|
||||
{
|
||||
signal: abortController.signal,
|
||||
}
|
||||
)
|
||||
return completion
|
||||
}
|
||||
|
||||
|
||||
@ -94,12 +94,16 @@ function ThreadDetail() {
|
||||
useEffect(() => {
|
||||
// Only auto-scroll when the user is not actively scrolling
|
||||
// AND either at the bottom OR there's streaming content
|
||||
if (!isUserScrolling && (streamingContent || isAtBottom)) {
|
||||
if (
|
||||
!isUserScrolling &&
|
||||
(streamingContent || isAtBottom) &&
|
||||
messages?.length
|
||||
) {
|
||||
// Use non-smooth scrolling for auto-scroll to prevent jank
|
||||
scrollToBottom(false)
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [streamingContent, isUserScrolling])
|
||||
}, [streamingContent, isUserScrolling, messages])
|
||||
|
||||
const scrollToBottom = (smooth = false) => {
|
||||
if (scrollContainerRef.current) {
|
||||
@ -194,7 +198,7 @@ function ThreadDetail() {
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
<StreamingContent />
|
||||
<StreamingContent threadId={threadId} />
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-4/6 mx-auto pt-2 pb-3 shrink-0 relative">
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user