diff --git a/core/src/types/message/index.ts b/core/src/types/message/index.ts index e8d78deda..ebb4c363d 100644 --- a/core/src/types/message/index.ts +++ b/core/src/types/message/index.ts @@ -1,3 +1,4 @@ export * from './messageEntity' export * from './messageInterface' export * from './messageEvent' +export * from './messageRequestType' diff --git a/core/src/types/message/messageEntity.ts b/core/src/types/message/messageEntity.ts index 87e4b1997..e9211d550 100644 --- a/core/src/types/message/messageEntity.ts +++ b/core/src/types/message/messageEntity.ts @@ -27,6 +27,8 @@ export type ThreadMessage = { updated: number /** The additional metadata of this message. **/ metadata?: Record + + type?: string } /** @@ -56,6 +58,8 @@ export type MessageRequest = { /** The thread of this message is belong to. **/ // TODO: deprecate threadId field thread?: Thread + + type?: string } /** diff --git a/core/src/types/message/messageRequestType.ts b/core/src/types/message/messageRequestType.ts new file mode 100644 index 000000000..51be51996 --- /dev/null +++ b/core/src/types/message/messageRequestType.ts @@ -0,0 +1,5 @@ +export enum MessageRequestType { + Thread = 'Thread', + Assistant = 'Assistant', + Summary = 'Summary', +} \ No newline at end of file diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 9e96ad93f..7374b6977 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -10,6 +10,7 @@ import { ChatCompletionRole, ContentType, MessageRequest, + MessageRequestType, MessageStatus, ThreadContent, ThreadMessage, @@ -250,6 +251,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension { const message: ThreadMessage = { id: ulid(), thread_id: data.threadId, + type: data.type, assistant_id: data.assistantId, role: ChatCompletionRole.Assistant, content: [], @@ -258,7 +260,10 @@ export default class JanInferenceNitroExtension extends InferenceExtension { updated: timestamp, object: "thread.message", }; - events.emit(MessageEvent.OnMessageResponse, message); + + if (data.type !== MessageRequestType.Summary) { + events.emit(MessageEvent.OnMessageResponse, message); + } this.isCancelled = false; this.controller = new AbortController(); diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index fd1230bc7..23fd8983e 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -18,6 +18,7 @@ import { InferenceEngine, BaseExtension, MessageEvent, + MessageRequestType, ModelEvent, InferenceEvent, AppConfigurationEventName, @@ -157,6 +158,7 @@ export default class JanInferenceOpenAIExtension extends BaseExtension { const message: ThreadMessage = { id: ulid(), thread_id: data.threadId, + type: data.type, assistant_id: data.assistantId, role: ChatCompletionRole.Assistant, content: [], @@ -165,7 +167,10 @@ export default class JanInferenceOpenAIExtension extends BaseExtension { updated: timestamp, object: "thread.message", }; - events.emit(MessageEvent.OnMessageResponse, message); + + if (data.type !== MessageRequestType.Summary) { + events.emit(MessageEvent.OnMessageResponse, message); + } instance.isCancelled = false; instance.controller = new AbortController(); diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index 170ec5e64..7f8bd261c 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -2,16 +2,21 @@ import { ReactNode, useCallback, useEffect, useRef } from 'react' import { + ChatCompletionMessage, + ChatCompletionRole, events, ThreadMessage, ExtensionTypeEnum, MessageStatus, + MessageRequest, Model, ConversationalExtension, MessageEvent, + MessageRequestType, ModelEvent, } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' +import { ulid } from 'ulid' import { activeModelAtom, @@ -25,6 +30,7 @@ import { toaster } from '../Toast' import { extensionManager } from '@/extension' import { + getCurrentChatMessagesAtom, addNewMessageAtom, updateMessageAtom, } from '@/helpers/atoms/ChatMessage.atom' @@ -37,9 +43,11 @@ import { } from '@/helpers/atoms/Thread.atom' export default function EventHandler({ children }: { children: ReactNode }) { + const messages = useAtomValue(getCurrentChatMessagesAtom) const addNewMessage = useSetAtom(addNewMessageAtom) const updateMessage = useSetAtom(updateMessageAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) + const activeModel = useAtomValue(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom) const setStateModel = useSetAtom(stateModelAtom) const setQueuedMessage = useSetAtom(queuedMessageAtom) @@ -51,6 +59,8 @@ export default function EventHandler({ children }: { children: ReactNode }) { const threadsRef = useRef(threads) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const updateThread = useSetAtom(updateThreadAtom) + const messagesRef = useRef(messages) + const activeModelRef = useRef(activeModel) useEffect(() => { threadsRef.current = threads @@ -60,9 +70,51 @@ export default function EventHandler({ children }: { children: ReactNode }) { modelsRef.current = downloadedModels }, [downloadedModels]) + useEffect(() => { + messagesRef.current = messages + }, [messages]) + + useEffect(() => { + activeModelRef.current = activeModel + }, [activeModel]) + const onNewMessageResponse = useCallback( (message: ThreadMessage) => { - addNewMessage(message) + const thread = threadsRef.current?.find((e) => e.id == message.thread_id) + // If this is the first ever prompt in the thread + if (thread && thread.title.trim() == 'New Thread') { + // This is the first time message comes in on a new thread + // Summarize the first message, and make that the title of the Thread + // 1. Get the summary of the first prompt using whatever engine user is currently using + const firstPrompt = messagesRef?.current[0].content[0].text.value.trim() + const summarizeFirstPrompt = + 'Summarize "' + firstPrompt + '" in 5 words as a title' + + // Prompt: Given this query from user {query}, return to me the summary in 5 words as the title + const msgId = ulid() + const messages: ChatCompletionMessage[] = [ + { + role: ChatCompletionRole.User, + content: summarizeFirstPrompt, + } as ChatCompletionMessage, + ] + + const firstPromptRequest: MessageRequest = { + id: msgId, + threadId: message.thread_id, + type: MessageRequestType.Summary, + messages, + model: activeModelRef?.current, + } + + // 2. Update the title with the result of the inference + // the title will be updated as part of the `EventName.OnFirstPromptUpdate` + events.emit(MessageEvent.OnMessageSent, firstPromptRequest) + } + + if (message.type !== MessageRequestType.Summary) { + addNewMessage(message) + } }, [addNewMessage] ) @@ -134,6 +186,11 @@ export default function EventHandler({ children }: { children: ReactNode }) { ...(messageContent && { lastMessage: messageContent }), } + // Update the Thread title with the response of the inference on the 1st prompt + if (message.type === MessageRequestType.Summary) { + thread.title = messageContent + } + updateThread({ ...thread, metadata, @@ -146,9 +203,12 @@ export default function EventHandler({ children }: { children: ReactNode }) { metadata, }) - extensionManager - .get(ExtensionTypeEnum.Conversational) - ?.addNewMessage(message) + // If this is not the summary of the Thread, don't need to add it to the Thread + if (message.type !== MessageRequestType.Summary) { + extensionManager + .get(ExtensionTypeEnum.Conversational) + ?.addNewMessage(message) + } } }, [updateMessage, updateThreadWaiting, setIsGeneratingResponse, updateThread] diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 7d89764db..d7c2d10fd 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -6,6 +6,7 @@ import { ChatCompletionRole, ContentType, MessageRequest, + MessageRequestType, MessageStatus, ExtensionTypeEnum, Thread, @@ -112,6 +113,7 @@ export default function useSendChatMessage() { const messageRequest: MessageRequest = { id: ulid(), + type: MessageRequestType.Thread, messages: messages, threadId: activeThread.id, model: activeThread.assistants[0].model ?? selectedModel, @@ -209,6 +211,7 @@ export default function useSendChatMessage() { } const messageRequest: MessageRequest = { id: msgId, + type: MessageRequestType.Thread, threadId: activeThread.id, messages, model: { @@ -218,8 +221,8 @@ export default function useSendChatMessage() { }, thread: activeThread, } - const timestamp = Date.now() + const timestamp = Date.now() const content: any = [] if (base64Blob && fileUpload[0]?.type === 'image') {