fix: delete message break the entire thread (#869)

This commit is contained in:
Louis 2023-12-06 11:41:16 +07:00 committed by GitHub
parent 7656048e64
commit f3fbade975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 11 deletions

View File

@ -7,7 +7,9 @@ import { join } from 'path'
* JSONConversationalExtension is a ConversationalExtension implementation that provides * JSONConversationalExtension is a ConversationalExtension implementation that provides
* functionality for managing threads. * functionality for managing threads.
*/ */
export default class JSONConversationalExtension implements ConversationalExtension { export default class JSONConversationalExtension
implements ConversationalExtension
{
private static readonly _homeDir = 'threads' private static readonly _homeDir = 'threads'
private static readonly _threadInfoFileName = 'thread.json' private static readonly _threadInfoFileName = 'thread.json'
private static readonly _threadMessagesFileName = 'messages.jsonl' private static readonly _threadMessagesFileName = 'messages.jsonl'
@ -67,7 +69,10 @@ export default class JSONConversationalExtension implements ConversationalExtens
*/ */
async saveThread(thread: Thread): Promise<void> { async saveThread(thread: Thread): Promise<void> {
try { try {
const threadDirPath = join(JSONConversationalExtension._homeDir, thread.id) const threadDirPath = join(
JSONConversationalExtension._homeDir,
thread.id
)
const threadJsonPath = join( const threadJsonPath = join(
threadDirPath, threadDirPath,
JSONConversationalExtension._threadInfoFileName JSONConversationalExtension._threadInfoFileName
@ -119,7 +124,7 @@ export default class JSONConversationalExtension implements ConversationalExtens
await fs.mkdir(threadDirPath) await fs.mkdir(threadDirPath)
await fs.writeFile( await fs.writeFile(
threadMessagePath, threadMessagePath,
messages.map((msg) => JSON.stringify(msg)).join('\n') messages.map((msg) => JSON.stringify(msg)).join('\n') + '\n'
) )
Promise.resolve() Promise.resolve()
} catch (err) { } catch (err) {
@ -153,7 +158,10 @@ export default class JSONConversationalExtension implements ConversationalExtens
const threadDirs: string[] = [] const threadDirs: string[] = []
for (let i = 0; i < fileInsideThread.length; i++) { for (let i = 0; i < fileInsideThread.length; i++) {
const path = join(JSONConversationalExtension._homeDir, fileInsideThread[i]) const path = join(
JSONConversationalExtension._homeDir,
fileInsideThread[i]
)
const isDirectory = await fs.isDirectory(path) const isDirectory = await fs.isDirectory(path)
if (!isDirectory) { if (!isDirectory) {
console.debug(`Ignore ${path} because it is not a directory`) console.debug(`Ignore ${path} because it is not a directory`)
@ -182,7 +190,9 @@ export default class JSONConversationalExtension implements ConversationalExtens
} }
const files: string[] = await fs.listFiles(threadDirPath) const files: string[] = await fs.listFiles(threadDirPath)
if (!files.includes(JSONConversationalExtension._threadMessagesFileName)) { if (
!files.includes(JSONConversationalExtension._threadMessagesFileName)
) {
throw Error(`${threadDirPath} not contains message file`) throw Error(`${threadDirPath} not contains message file`)
} }

View File

@ -55,7 +55,7 @@ export default function useSendChatMessage() {
modelRef.current = activeModel modelRef.current = activeModel
}, [activeModel]) }, [activeModel])
const resendChatMessage = async () => { const resendChatMessage = async (currentMessage: ThreadMessage) => {
if (!activeThread) { if (!activeThread) {
console.error('No active thread') console.error('No active thread')
return return
@ -75,10 +75,16 @@ export default function useSendChatMessage() {
return systemMessage return systemMessage
}) })
.concat( .concat(
currentMessages.map<ChatCompletionMessage>((msg) => ({ currentMessages
role: msg.role, .filter(
content: msg.content[0]?.text.value ?? '', (e) =>
})) currentMessage.role === ChatCompletionRole.User ||
e.id !== currentMessage.id
)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
) )
const messageRequest: MessageRequest = { const messageRequest: MessageRequest = {

View File

@ -4,6 +4,7 @@ import {
ExtensionType, ExtensionType,
ThreadMessage, ThreadMessage,
events, events,
ChatCompletionRole,
} from '@janhq/core' } from '@janhq/core'
import { ConversationalExtension, InferenceExtension } from '@janhq/core' import { ConversationalExtension, InferenceExtension } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
@ -52,6 +53,14 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
} }
} }
const onRegenerateClick = async () => {
if (message.role !== ChatCompletionRole.User) {
// Delete last response before regenerating
await onDeleteClick()
}
resendChatMessage(message)
}
return ( return (
<div className={twMerge('flex flex-row items-center')}> <div className={twMerge('flex flex-row items-center')}>
<div className="flex overflow-hidden rounded-md border border-border bg-background/20"> <div className="flex overflow-hidden rounded-md border border-border bg-background/20">
@ -67,7 +76,7 @@ const MessageToolbar = ({ message }: { message: ThreadMessage }) => {
message.id === messages[messages.length - 1]?.id && ( message.id === messages[messages.length - 1]?.id && (
<div <div
className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80" className="cursor-pointer border-r border-border px-2 py-2 hover:bg-background/80"
onClick={resendChatMessage} onClick={onRegenerateClick}
> >
<RefreshCcw size={14} /> <RefreshCcw size={14} />
</div> </div>