diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 65daf450f..d316cdbee 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -2,26 +2,19 @@ import { useEffect, useRef } from 'react' import { - ChatCompletionMessage, ChatCompletionRole, - ContentType, - MessageRequest, MessageRequestType, - MessageStatus, ExtensionTypeEnum, Thread, ThreadMessage, Model, ConversationalExtension, InferenceEngine, - ChatCompletionMessageContentType, AssistantTool, EngineManager, } from '@janhq/core' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' -import { ulid } from 'ulidx' - import { selectedModelAtom } from '@/containers/DropdownListSidebar' import { currentPromptAtom, @@ -30,8 +23,11 @@ import { } from '@/containers/Providers/Jotai' import { compressImage, getBase64 } from '@/utils/base64' +import { MessageRequestBuilder } from '@/utils/messageRequestBuilder' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' +import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder' + import { loadModelErrorAtom, useActiveModel } from './useActiveModel' import { extensionManager } from '@/extension/ExtensionManager' @@ -102,39 +98,13 @@ export default function useSendChatMessage() { return } updateThreadWaiting(activeThreadRef.current.id, true) - const messages: ChatCompletionMessage[] = [ - activeThreadRef.current.assistants[0]?.instructions, - ] - .filter((e) => e && e.trim() !== '') - .map((instructions) => { - const systemMessage: ChatCompletionMessage = { - role: ChatCompletionRole.System, - content: instructions, - } - return systemMessage - }) - .concat( - currentMessages - .filter( - (e) => - (currentMessage.role === ChatCompletionRole.User || - e.id !== currentMessage.id) && - e.status !== MessageStatus.Error - ) - .map((msg) => ({ - role: msg.role, - content: msg.content[0]?.text.value ?? '', - })) - ) - const messageRequest: MessageRequest = { - id: ulid(), - type: MessageRequestType.Thread, - messages: messages, - threadId: activeThreadRef.current.id, - model: - activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, - } + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + activeThreadRef.current.assistants[0].model ?? selectedModelRef.current, + activeThreadRef.current, + currentMessages + ).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions) const modelId = selectedModelRef.current?.id ?? @@ -143,7 +113,9 @@ export default function useSendChatMessage() { if (modelRef.current?.id !== modelId) { await startModel(modelId) } + setIsGeneratingResponse(true) + if (currentMessage.role !== ChatCompletionRole.User) { // Delete last response before regenerating deleteMessage(currentMessage.id ?? '') @@ -157,11 +129,13 @@ export default function useSendChatMessage() { } } const engine = EngineManager.instance()?.get( - messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? '' + requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? '' ) - engine?.inference(messageRequest) + engine?.inference(requestBuilder.build()) } + // Define interface extending Array prototype + const sendChatMessage = async (message: string) => { if (!message || message.trim().length === 0) return @@ -186,8 +160,6 @@ export default function useSendChatMessage() { const fileContentType = fileUpload[0]?.type - const msgId = ulid() - const isDocumentInput = base64Blob && fileContentType === 'pdf' const isImageInput = base64Blob && fileContentType === 'image' @@ -196,56 +168,6 @@ export default function useSendChatMessage() { base64Blob = await compressImage(base64Blob, 512) } - const messages: ChatCompletionMessage[] = [ - activeThreadRef.current.assistants[0]?.instructions, - ] - .filter((e) => e && e.trim() !== '') - .map((instructions) => { - const systemMessage: ChatCompletionMessage = { - role: ChatCompletionRole.System, - content: instructions, - } - return systemMessage - }) - .concat( - currentMessages - .filter((e) => e.status !== MessageStatus.Error) - .map((msg) => ({ - role: msg.role, - content: msg.content[0]?.text.value ?? '', - })) - .concat([ - { - role: ChatCompletionRole.User, - content: - selectedModelRef.current && base64Blob - ? [ - { - type: ChatCompletionMessageContentType.Text, - text: prompt, - }, - isDocumentInput - ? { - type: ChatCompletionMessageContentType.Doc, - doc_url: { - url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`, - }, - } - : null, - isImageInput - ? { - type: ChatCompletionMessageContentType.Image, - image_url: { - url: base64Blob, - }, - } - : null, - ].filter((e) => e !== null) - : prompt, - } as ChatCompletionMessage, - ]) - ) - let modelRequest = selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model @@ -277,86 +199,48 @@ export default function useSendChatMessage() { : {}), } } - const messageRequest: MessageRequest = { - id: msgId, - type: MessageRequestType.Thread, - threadId: activeThreadRef.current.id, - messages, - model: { + + // Build Message Request + const requestBuilder = new MessageRequestBuilder( + MessageRequestType.Thread, + { ...modelRequest, settings: settingParams, parameters: runtimeParams, }, - thread: activeThreadRef.current, - } + activeThreadRef.current, + currentMessages + ).addSystemMessage(activeThreadRef.current.assistants[0].instructions) - const timestamp = Date.now() - const content: any = [] + requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type) - if (base64Blob && fileUpload[0]?.type === 'image') { - content.push({ - type: ContentType.Image, - text: { - value: prompt, - annotations: [base64Blob], - }, - }) - } + // Build Thread Message to persist + const threadMessageBuilder = new ThreadMessageBuilder( + requestBuilder + ).pushMessage(prompt, base64Blob, fileUpload) - if (base64Blob && fileUpload[0]?.type === 'pdf') { - content.push({ - type: ContentType.Pdf, - text: { - value: prompt, - annotations: [base64Blob], - name: fileUpload[0].file.name, - size: fileUpload[0].file.size, - }, - }) - } + const newMessage = threadMessageBuilder.build() - if (prompt && !base64Blob) { - content.push({ - type: ContentType.Text, - text: { - value: prompt, - annotations: [], - }, - }) - } - - const threadMessage: ThreadMessage = { - id: msgId, - thread_id: activeThreadRef.current.id, - role: ChatCompletionRole.User, - status: MessageStatus.Ready, - created: timestamp, - updated: timestamp, - object: 'thread.message', - content: content, - } - - addNewMessage(threadMessage) - if (base64Blob) { - setFileUpload([]) - } + // Push to states + addNewMessage(newMessage) + // Update thread state const updatedThread: Thread = { ...activeThreadRef.current, - updated: timestamp, + updated: newMessage.created, metadata: { ...(activeThreadRef.current.metadata ?? {}), lastMessage: prompt, }, } - - // change last update thread when send message updateThread(updatedThread) + // Add message await extensionManager .get(ExtensionTypeEnum.Conversational) - ?.addNewMessage(threadMessage) + ?.addNewMessage(newMessage) + // Start Model if not started const modelId = selectedModelRef.current?.id ?? activeThreadRef.current.assistants[0].model.id @@ -369,12 +253,17 @@ export default function useSendChatMessage() { setIsGeneratingResponse(true) const engine = EngineManager.instance()?.get( - messageRequest.model?.engine ?? modelRequest.engine ?? '' + requestBuilder.model?.engine ?? modelRequest.engine ?? '' ) - engine?.inference(messageRequest) + engine?.inference(requestBuilder.build()) + // Reset states setReloadModel(false) setEngineParamsUpdate(false) + + if (base64Blob) { + setFileUpload([]) + } } return { diff --git a/web/utils/messageRequestBuilder.ts b/web/utils/messageRequestBuilder.ts new file mode 100644 index 000000000..e214b03ea --- /dev/null +++ b/web/utils/messageRequestBuilder.ts @@ -0,0 +1,130 @@ +import { + ChatCompletionMessage, + ChatCompletionMessageContent, + ChatCompletionMessageContentText, + ChatCompletionMessageContentType, + ChatCompletionRole, + MessageRequest, + MessageRequestType, + MessageStatus, + ModelInfo, + Thread, + ThreadMessage, +} from '@janhq/core' +import { ulid } from 'ulidx' + +import { FileType } from '@/containers/Providers/Jotai' + +export class MessageRequestBuilder { + msgId: string + type: MessageRequestType + messages: ChatCompletionMessage[] + model: ModelInfo + thread: Thread + + constructor( + type: MessageRequestType, + model: ModelInfo, + thread: Thread, + messages: ThreadMessage[] + ) { + this.msgId = ulid() + this.type = type + this.model = model + this.thread = thread + this.messages = messages + .filter((e) => e.status !== MessageStatus.Error) + .map((msg) => ({ + role: msg.role, + content: msg.content[0]?.text.value ?? '', + })) + } + + // Chainable + pushMessage( + message: string, + base64Blob: string | undefined, + fileContentType: FileType + ) { + if (base64Blob && fileContentType === 'pdf') + return this.addDocMessage(message) + else if (base64Blob && fileContentType === 'image') { + return this.addImageMessage(message, base64Blob) + } + this.messages = [ + ...this.messages, + { + role: ChatCompletionRole.User, + content: message, + }, + ] + return this + } + + // Chainable + addSystemMessage(message: string | undefined) { + if (!message || message.trim() === '') return this + this.messages = [ + { + role: ChatCompletionRole.System, + content: message, + }, + ...this.messages, + ] + return this + } + + // Chainable + addDocMessage(prompt: string) { + const message: ChatCompletionMessage = { + role: ChatCompletionRole.User, + content: [ + { + type: ChatCompletionMessageContentType.Text, + text: prompt, + } as ChatCompletionMessageContentText, + { + type: ChatCompletionMessageContentType.Doc, + doc_url: { + url: `threads/${this.thread.id}/files/${this.msgId}.pdf`, + }, + }, + ] as ChatCompletionMessageContent, + } + this.messages = [message, ...this.messages] + return this + } + + // Chainable + addImageMessage(prompt: string, base64: string) { + const message: ChatCompletionMessage = { + role: ChatCompletionRole.User, + content: [ + { + type: ChatCompletionMessageContentType.Text, + text: prompt, + } as ChatCompletionMessageContentText, + { + type: ChatCompletionMessageContentType.Image, + image_url: { + url: base64, + }, + }, + ] as ChatCompletionMessageContent, + } + + this.messages = [message, ...this.messages] + return this + } + + build(): MessageRequest { + return { + id: this.msgId, + type: this.type, + threadId: this.thread.id, + messages: this.messages, + model: this.model, + thread: this.thread, + } + } +} diff --git a/web/utils/threadMessageBuilder.ts b/web/utils/threadMessageBuilder.ts new file mode 100644 index 000000000..92e51e574 --- /dev/null +++ b/web/utils/threadMessageBuilder.ts @@ -0,0 +1,74 @@ +import { + ChatCompletionRole, + ContentType, + MessageStatus, + ThreadContent, + ThreadMessage, +} from '@janhq/core' + +import { FileInfo } from '@/containers/Providers/Jotai' + +import { MessageRequestBuilder } from './messageRequestBuilder' + +export class ThreadMessageBuilder { + messageRequest: MessageRequestBuilder + + content: ThreadContent[] = [] + + constructor(messageRequest: MessageRequestBuilder) { + this.messageRequest = messageRequest + } + + build(): ThreadMessage { + const timestamp = Date.now() + return { + id: this.messageRequest.msgId, + thread_id: this.messageRequest.thread.id, + role: ChatCompletionRole.User, + status: MessageStatus.Ready, + created: timestamp, + updated: timestamp, + object: 'thread.message', + content: this.content, + } + } + + pushMessage( + prompt: string, + base64: string | undefined, + fileUpload: FileInfo[] + ) { + if (base64 && fileUpload[0]?.type === 'image') { + this.content.push({ + type: ContentType.Image, + text: { + value: prompt, + annotations: [base64], + }, + }) + } + + if (base64 && fileUpload[0]?.type === 'pdf') { + this.content.push({ + type: ContentType.Pdf, + text: { + value: prompt, + annotations: [base64], + name: fileUpload[0].file.name, + size: fileUpload[0].file.size, + }, + }) + } + + if (prompt && !base64) { + this.content.push({ + type: ContentType.Text, + text: { + value: prompt, + annotations: [], + }, + }) + } + return this + } +}