feat: handle stop streaming message, scroll to bottom and model loads (#5023)
This commit is contained in:
parent
b69a9ceb0f
commit
f6433544af
@ -185,7 +185,7 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine {
|
|||||||
console.log('Loaded models:', loadedModels)
|
console.log('Loaded models:', loadedModels)
|
||||||
|
|
||||||
// This is to avoid loading the same model multiple times
|
// This is to avoid loading the same model multiple times
|
||||||
if (loadedModels.some((model) => model.id === model.id)) {
|
if (loadedModels.some((e) => e.id === model.id)) {
|
||||||
console.log(`Model ${model.id} already loaded`)
|
console.log(`Model ${model.id} already loaded`)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import TextareaAutosize from 'react-textarea-autosize'
|
import TextareaAutosize from 'react-textarea-autosize'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
import { usePrompt } from '@/hooks/usePrompt'
|
import { usePrompt } from '@/hooks/usePrompt'
|
||||||
import { useEffect, useRef, useState } from 'react'
|
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { ArrowRight } from 'lucide-react'
|
import { ArrowRight } from 'lucide-react'
|
||||||
import {
|
import {
|
||||||
@ -44,7 +44,7 @@ const ChatInput = ({
|
|||||||
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
const textareaRef = useRef<HTMLTextAreaElement>(null)
|
||||||
const [isFocused, setIsFocused] = useState(false)
|
const [isFocused, setIsFocused] = useState(false)
|
||||||
const [rows, setRows] = useState(1)
|
const [rows, setRows] = useState(1)
|
||||||
const { streamingContent, updateTools } = useAppState()
|
const { streamingContent, updateTools, abortControllers } = useAppState()
|
||||||
const { prompt, setPrompt } = usePrompt()
|
const { prompt, setPrompt } = usePrompt()
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { spellCheckChatInput } = useGeneralSetting()
|
const { spellCheckChatInput } = useGeneralSetting()
|
||||||
@ -97,6 +97,13 @@ const ChatInput = ({
|
|||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
const stopStreaming = useCallback(
|
||||||
|
(threadId: string) => {
|
||||||
|
abortControllers[threadId]?.abort()
|
||||||
|
},
|
||||||
|
[abortControllers]
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="relative">
|
<div className="relative">
|
||||||
<div
|
<div
|
||||||
@ -218,7 +225,11 @@ const ChatInput = ({
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{streamingContent ? (
|
{streamingContent ? (
|
||||||
<Button variant="destructive" size="icon">
|
<Button
|
||||||
|
variant="destructive"
|
||||||
|
size="icon"
|
||||||
|
onClick={() => stopStreaming(streamingContent.thread_id)}
|
||||||
|
>
|
||||||
<IconPlayerStopFilled />
|
<IconPlayerStopFilled />
|
||||||
</Button>
|
</Button>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@ -2,11 +2,15 @@ import { useAppState } from '@/hooks/useAppState'
|
|||||||
import { ThreadContent } from './ThreadContent'
|
import { ThreadContent } from './ThreadContent'
|
||||||
import { memo } from 'react'
|
import { memo } from 'react'
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
threadId: string
|
||||||
|
}
|
||||||
|
|
||||||
// Use memo with no dependencies to allow re-renders when props change
|
// Use memo with no dependencies to allow re-renders when props change
|
||||||
export const StreamingContent = memo(() => {
|
export const StreamingContent = memo(({ threadId }: Props) => {
|
||||||
const { streamingContent } = useAppState()
|
const { streamingContent } = useAppState()
|
||||||
|
|
||||||
if (!streamingContent) return null
|
if (!streamingContent || streamingContent.thread_id !== threadId) return null
|
||||||
|
|
||||||
// Pass a new object to ThreadContent to avoid reference issues
|
// Pass a new object to ThreadContent to avoid reference issues
|
||||||
// The streaming content is always the last message
|
// The streaming content is always the last message
|
||||||
|
|||||||
@ -7,10 +7,12 @@ type AppState = {
|
|||||||
loadingModel?: boolean
|
loadingModel?: boolean
|
||||||
tools: MCPTool[]
|
tools: MCPTool[]
|
||||||
serverStatus: 'running' | 'stopped' | 'pending'
|
serverStatus: 'running' | 'stopped' | 'pending'
|
||||||
|
abortControllers: Record<string, AbortController>
|
||||||
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
updateStreamingContent: (content: ThreadMessage | undefined) => void
|
||||||
updateLoadingModel: (loading: boolean) => void
|
updateLoadingModel: (loading: boolean) => void
|
||||||
updateTools: (tools: MCPTool[]) => void
|
updateTools: (tools: MCPTool[]) => void
|
||||||
|
setAbortController: (threadId: string, controller: AbortController) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useAppState = create<AppState>()((set) => ({
|
export const useAppState = create<AppState>()((set) => ({
|
||||||
@ -18,6 +20,7 @@ export const useAppState = create<AppState>()((set) => ({
|
|||||||
loadingModel: false,
|
loadingModel: false,
|
||||||
tools: [],
|
tools: [],
|
||||||
serverStatus: 'stopped',
|
serverStatus: 'stopped',
|
||||||
|
abortControllers: {},
|
||||||
updateStreamingContent: (content) => {
|
updateStreamingContent: (content) => {
|
||||||
set({ streamingContent: content })
|
set({ streamingContent: content })
|
||||||
},
|
},
|
||||||
@ -28,4 +31,12 @@ export const useAppState = create<AppState>()((set) => ({
|
|||||||
set({ tools })
|
set({ tools })
|
||||||
},
|
},
|
||||||
setServerStatus: (value) => set({ serverStatus: value }),
|
setServerStatus: (value) => set({ serverStatus: value }),
|
||||||
|
setAbortController: (threadId, controller) => {
|
||||||
|
set((state) => ({
|
||||||
|
abortControllers: {
|
||||||
|
...state.abortControllers,
|
||||||
|
[threadId]: controller,
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
},
|
||||||
}))
|
}))
|
||||||
|
|||||||
@ -27,7 +27,8 @@ export const useChat = () => {
|
|||||||
useModelProvider()
|
useModelProvider()
|
||||||
|
|
||||||
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
const { getCurrentThread: retrieveThread, createThread } = useThreads()
|
||||||
const { updateStreamingContent, updateLoadingModel } = useAppState()
|
const { updateStreamingContent, updateLoadingModel, setAbortController } =
|
||||||
|
useAppState()
|
||||||
const { addMessage } = useMessages()
|
const { addMessage } = useMessages()
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
|
|
||||||
@ -83,12 +84,14 @@ export const useChat = () => {
|
|||||||
builder.addUserMessage(message)
|
builder.addUserMessage(message)
|
||||||
|
|
||||||
let isCompleted = false
|
let isCompleted = false
|
||||||
|
const abortController = new AbortController()
|
||||||
|
setAbortController(activeThread.id, abortController)
|
||||||
while (!isCompleted) {
|
while (!isCompleted) {
|
||||||
const completion = await sendCompletion(
|
const completion = await sendCompletion(
|
||||||
activeThread,
|
activeThread,
|
||||||
provider,
|
provider,
|
||||||
builder.getMessages(),
|
builder.getMessages(),
|
||||||
|
abortController,
|
||||||
tools
|
tools
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,6 +144,7 @@ export const useChat = () => {
|
|||||||
setPrompt,
|
setPrompt,
|
||||||
selectedModel,
|
selectedModel,
|
||||||
tools,
|
tools,
|
||||||
|
setAbortController,
|
||||||
updateLoadingModel,
|
updateLoadingModel,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@ -110,6 +110,7 @@ export const sendCompletion = async (
|
|||||||
thread: Thread,
|
thread: Thread,
|
||||||
provider: ModelProvider,
|
provider: ModelProvider,
|
||||||
messages: ChatCompletionMessageParam[],
|
messages: ChatCompletionMessageParam[],
|
||||||
|
abortController: AbortController,
|
||||||
tools: MCPTool[] = []
|
tools: MCPTool[] = []
|
||||||
): Promise<StreamCompletionResponse | undefined> => {
|
): Promise<StreamCompletionResponse | undefined> => {
|
||||||
if (!thread?.model?.id || !provider) return undefined
|
if (!thread?.model?.id || !provider) return undefined
|
||||||
@ -126,14 +127,19 @@ export const sendCompletion = async (
|
|||||||
})
|
})
|
||||||
|
|
||||||
// TODO: Add message history
|
// TODO: Add message history
|
||||||
const completion = await tokenJS.chat.completions.create({
|
const completion = await tokenJS.chat.completions.create(
|
||||||
stream: true,
|
{
|
||||||
provider: providerName,
|
stream: true,
|
||||||
model: thread.model?.id,
|
provider: providerName,
|
||||||
messages,
|
model: thread.model?.id,
|
||||||
tools: normalizeTools(tools),
|
messages,
|
||||||
tool_choice: tools.length ? 'auto' : undefined,
|
tools: normalizeTools(tools),
|
||||||
})
|
tool_choice: tools.length ? 'auto' : undefined,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal: abortController.signal,
|
||||||
|
}
|
||||||
|
)
|
||||||
return completion
|
return completion
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -94,12 +94,16 @@ function ThreadDetail() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// Only auto-scroll when the user is not actively scrolling
|
// Only auto-scroll when the user is not actively scrolling
|
||||||
// AND either at the bottom OR there's streaming content
|
// AND either at the bottom OR there's streaming content
|
||||||
if (!isUserScrolling && (streamingContent || isAtBottom)) {
|
if (
|
||||||
|
!isUserScrolling &&
|
||||||
|
(streamingContent || isAtBottom) &&
|
||||||
|
messages?.length
|
||||||
|
) {
|
||||||
// Use non-smooth scrolling for auto-scroll to prevent jank
|
// Use non-smooth scrolling for auto-scroll to prevent jank
|
||||||
scrollToBottom(false)
|
scrollToBottom(false)
|
||||||
}
|
}
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [streamingContent, isUserScrolling])
|
}, [streamingContent, isUserScrolling, messages])
|
||||||
|
|
||||||
const scrollToBottom = (smooth = false) => {
|
const scrollToBottom = (smooth = false) => {
|
||||||
if (scrollContainerRef.current) {
|
if (scrollContainerRef.current) {
|
||||||
@ -194,7 +198,7 @@ function ThreadDetail() {
|
|||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
<StreamingContent />
|
<StreamingContent threadId={threadId} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="w-4/6 mx-auto pt-2 pb-3 shrink-0 relative">
|
<div className="w-4/6 mx-auto pt-2 pb-3 shrink-0 relative">
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user