fix: message should only be interrupted when i start another thread (#2053)
* fix: message should only be interrupted when i start another thread * fix: thread lost message streaming if navigate to another thread * fix: state issue with useThreads
This commit is contained in:
parent
6590ee7a6a
commit
47b890bba5
@ -177,7 +177,6 @@ export default function EventHandler({ children }: { children: ReactNode }) {
|
|||||||
)
|
)
|
||||||
if (message.status === MessageStatus.Pending) {
|
if (message.status === MessageStatus.Pending) {
|
||||||
if (message.content.length) {
|
if (message.content.length) {
|
||||||
updateThreadWaiting(message.thread_id, false)
|
|
||||||
setIsGeneratingResponse(false)
|
setIsGeneratingResponse(false)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@ -16,6 +16,8 @@ import {
|
|||||||
*/
|
*/
|
||||||
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
||||||
|
|
||||||
|
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the chat messages for the current active conversation
|
* Return the chat messages for the current active conversation
|
||||||
*/
|
*/
|
||||||
@ -34,6 +36,10 @@ export const setConvoMessagesAtom = atom(
|
|||||||
}
|
}
|
||||||
newData[threadId] = messages
|
newData[threadId] = messages
|
||||||
set(chatMessages, newData)
|
set(chatMessages, newData)
|
||||||
|
set(readyThreadsMessagesAtom, {
|
||||||
|
...get(readyThreadsMessagesAtom),
|
||||||
|
[threadId]: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import {
|
|||||||
ThreadState,
|
ThreadState,
|
||||||
Model,
|
Model,
|
||||||
AssistantTool,
|
AssistantTool,
|
||||||
|
events,
|
||||||
|
InferenceEvent,
|
||||||
} from '@janhq/core'
|
} from '@janhq/core'
|
||||||
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
import { atom, useAtomValue, useSetAtom } from 'jotai'
|
||||||
|
|
||||||
@ -30,6 +32,7 @@ import {
|
|||||||
threadStatesAtom,
|
threadStatesAtom,
|
||||||
updateThreadAtom,
|
updateThreadAtom,
|
||||||
setThreadModelParamsAtom,
|
setThreadModelParamsAtom,
|
||||||
|
isGeneratingResponseAtom,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => {
|
||||||
@ -57,6 +60,7 @@ export const useCreateNewThread = () => {
|
|||||||
const setSelectedModel = useSetAtom(selectedModelAtom)
|
const setSelectedModel = useSetAtom(selectedModelAtom)
|
||||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||||
const { experimentalFeature } = useContext(FeatureToggleContext)
|
const { experimentalFeature } = useContext(FeatureToggleContext)
|
||||||
|
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||||
|
|
||||||
const { recommendedModel, downloadedModels } = useRecommendedModel()
|
const { recommendedModel, downloadedModels } = useRecommendedModel()
|
||||||
|
|
||||||
@ -66,6 +70,10 @@ export const useCreateNewThread = () => {
|
|||||||
assistant: Assistant,
|
assistant: Assistant,
|
||||||
model?: Model | undefined
|
model?: Model | undefined
|
||||||
) => {
|
) => {
|
||||||
|
// Stop generating if any
|
||||||
|
setIsGeneratingResponse(false)
|
||||||
|
events.emit(InferenceEvent.OnInferenceStopped, {})
|
||||||
|
|
||||||
const defaultModel = model ?? recommendedModel ?? downloadedModels[0]
|
const defaultModel = model ?? recommendedModel ?? downloadedModels[0]
|
||||||
|
|
||||||
// check last thread message, if there empty last message use can not create thread
|
// check last thread message, if there empty last message use can not create thread
|
||||||
|
|||||||
@ -1,20 +1,14 @@
|
|||||||
import { useCallback } from 'react'
|
import { ExtensionTypeEnum, Thread, ConversationalExtension } from '@janhq/core'
|
||||||
|
|
||||||
import {
|
import { useAtomValue, useSetAtom } from 'jotai'
|
||||||
InferenceEvent,
|
|
||||||
ExtensionTypeEnum,
|
|
||||||
Thread,
|
|
||||||
events,
|
|
||||||
ConversationalExtension,
|
|
||||||
} from '@janhq/core'
|
|
||||||
|
|
||||||
import { useSetAtom } from 'jotai'
|
|
||||||
|
|
||||||
import { extensionManager } from '@/extension'
|
import { extensionManager } from '@/extension'
|
||||||
import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
import {
|
||||||
|
readyThreadsMessagesAtom,
|
||||||
|
setConvoMessagesAtom,
|
||||||
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import {
|
import {
|
||||||
ModelParams,
|
ModelParams,
|
||||||
isGeneratingResponseAtom,
|
|
||||||
setActiveThreadIdAtom,
|
setActiveThreadIdAtom,
|
||||||
setThreadModelParamsAtom,
|
setThreadModelParamsAtom,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
@ -23,16 +17,14 @@ export default function useSetActiveThread() {
|
|||||||
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
const setActiveThreadId = useSetAtom(setActiveThreadIdAtom)
|
||||||
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
|
const setThreadMessage = useSetAtom(setConvoMessagesAtom)
|
||||||
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
const setThreadModelParams = useSetAtom(setThreadModelParamsAtom)
|
||||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
const readyMessageThreads = useAtomValue(readyThreadsMessagesAtom)
|
||||||
|
|
||||||
const setActiveThread = useCallback(
|
const setActiveThread = async (thread: Thread) => {
|
||||||
async (thread: Thread) => {
|
// Load local messages only if there are no messages in the state
|
||||||
setIsGeneratingResponse(false)
|
if (!readyMessageThreads[thread.id]) {
|
||||||
events.emit(InferenceEvent.OnInferenceStopped, thread.id)
|
|
||||||
|
|
||||||
// load the corresponding messages
|
|
||||||
const messages = await getLocalThreadMessage(thread.id)
|
const messages = await getLocalThreadMessage(thread.id)
|
||||||
setThreadMessage(thread.id, messages)
|
setThreadMessage(thread.id, messages)
|
||||||
|
}
|
||||||
|
|
||||||
setActiveThreadId(thread.id)
|
setActiveThreadId(thread.id)
|
||||||
const modelParams: ModelParams = {
|
const modelParams: ModelParams = {
|
||||||
@ -40,14 +32,7 @@ export default function useSetActiveThread() {
|
|||||||
...thread.assistants[0]?.model?.settings,
|
...thread.assistants[0]?.model?.settings,
|
||||||
}
|
}
|
||||||
setThreadModelParams(thread.id, modelParams)
|
setThreadModelParams(thread.id, modelParams)
|
||||||
},
|
}
|
||||||
[
|
|
||||||
setActiveThreadId,
|
|
||||||
setThreadMessage,
|
|
||||||
setThreadModelParams,
|
|
||||||
setIsGeneratingResponse,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return { setActiveThread }
|
return { setActiveThread }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,8 +9,6 @@ import {
|
|||||||
|
|
||||||
import { useSetAtom } from 'jotai'
|
import { useSetAtom } from 'jotai'
|
||||||
|
|
||||||
import useSetActiveThread from './useSetActiveThread'
|
|
||||||
|
|
||||||
import { extensionManager } from '@/extension/ExtensionManager'
|
import { extensionManager } from '@/extension/ExtensionManager'
|
||||||
import {
|
import {
|
||||||
ModelParams,
|
ModelParams,
|
||||||
@ -24,7 +22,6 @@ const useThreads = () => {
|
|||||||
const setThreadStates = useSetAtom(threadStatesAtom)
|
const setThreadStates = useSetAtom(threadStatesAtom)
|
||||||
const setThreads = useSetAtom(threadsAtom)
|
const setThreads = useSetAtom(threadsAtom)
|
||||||
const setThreadModelRuntimeParams = useSetAtom(threadModelParamsAtom)
|
const setThreadModelRuntimeParams = useSetAtom(threadModelParamsAtom)
|
||||||
const { setActiveThread } = useSetActiveThread()
|
|
||||||
const setThreadDataReady = useSetAtom(threadDataReadyAtom)
|
const setThreadDataReady = useSetAtom(threadDataReadyAtom)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -56,16 +53,11 @@ const useThreads = () => {
|
|||||||
setThreadStates(localThreadStates)
|
setThreadStates(localThreadStates)
|
||||||
setThreads(localThreads)
|
setThreads(localThreads)
|
||||||
setThreadModelRuntimeParams(threadModelParams)
|
setThreadModelRuntimeParams(threadModelParams)
|
||||||
|
|
||||||
if (localThreads.length > 0) {
|
|
||||||
setActiveThread(localThreads[0])
|
|
||||||
}
|
|
||||||
setThreadDataReady(true)
|
setThreadDataReady(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
getThreads()
|
getThreads()
|
||||||
}, [
|
}, [
|
||||||
setActiveThread,
|
|
||||||
setThreadModelRuntimeParams,
|
setThreadModelRuntimeParams,
|
||||||
setThreadStates,
|
setThreadStates,
|
||||||
setThreads,
|
setThreads,
|
||||||
|
|||||||
@ -38,6 +38,8 @@ import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
|
|||||||
import {
|
import {
|
||||||
activeThreadAtom,
|
activeThreadAtom,
|
||||||
getActiveThreadIdAtom,
|
getActiveThreadIdAtom,
|
||||||
|
isGeneratingResponseAtom,
|
||||||
|
threadStatesAtom,
|
||||||
waitingToSendMessage,
|
waitingToSendMessage,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -57,6 +59,12 @@ const ChatInput: React.FC = () => {
|
|||||||
const imageInputRef = useRef<HTMLInputElement>(null)
|
const imageInputRef = useRef<HTMLInputElement>(null)
|
||||||
const [showAttacmentMenus, setShowAttacmentMenus] = useState(false)
|
const [showAttacmentMenus, setShowAttacmentMenus] = useState(false)
|
||||||
const { experimentalFeature } = useContext(FeatureToggleContext)
|
const { experimentalFeature } = useContext(FeatureToggleContext)
|
||||||
|
const isGeneratingResponse = useAtomValue(isGeneratingResponseAtom)
|
||||||
|
const threadStates = useAtomValue(threadStatesAtom)
|
||||||
|
|
||||||
|
const isStreamingResponse = Object.values(threadStates).some(
|
||||||
|
(threadState) => threadState.waitingForResponse
|
||||||
|
)
|
||||||
|
|
||||||
const onPromptChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
const onPromptChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
setCurrentPrompt(e.target.value)
|
setCurrentPrompt(e.target.value)
|
||||||
@ -235,7 +243,9 @@ const ChatInput: React.FC = () => {
|
|||||||
accept="application/pdf"
|
accept="application/pdf"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
{messages[messages.length - 1]?.status !== MessageStatus.Pending ? (
|
{messages[messages.length - 1]?.status !== MessageStatus.Pending &&
|
||||||
|
!isGeneratingResponse &&
|
||||||
|
!isStreamingResponse ? (
|
||||||
<Button
|
<Button
|
||||||
size="lg"
|
size="lg"
|
||||||
disabled={
|
disabled={
|
||||||
|
|||||||
@ -49,8 +49,17 @@ export default function ThreadList() {
|
|||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (threadDataReady && assistants.length > 0 && threads.length === 0) {
|
if (threadDataReady && assistants.length > 0 && threads.length === 0) {
|
||||||
requestCreateNewThread(assistants[0])
|
requestCreateNewThread(assistants[0])
|
||||||
|
} else if (threadDataReady && !activeThreadId) {
|
||||||
|
setActiveThread(threads[0])
|
||||||
}
|
}
|
||||||
}, [assistants, threads, threadDataReady, requestCreateNewThread])
|
}, [
|
||||||
|
assistants,
|
||||||
|
threads,
|
||||||
|
threadDataReady,
|
||||||
|
requestCreateNewThread,
|
||||||
|
activeThreadId,
|
||||||
|
setActiveThread,
|
||||||
|
])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="px-3 py-4">
|
<div className="px-3 py-4">
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user