From 83579df3a40ff61eac25975da8295fceaec679dc Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 2 Aug 2024 16:23:12 +0700 Subject: [PATCH] fix: add back normalize message function (#3234) Signed-off-by: James --- web/hooks/useSendMessage.ts | 50 ++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/web/hooks/useSendMessage.ts b/web/hooks/useSendMessage.ts index a8f118767..3946c5fd8 100644 --- a/web/hooks/useSendMessage.ts +++ b/web/hooks/useSendMessage.ts @@ -22,6 +22,8 @@ import { inferenceErrorAtom } from '@/screens/HubScreen2/components/InferenceErr import { showWarningMultipleModelModalAtom } from '@/screens/HubScreen2/components/WarningMultipleModelModal' import { concurrentModelWarningThreshold } from '@/screens/Settings/MyModels/ModelItem' +import { Stack } from '@/utils/Stack' + import useCortex from './useCortex' import useEngineInit from './useEngineInit' @@ -47,28 +49,29 @@ import { updateThreadTitleAtom, } from '@/helpers/atoms/Thread.atom' -// TODO: NamH add this back -// const normalizeMessages = (messages: Message[]): Message[] => { -// const stack = new Stack() -// for (const message of messages) { -// if (stack.isEmpty()) { -// stack.push(message) -// continue -// } -// const topMessage = stack.peek() +const normalizeMessages = ( + messages: ChatCompletionMessageParam[] +): ChatCompletionMessageParam[] => { + const stack = new Stack() + for (const message of messages) { + if (stack.isEmpty()) { + stack.push(message) + continue + } + const topMessage = stack.peek() -// if (message.role === topMessage.role) { -// // add an empty message -// stack.push({ -// role: topMessage.role === 'user' ? 'assistant' : 'user', -// content: '.', // some model requires not empty message -// }) -// } -// stack.push(message) -// } + if (message.role === topMessage.role) { + // add an empty message + stack.push({ + role: topMessage.role === 'user' ? 'assistant' : 'user', + content: '.', // some model requires not empty message + }) + } + stack.push(message) + } -// return stack.reverseOutput() -// } + return stack.reverseOutput() +} const useSendMessage = () => { const createMessage = useMessageCreateMutation() @@ -285,7 +288,7 @@ const useSendMessage = () => { content: activeThread!.assistants[0].instructions ?? '', } - const messages: ChatCompletionMessageParam[] = currentMessages + let messages: ChatCompletionMessageParam[] = currentMessages .map((msg) => { switch (msg.role) { case 'user': @@ -305,7 +308,7 @@ const useSendMessage = () => { }) .filter((msg) => msg != null) as ChatCompletionMessageParam[] messages.unshift(systemMessage) - + messages = normalizeMessages(messages) const modelOptions: Record = {} if (selectedModel!.frequency_penalty) { modelOptions.frequency_penalty = selectedModel!.frequency_penalty @@ -540,7 +543,7 @@ const useSendMessage = () => { content: activeThread!.assistants[0].instructions ?? '', } - const messages: ChatCompletionMessageParam[] = currentMessages + let messages: ChatCompletionMessageParam[] = currentMessages .map((msg) => { switch (msg.role) { case 'user': @@ -564,6 +567,7 @@ const useSendMessage = () => { content: message, }) messages.unshift(systemMessage) + messages = normalizeMessages(messages) const modelOptions: Record = {} if (selectedModel!.frequency_penalty) { modelOptions.frequency_penalty = selectedModel!.frequency_penalty