feat: Thread titles should auto-summarize Topic (#1976)

This commit is contained in:
0xgokuz 2024-02-10 19:16:42 +07:00 committed by GitHub
parent 5864f4989b
commit 875c2bc3c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 90 additions and 7 deletions

View File

@ -1,3 +1,4 @@
export * from './messageEntity' export * from './messageEntity'
export * from './messageInterface' export * from './messageInterface'
export * from './messageEvent' export * from './messageEvent'
export * from './messageRequestType'

View File

@ -27,6 +27,8 @@ export type ThreadMessage = {
updated: number updated: number
/** The additional metadata of this message. **/ /** The additional metadata of this message. **/
metadata?: Record<string, unknown> metadata?: Record<string, unknown>
type?: string
} }
/** /**
@ -56,6 +58,8 @@ export type MessageRequest = {
/** The thread of this message is belong to. **/ /** The thread of this message is belong to. **/
// TODO: deprecate threadId field // TODO: deprecate threadId field
thread?: Thread thread?: Thread
type?: string
} }
/** /**

View File

@ -0,0 +1,5 @@
export enum MessageRequestType {
Thread = 'Thread',
Assistant = 'Assistant',
Summary = 'Summary',
}

View File

@ -10,6 +10,7 @@ import {
ChatCompletionRole, ChatCompletionRole,
ContentType, ContentType,
MessageRequest, MessageRequest,
MessageRequestType,
MessageStatus, MessageStatus,
ThreadContent, ThreadContent,
ThreadMessage, ThreadMessage,
@ -250,6 +251,7 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
thread_id: data.threadId, thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId, assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant, role: ChatCompletionRole.Assistant,
content: [], content: [],
@ -258,7 +260,10 @@ export default class JanInferenceNitroExtension extends InferenceExtension {
updated: timestamp, updated: timestamp,
object: "thread.message", object: "thread.message",
}; };
events.emit(MessageEvent.OnMessageResponse, message);
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message);
}
this.isCancelled = false; this.isCancelled = false;
this.controller = new AbortController(); this.controller = new AbortController();

View File

@ -18,6 +18,7 @@ import {
InferenceEngine, InferenceEngine,
BaseExtension, BaseExtension,
MessageEvent, MessageEvent,
MessageRequestType,
ModelEvent, ModelEvent,
InferenceEvent, InferenceEvent,
AppConfigurationEventName, AppConfigurationEventName,
@ -157,6 +158,7 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
const message: ThreadMessage = { const message: ThreadMessage = {
id: ulid(), id: ulid(),
thread_id: data.threadId, thread_id: data.threadId,
type: data.type,
assistant_id: data.assistantId, assistant_id: data.assistantId,
role: ChatCompletionRole.Assistant, role: ChatCompletionRole.Assistant,
content: [], content: [],
@ -165,7 +167,10 @@ export default class JanInferenceOpenAIExtension extends BaseExtension {
updated: timestamp, updated: timestamp,
object: "thread.message", object: "thread.message",
}; };
events.emit(MessageEvent.OnMessageResponse, message);
if (data.type !== MessageRequestType.Summary) {
events.emit(MessageEvent.OnMessageResponse, message);
}
instance.isCancelled = false; instance.isCancelled = false;
instance.controller = new AbortController(); instance.controller = new AbortController();

View File

@ -2,16 +2,21 @@
import { ReactNode, useCallback, useEffect, useRef } from 'react' import { ReactNode, useCallback, useEffect, useRef } from 'react'
import { import {
ChatCompletionMessage,
ChatCompletionRole,
events, events,
ThreadMessage, ThreadMessage,
ExtensionTypeEnum, ExtensionTypeEnum,
MessageStatus, MessageStatus,
MessageRequest,
Model, Model,
ConversationalExtension, ConversationalExtension,
MessageEvent, MessageEvent,
MessageRequestType,
ModelEvent, ModelEvent,
} from '@janhq/core' } from '@janhq/core'
import { useAtomValue, useSetAtom } from 'jotai' import { useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulid'
import { import {
activeModelAtom, activeModelAtom,
@ -25,6 +30,7 @@ import { toaster } from '../Toast'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
import { import {
getCurrentChatMessagesAtom,
addNewMessageAtom, addNewMessageAtom,
updateMessageAtom, updateMessageAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
@ -37,9 +43,11 @@ import {
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
export default function EventHandler({ children }: { children: ReactNode }) { export default function EventHandler({ children }: { children: ReactNode }) {
const messages = useAtomValue(getCurrentChatMessagesAtom)
const addNewMessage = useSetAtom(addNewMessageAtom) const addNewMessage = useSetAtom(addNewMessageAtom)
const updateMessage = useSetAtom(updateMessageAtom) const updateMessage = useSetAtom(updateMessageAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom) const downloadedModels = useAtomValue(downloadedModelsAtom)
const activeModel = useAtomValue(activeModelAtom)
const setActiveModel = useSetAtom(activeModelAtom) const setActiveModel = useSetAtom(activeModelAtom)
const setStateModel = useSetAtom(stateModelAtom) const setStateModel = useSetAtom(stateModelAtom)
const setQueuedMessage = useSetAtom(queuedMessageAtom) const setQueuedMessage = useSetAtom(queuedMessageAtom)
@ -51,6 +59,8 @@ export default function EventHandler({ children }: { children: ReactNode }) {
const threadsRef = useRef(threads) const threadsRef = useRef(threads)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const updateThread = useSetAtom(updateThreadAtom) const updateThread = useSetAtom(updateThreadAtom)
const messagesRef = useRef(messages)
const activeModelRef = useRef(activeModel)
useEffect(() => { useEffect(() => {
threadsRef.current = threads threadsRef.current = threads
@ -60,9 +70,51 @@ export default function EventHandler({ children }: { children: ReactNode }) {
modelsRef.current = downloadedModels modelsRef.current = downloadedModels
}, [downloadedModels]) }, [downloadedModels])
useEffect(() => {
messagesRef.current = messages
}, [messages])
useEffect(() => {
activeModelRef.current = activeModel
}, [activeModel])
const onNewMessageResponse = useCallback( const onNewMessageResponse = useCallback(
(message: ThreadMessage) => { (message: ThreadMessage) => {
addNewMessage(message) const thread = threadsRef.current?.find((e) => e.id == message.thread_id)
// If this is the first ever prompt in the thread
if (thread && thread.title.trim() == 'New Thread') {
// This is the first time message comes in on a new thread
// Summarize the first message, and make that the title of the Thread
// 1. Get the summary of the first prompt using whatever engine user is currently using
const firstPrompt = messagesRef?.current[0].content[0].text.value.trim()
const summarizeFirstPrompt =
'Summarize "' + firstPrompt + '" in 5 words as a title'
// Prompt: Given this query from user {query}, return to me the summary in 5 words as the title
const msgId = ulid()
const messages: ChatCompletionMessage[] = [
{
role: ChatCompletionRole.User,
content: summarizeFirstPrompt,
} as ChatCompletionMessage,
]
const firstPromptRequest: MessageRequest = {
id: msgId,
threadId: message.thread_id,
type: MessageRequestType.Summary,
messages,
model: activeModelRef?.current,
}
// 2. Update the title with the result of the inference
// the title will be updated as part of the `EventName.OnFirstPromptUpdate`
events.emit(MessageEvent.OnMessageSent, firstPromptRequest)
}
if (message.type !== MessageRequestType.Summary) {
addNewMessage(message)
}
}, },
[addNewMessage] [addNewMessage]
) )
@ -134,6 +186,11 @@ export default function EventHandler({ children }: { children: ReactNode }) {
...(messageContent && { lastMessage: messageContent }), ...(messageContent && { lastMessage: messageContent }),
} }
// Update the Thread title with the response of the inference on the 1st prompt
if (message.type === MessageRequestType.Summary) {
thread.title = messageContent
}
updateThread({ updateThread({
...thread, ...thread,
metadata, metadata,
@ -146,9 +203,12 @@ export default function EventHandler({ children }: { children: ReactNode }) {
metadata, metadata,
}) })
extensionManager // If this is not the summary of the Thread, don't need to add it to the Thread
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) if (message.type !== MessageRequestType.Summary) {
?.addNewMessage(message) extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(message)
}
} }
}, },
[updateMessage, updateThreadWaiting, setIsGeneratingResponse, updateThread] [updateMessage, updateThreadWaiting, setIsGeneratingResponse, updateThread]

View File

@ -6,6 +6,7 @@ import {
ChatCompletionRole, ChatCompletionRole,
ContentType, ContentType,
MessageRequest, MessageRequest,
MessageRequestType,
MessageStatus, MessageStatus,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
@ -112,6 +113,7 @@ export default function useSendChatMessage() {
const messageRequest: MessageRequest = { const messageRequest: MessageRequest = {
id: ulid(), id: ulid(),
type: MessageRequestType.Thread,
messages: messages, messages: messages,
threadId: activeThread.id, threadId: activeThread.id,
model: activeThread.assistants[0].model ?? selectedModel, model: activeThread.assistants[0].model ?? selectedModel,
@ -209,6 +211,7 @@ export default function useSendChatMessage() {
} }
const messageRequest: MessageRequest = { const messageRequest: MessageRequest = {
id: msgId, id: msgId,
type: MessageRequestType.Thread,
threadId: activeThread.id, threadId: activeThread.id,
messages, messages,
model: { model: {
@ -218,8 +221,8 @@ export default function useSendChatMessage() {
}, },
thread: activeThread, thread: activeThread,
} }
const timestamp = Date.now()
const timestamp = Date.now()
const content: any = [] const content: any = []
if (base64Blob && fileUpload[0]?.type === 'image') { if (base64Blob && fileUpload[0]?.type === 'image') {