fix: #1096 yield error message upon thread switching (#1109)

This commit is contained in:
Louis 2023-12-20 14:58:47 +07:00 committed by GitHub
parent 9b448d478c
commit e0370210a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 111 additions and 23 deletions

View File

@ -59,6 +59,8 @@ export enum MessageStatus {
Ready = 'ready',
/** Message is not fully loaded. **/
Pending = 'pending',
/** Message loaded with error. **/
Error = 'error',
}
/**

View File

@ -223,12 +223,14 @@ export default class JanInferenceNitroExtension implements InferenceExtension {
events.emit(EventName.OnMessageUpdate, message);
},
complete: async () => {
message.status = MessageStatus.Ready;
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Ready;
if (instance.isCancelled || message.content.length) {
message.status = MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
return;
}

View File

@ -213,12 +213,14 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension {
events.emit(EventName.OnMessageUpdate, message);
},
complete: async () => {
message.status = MessageStatus.Ready;
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
if (instance.isCancelled || message.content.length > 0) {
message.status = MessageStatus.Ready;
message.status = MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
return;
}

View File

@ -237,10 +237,17 @@ export default class JanInferenceTritonTrtLLMExtension
events.emit(EventName.OnMessageUpdate, message);
},
complete: async () => {
message.status = MessageStatus.Ready;
message.status = message.content.length
? MessageStatus.Ready
: MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
},
error: async (err) => {
if (instance.isCancelled || message.content.length) {
message.status = MessageStatus.Error;
events.emit(EventName.OnMessageUpdate, message);
return;
}
const messageContent: ThreadContent = {
type: ContentType.Text,
text: {

View File

@ -92,7 +92,10 @@ export default function EventHandler({ children }: { children: ReactNode }) {
message.content,
message.status
)
if (message.status === MessageStatus.Ready) {
if (
message.status === MessageStatus.Ready ||
message.status === MessageStatus.Error
) {
// Mark the thread as not waiting for response
updateThreadWaiting(message.thread_id, false)

View File

@ -1,4 +1,4 @@
import { ExtensionType, Thread } from '@janhq/core'
import { EventName, ExtensionType, Thread, events } from '@janhq/core'
import { ConversationalExtension } from '@janhq/core'
@ -22,6 +22,8 @@ export default function useSetActiveThread() {
return
}
events.emit(EventName.OnInferenceStopped, thread.id)
// load the corresponding messages
const messages = await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)

View File

@ -2,9 +2,17 @@ import { Fragment } from 'react'
import ScrollToBottom from 'react-scroll-to-bottom'
import { InferenceEngine } from '@janhq/core'
import {
ChatCompletionRole,
ConversationalExtension,
ExtensionType,
InferenceEngine,
MessageStatus,
} from '@janhq/core'
import { Button } from '@janhq/uikit'
import { useAtomValue } from 'jotai'
import { useAtomValue, useSetAtom } from 'jotai'
import { RefreshCcw } from 'lucide-react'
import LogoMark from '@/containers/Brand/Logo/Mark'
@ -14,14 +22,45 @@ import { useGetDownloadedModels } from '@/hooks/useGetDownloadedModels'
import { useMainViewState } from '@/hooks/useMainViewState'
import useSendChatMessage from '@/hooks/useSendChatMessage'
import ChatItem from '../ChatItem'
import { getCurrentChatMessagesAtom } from '@/helpers/atoms/ChatMessage.atom'
import { extensionManager } from '@/extension'
import {
deleteMessageAtom,
getCurrentChatMessagesAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
const ChatBody: React.FC = () => {
const messages = useAtomValue(getCurrentChatMessagesAtom)
const { downloadedModels } = useGetDownloadedModels()
const { setMainViewState } = useMainViewState()
const thread = useAtomValue(activeThreadAtom)
const deleteMessage = useSetAtom(deleteMessageAtom)
const { resendChatMessage } = useSendChatMessage()
const regenerateMessage = async () => {
const lastMessageIndex = messages.length - 1
const message = messages[lastMessageIndex]
if (message.role !== ChatCompletionRole.User) {
// Delete last response before regenerating
deleteMessage(message.id ?? '')
if (thread) {
await extensionManager
.get<ConversationalExtension>(ExtensionType.Conversational)
?.writeMessages(
thread.id,
messages.filter((msg) => msg.id !== message.id)
)
}
const targetMessage = messages[lastMessageIndex - 1]
if (targetMessage) resendChatMessage(targetMessage)
} else {
resendChatMessage(message)
}
}
if (downloadedModels.length === 0)
return (
@ -76,8 +115,35 @@ const ChatBody: React.FC = () => {
</div>
) : (
<ScrollToBottom className="flex h-full w-full flex-col">
{messages.map((message) => (
<ChatItem {...message} key={message.id} />
{messages.map((message, index) => (
<>
{message.content.length ? (
<ChatItem {...message} key={message.id} />
) : (
<></>
)}
{message.status === MessageStatus.Error &&
index === messages.length - 1 && (
<div
key={message.id}
className="mt-10 flex flex-col items-center"
>
<span className="mb-3 text-center text-sm font-medium text-gray-500">
Whoops! Jan&apos;s generation was interrupted. Let&apos;s
give it another go!
</span>
<Button
className="w-min"
themes="outline"
onClick={regenerateMessage}
>
<RefreshCcw size={14} className="" />
<span className="w-2" />
Regenerate
</Button>
</div>
)}
</>
))}
</ScrollToBottom>
)}

View File

@ -47,19 +47,20 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
resendChatMessage(message)
}
if (message.status !== MessageStatus.Ready) return null
if (message.status === MessageStatus.Pending) return null
return (
<div className={twMerge('flex flex-row items-center')}>
<div className="flex overflow-hidden rounded-md border border-border bg-background/20">
{message.id === messages[messages.length - 1]?.id && (
<div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={onRegenerateClick}
>
<RefreshCcw size={14} />
</div>
)}
{message.id === messages[messages.length - 1]?.id &&
messages[messages.length - 1].status !== MessageStatus.Error && (
<div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={onRegenerateClick}
>
<RefreshCcw size={14} />
</div>
)}
<div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={() => {

View File

@ -99,7 +99,10 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
}, [])
useEffect(() => {
if (props.status === MessageStatus.Ready) {
if (
props.status === MessageStatus.Ready ||
props.status === MessageStatus.Error
) {
return
}
const currentTimestamp = new Date().getTime() // Get current time in milliseconds