refactor: introduce message request builder (#2481)

This commit is contained in:
Louis 2024-03-25 12:27:41 +07:00
parent 9551996e34
commit 77cbdc2dcf
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
3 changed files with 247 additions and 154 deletions

View File

@ -2,26 +2,19 @@
import { useEffect, useRef } from 'react' import { useEffect, useRef } from 'react'
import { import {
ChatCompletionMessage,
ChatCompletionRole, ChatCompletionRole,
ContentType,
MessageRequest,
MessageRequestType, MessageRequestType,
MessageStatus,
ExtensionTypeEnum, ExtensionTypeEnum,
Thread, Thread,
ThreadMessage, ThreadMessage,
Model, Model,
ConversationalExtension, ConversationalExtension,
InferenceEngine, InferenceEngine,
ChatCompletionMessageContentType,
AssistantTool, AssistantTool,
EngineManager, EngineManager,
} from '@janhq/core' } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { ulid } from 'ulidx'
import { selectedModelAtom } from '@/containers/DropdownListSidebar' import { selectedModelAtom } from '@/containers/DropdownListSidebar'
import { import {
currentPromptAtom, currentPromptAtom,
@ -30,8 +23,11 @@ import {
} from '@/containers/Providers/Jotai' } from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64' import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam' import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { loadModelErrorAtom, useActiveModel } from './useActiveModel' import { loadModelErrorAtom, useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
@ -102,39 +98,13 @@ export default function useSendChatMessage() {
return return
} }
updateThreadWaiting(activeThreadRef.current.id, true) updateThreadWaiting(activeThreadRef.current.id, true)
const messages: ChatCompletionMessage[] = [
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter(
(e) =>
(currentMessage.role === ChatCompletionRole.User ||
e.id !== currentMessage.id) &&
e.status !== MessageStatus.Error
)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
)
const messageRequest: MessageRequest = { const requestBuilder = new MessageRequestBuilder(
id: ulid(), MessageRequestType.Thread,
type: MessageRequestType.Thread, activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
messages: messages, activeThreadRef.current,
threadId: activeThreadRef.current.id, currentMessages
model: ).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions)
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
}
const modelId = const modelId =
selectedModelRef.current?.id ?? selectedModelRef.current?.id ??
@ -143,7 +113,9 @@ export default function useSendChatMessage() {
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
await startModel(modelId) await startModel(modelId)
} }
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
if (currentMessage.role !== ChatCompletionRole.User) { if (currentMessage.role !== ChatCompletionRole.User) {
// Delete last response before regenerating // Delete last response before regenerating
deleteMessage(currentMessage.id ?? '') deleteMessage(currentMessage.id ?? '')
@ -157,11 +129,13 @@ export default function useSendChatMessage() {
} }
} }
const engine = EngineManager.instance()?.get( const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? '' requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
) )
engine?.inference(messageRequest) engine?.inference(requestBuilder.build())
} }
// Define interface extending Array prototype
const sendChatMessage = async (message: string) => { const sendChatMessage = async (message: string) => {
if (!message || message.trim().length === 0) return if (!message || message.trim().length === 0) return
@ -186,8 +160,6 @@ export default function useSendChatMessage() {
const fileContentType = fileUpload[0]?.type const fileContentType = fileUpload[0]?.type
const msgId = ulid()
const isDocumentInput = base64Blob && fileContentType === 'pdf' const isDocumentInput = base64Blob && fileContentType === 'pdf'
const isImageInput = base64Blob && fileContentType === 'image' const isImageInput = base64Blob && fileContentType === 'image'
@ -196,56 +168,6 @@ export default function useSendChatMessage() {
base64Blob = await compressImage(base64Blob, 512) base64Blob = await compressImage(base64Blob, 512)
} }
const messages: ChatCompletionMessage[] = [
activeThreadRef.current.assistants[0]?.instructions,
]
.filter((e) => e && e.trim() !== '')
.map<ChatCompletionMessage>((instructions) => {
const systemMessage: ChatCompletionMessage = {
role: ChatCompletionRole.System,
content: instructions,
}
return systemMessage
})
.concat(
currentMessages
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
.concat([
{
role: ChatCompletionRole.User,
content:
selectedModelRef.current && base64Blob
? [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
},
isDocumentInput
? {
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`,
},
}
: null,
isImageInput
? {
type: ChatCompletionMessageContentType.Image,
image_url: {
url: base64Blob,
},
}
: null,
].filter((e) => e !== null)
: prompt,
} as ChatCompletionMessage,
])
)
let modelRequest = let modelRequest =
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
@ -277,86 +199,48 @@ export default function useSendChatMessage() {
: {}), : {}),
} }
} }
const messageRequest: MessageRequest = {
id: msgId, // Build Message Request
type: MessageRequestType.Thread, const requestBuilder = new MessageRequestBuilder(
threadId: activeThreadRef.current.id, MessageRequestType.Thread,
messages, {
model: {
...modelRequest, ...modelRequest,
settings: settingParams, settings: settingParams,
parameters: runtimeParams, parameters: runtimeParams,
}, },
thread: activeThreadRef.current, activeThreadRef.current,
} currentMessages
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
const timestamp = Date.now() requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
const content: any = []
if (base64Blob && fileUpload[0]?.type === 'image') { // Build Thread Message to persist
content.push({ const threadMessageBuilder = new ThreadMessageBuilder(
type: ContentType.Image, requestBuilder
text: { ).pushMessage(prompt, base64Blob, fileUpload)
value: prompt,
annotations: [base64Blob],
},
})
}
if (base64Blob && fileUpload[0]?.type === 'pdf') { const newMessage = threadMessageBuilder.build()
content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64Blob],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
if (prompt && !base64Blob) { // Push to states
content.push({ addNewMessage(newMessage)
type: ContentType.Text,
text: {
value: prompt,
annotations: [],
},
})
}
const threadMessage: ThreadMessage = {
id: msgId,
thread_id: activeThreadRef.current.id,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
updated: timestamp,
object: 'thread.message',
content: content,
}
addNewMessage(threadMessage)
if (base64Blob) {
setFileUpload([])
}
// Update thread state
const updatedThread: Thread = { const updatedThread: Thread = {
...activeThreadRef.current, ...activeThreadRef.current,
updated: timestamp, updated: newMessage.created,
metadata: { metadata: {
...(activeThreadRef.current.metadata ?? {}), ...(activeThreadRef.current.metadata ?? {}),
lastMessage: prompt, lastMessage: prompt,
}, },
} }
// change last update thread when send message
updateThread(updatedThread) updateThread(updatedThread)
// Add message
await extensionManager await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational) .get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.addNewMessage(threadMessage) ?.addNewMessage(newMessage)
// Start Model if not started
const modelId = const modelId =
selectedModelRef.current?.id ?? selectedModelRef.current?.id ??
activeThreadRef.current.assistants[0].model.id activeThreadRef.current.assistants[0].model.id
@ -369,12 +253,17 @@ export default function useSendChatMessage() {
setIsGeneratingResponse(true) setIsGeneratingResponse(true)
const engine = EngineManager.instance()?.get( const engine = EngineManager.instance()?.get(
messageRequest.model?.engine ?? modelRequest.engine ?? '' requestBuilder.model?.engine ?? modelRequest.engine ?? ''
) )
engine?.inference(messageRequest) engine?.inference(requestBuilder.build())
// Reset states
setReloadModel(false) setReloadModel(false)
setEngineParamsUpdate(false) setEngineParamsUpdate(false)
if (base64Blob) {
setFileUpload([])
}
} }
return { return {

View File

@ -0,0 +1,130 @@
import {
ChatCompletionMessage,
ChatCompletionMessageContent,
ChatCompletionMessageContentText,
ChatCompletionMessageContentType,
ChatCompletionRole,
MessageRequest,
MessageRequestType,
MessageStatus,
ModelInfo,
Thread,
ThreadMessage,
} from '@janhq/core'
import { ulid } from 'ulidx'
import { FileType } from '@/containers/Providers/Jotai'
export class MessageRequestBuilder {
msgId: string
type: MessageRequestType
messages: ChatCompletionMessage[]
model: ModelInfo
thread: Thread
constructor(
type: MessageRequestType,
model: ModelInfo,
thread: Thread,
messages: ThreadMessage[]
) {
this.msgId = ulid()
this.type = type
this.model = model
this.thread = thread
this.messages = messages
.filter((e) => e.status !== MessageStatus.Error)
.map<ChatCompletionMessage>((msg) => ({
role: msg.role,
content: msg.content[0]?.text.value ?? '',
}))
}
// Chainable
pushMessage(
message: string,
base64Blob: string | undefined,
fileContentType: FileType
) {
if (base64Blob && fileContentType === 'pdf')
return this.addDocMessage(message)
else if (base64Blob && fileContentType === 'image') {
return this.addImageMessage(message, base64Blob)
}
this.messages = [
...this.messages,
{
role: ChatCompletionRole.User,
content: message,
},
]
return this
}
// Chainable
addSystemMessage(message: string | undefined) {
if (!message || message.trim() === '') return this
this.messages = [
{
role: ChatCompletionRole.System,
content: message,
},
...this.messages,
]
return this
}
// Chainable
addDocMessage(prompt: string) {
const message: ChatCompletionMessage = {
role: ChatCompletionRole.User,
content: [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
} as ChatCompletionMessageContentText,
{
type: ChatCompletionMessageContentType.Doc,
doc_url: {
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
},
},
] as ChatCompletionMessageContent,
}
this.messages = [message, ...this.messages]
return this
}
// Chainable
addImageMessage(prompt: string, base64: string) {
const message: ChatCompletionMessage = {
role: ChatCompletionRole.User,
content: [
{
type: ChatCompletionMessageContentType.Text,
text: prompt,
} as ChatCompletionMessageContentText,
{
type: ChatCompletionMessageContentType.Image,
image_url: {
url: base64,
},
},
] as ChatCompletionMessageContent,
}
this.messages = [message, ...this.messages]
return this
}
build(): MessageRequest {
return {
id: this.msgId,
type: this.type,
threadId: this.thread.id,
messages: this.messages,
model: this.model,
thread: this.thread,
}
}
}

View File

@ -0,0 +1,74 @@
import {
ChatCompletionRole,
ContentType,
MessageStatus,
ThreadContent,
ThreadMessage,
} from '@janhq/core'
import { FileInfo } from '@/containers/Providers/Jotai'
import { MessageRequestBuilder } from './messageRequestBuilder'
export class ThreadMessageBuilder {
messageRequest: MessageRequestBuilder
content: ThreadContent[] = []
constructor(messageRequest: MessageRequestBuilder) {
this.messageRequest = messageRequest
}
build(): ThreadMessage {
const timestamp = Date.now()
return {
id: this.messageRequest.msgId,
thread_id: this.messageRequest.thread.id,
role: ChatCompletionRole.User,
status: MessageStatus.Ready,
created: timestamp,
updated: timestamp,
object: 'thread.message',
content: this.content,
}
}
pushMessage(
prompt: string,
base64: string | undefined,
fileUpload: FileInfo[]
) {
if (base64 && fileUpload[0]?.type === 'image') {
this.content.push({
type: ContentType.Image,
text: {
value: prompt,
annotations: [base64],
},
})
}
if (base64 && fileUpload[0]?.type === 'pdf') {
this.content.push({
type: ContentType.Pdf,
text: {
value: prompt,
annotations: [base64],
name: fileUpload[0].file.name,
size: fileUpload[0].file.size,
},
})
}
if (prompt && !base64) {
this.content.push({
type: ContentType.Text,
text: {
value: prompt,
annotations: [],
},
})
}
return this
}
}