diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index ba3399b24..a74f059a0 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -59,6 +59,8 @@ export enum MessageStatus { Ready = 'ready', /** Message is not fully loaded. **/ Pending = 'pending', + /** Message loaded with error. **/ + Error = 'error', } /** diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 4aa15a7a9..946d526dd 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -223,12 +223,14 @@ export default class JanInferenceNitroExtension implements InferenceExtension { events.emit(EventName.OnMessageUpdate, message); }, complete: async () => { - message.status = MessageStatus.Ready; + message.status = message.content.length + ? MessageStatus.Ready + : MessageStatus.Error; events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { - if (instance.isCancelled || message.content.length > 0) { - message.status = MessageStatus.Ready; + if (instance.isCancelled || message.content.length) { + message.status = MessageStatus.Error; events.emit(EventName.OnMessageUpdate, message); return; } diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index 64c429664..0f3a5064d 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -213,12 +213,14 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { events.emit(EventName.OnMessageUpdate, message); }, complete: async () => { - message.status = MessageStatus.Ready; + message.status = message.content.length + ? MessageStatus.Ready + : MessageStatus.Error; events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { if (instance.isCancelled || message.content.length > 0) { - message.status = MessageStatus.Ready; + message.status = MessageStatus.Error; events.emit(EventName.OnMessageUpdate, message); return; } diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts index 103d61f68..0720ed3ac 100644 --- a/extensions/inference-triton-trtllm-extension/src/index.ts +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -237,10 +237,17 @@ export default class JanInferenceTritonTrtLLMExtension events.emit(EventName.OnMessageUpdate, message); }, complete: async () => { - message.status = MessageStatus.Ready; + message.status = message.content.length + ? MessageStatus.Ready + : MessageStatus.Error; events.emit(EventName.OnMessageUpdate, message); }, error: async (err) => { + if (instance.isCancelled || message.content.length) { + message.status = MessageStatus.Error; + events.emit(EventName.OnMessageUpdate, message); + return; + } const messageContent: ThreadContent = { type: ContentType.Text, text: { diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 2502340ba..a828a02a1 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -92,7 +92,10 @@ export default function EventHandler({ children }: { children: ReactNode }) { message.content, message.status ) - if (message.status === MessageStatus.Ready) { + if ( + message.status === MessageStatus.Ready || + message.status === MessageStatus.Error + ) { // Mark the thread as not waiting for response updateThreadWaiting(message.thread_id, false) diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts index 0705901c3..65618c7fa 100644 --- a/web/hooks/useSetActiveThread.ts +++ b/web/hooks/useSetActiveThread.ts @@ -1,4 +1,4 @@ -import { ExtensionType, Thread } from '@janhq/core' +import { EventName, ExtensionType, Thread, events } from '@janhq/core' import { ConversationalExtension } from '@janhq/core' @@ -22,6 +22,8 @@ export default function useSetActiveThread() { return } + events.emit(EventName.OnInferenceStopped, thread.id) + // load the corresponding messages const messages = await extensionManager .get(ExtensionType.Conversational) diff --git a/web/screens/Chat/ChatBody/index.tsx b/web/screens/Chat/ChatBody/index.tsx index 27f3eb8e6..2129c2554 100644 --- a/web/screens/Chat/ChatBody/index.tsx +++ b/web/screens/Chat/ChatBody/index.tsx @@ -2,9 +2,17 @@ import { Fragment } from 'react' import ScrollToBottom from 'react-scroll-to-bottom' -import { InferenceEngine } from '@janhq/core' +import { + ChatCompletionRole, + ConversationalExtension, + ExtensionType, + InferenceEngine, + MessageStatus, +} from '@janhq/core' import { Button } from '@janhq/uikit' -import { useAtomValue } from 'jotai' +import { useAtomValue, useSetAtom } from 'jotai' + +import { RefreshCcw } from 'lucide-react' import LogoMark from '@/containers/Brand/Logo/Mark' @@ -14,14 +22,45 @@ import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels' import { useMainViewState } from '@/hooks/useMainViewState' +import useSendChatMessage from '@/hooks/useSendChatMessage' + import ChatItem from '../ChatItem' -import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' +import { extensionManager } from '@/extension' +import { + deleteMessageAtom, + getCurrentChatMessagesAtom, +} from '@/helpers/atoms/ChatMessage.atom' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' const ChatBody: React.FC = () => { const messages = useAtomValue(getCurrentChatMessagesAtom) const { downloadedModels } = useGetDownloadedModels() const { setMainViewState } = useMainViewState() + const thread = useAtomValue(activeThreadAtom) + const deleteMessage = useSetAtom(deleteMessageAtom) + const { resendChatMessage } = useSendChatMessage() + + const regenerateMessage = async () => { + const lastMessageIndex = messages.length - 1 + const message = messages[lastMessageIndex] + if (message.role !== ChatCompletionRole.User) { + // Delete last response before regenerating + deleteMessage(message.id ?? '') + if (thread) { + await extensionManager + .get(ExtensionType.Conversational) + ?.writeMessages( + thread.id, + messages.filter((msg) => msg.id !== message.id) + ) + } + const targetMessage = messages[lastMessageIndex - 1] + if (targetMessage) resendChatMessage(targetMessage) + } else { + resendChatMessage(message) + } + } if (downloadedModels.length === 0) return ( @@ -76,8 +115,35 @@ const ChatBody: React.FC = () => { ) : ( - {messages.map((message) => ( - + {messages.map((message, index) => ( + <> + {message.content.length ? ( + + ) : ( + <> + )} + {message.status === MessageStatus.Error && + index === messages.length - 1 && ( +
+ + Whoops! Jan's generation was interrupted. Let's + give it another go! + + +
+ )} + ))}
)} diff --git a/web/screens/Chat/MessageToolbar/index.tsx b/web/screens/Chat/MessageToolbar/index.tsx index 87004cc3b..721b0476b 100644 --- a/web/screens/Chat/MessageToolbar/index.tsx +++ b/web/screens/Chat/MessageToolbar/index.tsx @@ -47,19 +47,20 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => { resendChatMessage(message) } - if (message.status !== MessageStatus.Ready) return null + if (message.status === MessageStatus.Pending) return null return (
- {message.id === messages[messages.length - 1]?.id && ( -
- -
- )} + {message.id === messages[messages.length - 1]?.id && + messages[messages.length - 1].status !== MessageStatus.Error && ( +
+ +
+ )}
{ diff --git a/web/screens/Chat/SimpleTextMessage/index.tsx b/web/screens/Chat/SimpleTextMessage/index.tsx index e9107d6af..75ce5b24a 100644 --- a/web/screens/Chat/SimpleTextMessage/index.tsx +++ b/web/screens/Chat/SimpleTextMessage/index.tsx @@ -99,7 +99,10 @@ const SimpleTextMessage: React.FC = (props) => { }, []) useEffect(() => { - if (props.status === MessageStatus.Ready) { + if ( + props.status === MessageStatus.Ready || + props.status === MessageStatus.Error + ) { return } const currentTimestamp = new Date().getTime() // Get current time in milliseconds