refactor: introduce message request builder (#2481)
This commit is contained in:
parent
9551996e34
commit
77cbdc2dcf
@ -2,26 +2,19 @@
|
||||
import { useEffect, useRef } from 'react'
|
||||
|
||||
import {
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
MessageRequest,
|
||||
MessageRequestType,
|
||||
MessageStatus,
|
||||
ExtensionTypeEnum,
|
||||
Thread,
|
||||
ThreadMessage,
|
||||
Model,
|
||||
ConversationalExtension,
|
||||
InferenceEngine,
|
||||
ChatCompletionMessageContentType,
|
||||
AssistantTool,
|
||||
EngineManager,
|
||||
} from '@janhq/core'
|
||||
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
|
||||
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import { selectedModelAtom } from '@/containers/DropdownListSidebar'
|
||||
import {
|
||||
currentPromptAtom,
|
||||
@ -30,8 +23,11 @@ import {
|
||||
} from '@/containers/Providers/Jotai'
|
||||
|
||||
import { compressImage, getBase64 } from '@/utils/base64'
|
||||
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
|
||||
import { toRuntimeParams, toSettingParams } from '@/utils/modelParam'
|
||||
|
||||
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
|
||||
|
||||
import { loadModelErrorAtom, useActiveModel } from './useActiveModel'
|
||||
|
||||
import { extensionManager } from '@/extension/ExtensionManager'
|
||||
@ -102,39 +98,13 @@ export default function useSendChatMessage() {
|
||||
return
|
||||
}
|
||||
updateThreadWaiting(activeThreadRef.current.id, true)
|
||||
const messages: ChatCompletionMessage[] = [
|
||||
activeThreadRef.current.assistants[0]?.instructions,
|
||||
]
|
||||
.filter((e) => e && e.trim() !== '')
|
||||
.map<ChatCompletionMessage>((instructions) => {
|
||||
const systemMessage: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.System,
|
||||
content: instructions,
|
||||
}
|
||||
return systemMessage
|
||||
})
|
||||
.concat(
|
||||
currentMessages
|
||||
.filter(
|
||||
(e) =>
|
||||
(currentMessage.role === ChatCompletionRole.User ||
|
||||
e.id !== currentMessage.id) &&
|
||||
e.status !== MessageStatus.Error
|
||||
)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
)
|
||||
|
||||
const messageRequest: MessageRequest = {
|
||||
id: ulid(),
|
||||
type: MessageRequestType.Thread,
|
||||
messages: messages,
|
||||
threadId: activeThreadRef.current.id,
|
||||
model:
|
||||
const requestBuilder = new MessageRequestBuilder(
|
||||
MessageRequestType.Thread,
|
||||
activeThreadRef.current.assistants[0].model ?? selectedModelRef.current,
|
||||
}
|
||||
activeThreadRef.current,
|
||||
currentMessages
|
||||
).addSystemMessage(activeThreadRef.current.assistants[0]?.instructions)
|
||||
|
||||
const modelId =
|
||||
selectedModelRef.current?.id ??
|
||||
@ -143,7 +113,9 @@ export default function useSendChatMessage() {
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
await startModel(modelId)
|
||||
}
|
||||
|
||||
setIsGeneratingResponse(true)
|
||||
|
||||
if (currentMessage.role !== ChatCompletionRole.User) {
|
||||
// Delete last response before regenerating
|
||||
deleteMessage(currentMessage.id ?? '')
|
||||
@ -157,11 +129,13 @@ export default function useSendChatMessage() {
|
||||
}
|
||||
}
|
||||
const engine = EngineManager.instance()?.get(
|
||||
messageRequest.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
requestBuilder.model?.engine ?? selectedModelRef.current?.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
engine?.inference(requestBuilder.build())
|
||||
}
|
||||
|
||||
// Define interface extending Array prototype
|
||||
|
||||
const sendChatMessage = async (message: string) => {
|
||||
if (!message || message.trim().length === 0) return
|
||||
|
||||
@ -186,8 +160,6 @@ export default function useSendChatMessage() {
|
||||
|
||||
const fileContentType = fileUpload[0]?.type
|
||||
|
||||
const msgId = ulid()
|
||||
|
||||
const isDocumentInput = base64Blob && fileContentType === 'pdf'
|
||||
const isImageInput = base64Blob && fileContentType === 'image'
|
||||
|
||||
@ -196,56 +168,6 @@ export default function useSendChatMessage() {
|
||||
base64Blob = await compressImage(base64Blob, 512)
|
||||
}
|
||||
|
||||
const messages: ChatCompletionMessage[] = [
|
||||
activeThreadRef.current.assistants[0]?.instructions,
|
||||
]
|
||||
.filter((e) => e && e.trim() !== '')
|
||||
.map<ChatCompletionMessage>((instructions) => {
|
||||
const systemMessage: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.System,
|
||||
content: instructions,
|
||||
}
|
||||
return systemMessage
|
||||
})
|
||||
.concat(
|
||||
currentMessages
|
||||
.filter((e) => e.status !== MessageStatus.Error)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
.concat([
|
||||
{
|
||||
role: ChatCompletionRole.User,
|
||||
content:
|
||||
selectedModelRef.current && base64Blob
|
||||
? [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
},
|
||||
isDocumentInput
|
||||
? {
|
||||
type: ChatCompletionMessageContentType.Doc,
|
||||
doc_url: {
|
||||
url: `threads/${activeThreadRef.current.id}/files/${msgId}.pdf`,
|
||||
},
|
||||
}
|
||||
: null,
|
||||
isImageInput
|
||||
? {
|
||||
type: ChatCompletionMessageContentType.Image,
|
||||
image_url: {
|
||||
url: base64Blob,
|
||||
},
|
||||
}
|
||||
: null,
|
||||
].filter((e) => e !== null)
|
||||
: prompt,
|
||||
} as ChatCompletionMessage,
|
||||
])
|
||||
)
|
||||
|
||||
let modelRequest =
|
||||
selectedModelRef?.current ?? activeThreadRef.current.assistants[0].model
|
||||
|
||||
@ -277,86 +199,48 @@ export default function useSendChatMessage() {
|
||||
: {}),
|
||||
}
|
||||
}
|
||||
const messageRequest: MessageRequest = {
|
||||
id: msgId,
|
||||
type: MessageRequestType.Thread,
|
||||
threadId: activeThreadRef.current.id,
|
||||
messages,
|
||||
model: {
|
||||
|
||||
// Build Message Request
|
||||
const requestBuilder = new MessageRequestBuilder(
|
||||
MessageRequestType.Thread,
|
||||
{
|
||||
...modelRequest,
|
||||
settings: settingParams,
|
||||
parameters: runtimeParams,
|
||||
},
|
||||
thread: activeThreadRef.current,
|
||||
}
|
||||
activeThreadRef.current,
|
||||
currentMessages
|
||||
).addSystemMessage(activeThreadRef.current.assistants[0].instructions)
|
||||
|
||||
const timestamp = Date.now()
|
||||
const content: any = []
|
||||
requestBuilder.pushMessage(prompt, base64Blob, fileUpload[0]?.type)
|
||||
|
||||
if (base64Blob && fileUpload[0]?.type === 'image') {
|
||||
content.push({
|
||||
type: ContentType.Image,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64Blob],
|
||||
},
|
||||
})
|
||||
}
|
||||
// Build Thread Message to persist
|
||||
const threadMessageBuilder = new ThreadMessageBuilder(
|
||||
requestBuilder
|
||||
).pushMessage(prompt, base64Blob, fileUpload)
|
||||
|
||||
if (base64Blob && fileUpload[0]?.type === 'pdf') {
|
||||
content.push({
|
||||
type: ContentType.Pdf,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64Blob],
|
||||
name: fileUpload[0].file.name,
|
||||
size: fileUpload[0].file.size,
|
||||
},
|
||||
})
|
||||
}
|
||||
const newMessage = threadMessageBuilder.build()
|
||||
|
||||
if (prompt && !base64Blob) {
|
||||
content.push({
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const threadMessage: ThreadMessage = {
|
||||
id: msgId,
|
||||
thread_id: activeThreadRef.current.id,
|
||||
role: ChatCompletionRole.User,
|
||||
status: MessageStatus.Ready,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: 'thread.message',
|
||||
content: content,
|
||||
}
|
||||
|
||||
addNewMessage(threadMessage)
|
||||
if (base64Blob) {
|
||||
setFileUpload([])
|
||||
}
|
||||
// Push to states
|
||||
addNewMessage(newMessage)
|
||||
|
||||
// Update thread state
|
||||
const updatedThread: Thread = {
|
||||
...activeThreadRef.current,
|
||||
updated: timestamp,
|
||||
updated: newMessage.created,
|
||||
metadata: {
|
||||
...(activeThreadRef.current.metadata ?? {}),
|
||||
lastMessage: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
// change last update thread when send message
|
||||
updateThread(updatedThread)
|
||||
|
||||
// Add message
|
||||
await extensionManager
|
||||
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
|
||||
?.addNewMessage(threadMessage)
|
||||
?.addNewMessage(newMessage)
|
||||
|
||||
// Start Model if not started
|
||||
const modelId =
|
||||
selectedModelRef.current?.id ??
|
||||
activeThreadRef.current.assistants[0].model.id
|
||||
@ -369,12 +253,17 @@ export default function useSendChatMessage() {
|
||||
setIsGeneratingResponse(true)
|
||||
|
||||
const engine = EngineManager.instance()?.get(
|
||||
messageRequest.model?.engine ?? modelRequest.engine ?? ''
|
||||
requestBuilder.model?.engine ?? modelRequest.engine ?? ''
|
||||
)
|
||||
engine?.inference(messageRequest)
|
||||
engine?.inference(requestBuilder.build())
|
||||
|
||||
// Reset states
|
||||
setReloadModel(false)
|
||||
setEngineParamsUpdate(false)
|
||||
|
||||
if (base64Blob) {
|
||||
setFileUpload([])
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
130
web/utils/messageRequestBuilder.ts
Normal file
130
web/utils/messageRequestBuilder.ts
Normal file
@ -0,0 +1,130 @@
|
||||
import {
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageContent,
|
||||
ChatCompletionMessageContentText,
|
||||
ChatCompletionMessageContentType,
|
||||
ChatCompletionRole,
|
||||
MessageRequest,
|
||||
MessageRequestType,
|
||||
MessageStatus,
|
||||
ModelInfo,
|
||||
Thread,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
import { ulid } from 'ulidx'
|
||||
|
||||
import { FileType } from '@/containers/Providers/Jotai'
|
||||
|
||||
export class MessageRequestBuilder {
|
||||
msgId: string
|
||||
type: MessageRequestType
|
||||
messages: ChatCompletionMessage[]
|
||||
model: ModelInfo
|
||||
thread: Thread
|
||||
|
||||
constructor(
|
||||
type: MessageRequestType,
|
||||
model: ModelInfo,
|
||||
thread: Thread,
|
||||
messages: ThreadMessage[]
|
||||
) {
|
||||
this.msgId = ulid()
|
||||
this.type = type
|
||||
this.model = model
|
||||
this.thread = thread
|
||||
this.messages = messages
|
||||
.filter((e) => e.status !== MessageStatus.Error)
|
||||
.map<ChatCompletionMessage>((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content[0]?.text.value ?? '',
|
||||
}))
|
||||
}
|
||||
|
||||
// Chainable
|
||||
pushMessage(
|
||||
message: string,
|
||||
base64Blob: string | undefined,
|
||||
fileContentType: FileType
|
||||
) {
|
||||
if (base64Blob && fileContentType === 'pdf')
|
||||
return this.addDocMessage(message)
|
||||
else if (base64Blob && fileContentType === 'image') {
|
||||
return this.addImageMessage(message, base64Blob)
|
||||
}
|
||||
this.messages = [
|
||||
...this.messages,
|
||||
{
|
||||
role: ChatCompletionRole.User,
|
||||
content: message,
|
||||
},
|
||||
]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addSystemMessage(message: string | undefined) {
|
||||
if (!message || message.trim() === '') return this
|
||||
this.messages = [
|
||||
{
|
||||
role: ChatCompletionRole.System,
|
||||
content: message,
|
||||
},
|
||||
...this.messages,
|
||||
]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addDocMessage(prompt: string) {
|
||||
const message: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.User,
|
||||
content: [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
} as ChatCompletionMessageContentText,
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Doc,
|
||||
doc_url: {
|
||||
url: `threads/${this.thread.id}/files/${this.msgId}.pdf`,
|
||||
},
|
||||
},
|
||||
] as ChatCompletionMessageContent,
|
||||
}
|
||||
this.messages = [message, ...this.messages]
|
||||
return this
|
||||
}
|
||||
|
||||
// Chainable
|
||||
addImageMessage(prompt: string, base64: string) {
|
||||
const message: ChatCompletionMessage = {
|
||||
role: ChatCompletionRole.User,
|
||||
content: [
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Text,
|
||||
text: prompt,
|
||||
} as ChatCompletionMessageContentText,
|
||||
{
|
||||
type: ChatCompletionMessageContentType.Image,
|
||||
image_url: {
|
||||
url: base64,
|
||||
},
|
||||
},
|
||||
] as ChatCompletionMessageContent,
|
||||
}
|
||||
|
||||
this.messages = [message, ...this.messages]
|
||||
return this
|
||||
}
|
||||
|
||||
build(): MessageRequest {
|
||||
return {
|
||||
id: this.msgId,
|
||||
type: this.type,
|
||||
threadId: this.thread.id,
|
||||
messages: this.messages,
|
||||
model: this.model,
|
||||
thread: this.thread,
|
||||
}
|
||||
}
|
||||
}
|
||||
74
web/utils/threadMessageBuilder.ts
Normal file
74
web/utils/threadMessageBuilder.ts
Normal file
@ -0,0 +1,74 @@
|
||||
import {
|
||||
ChatCompletionRole,
|
||||
ContentType,
|
||||
MessageStatus,
|
||||
ThreadContent,
|
||||
ThreadMessage,
|
||||
} from '@janhq/core'
|
||||
|
||||
import { FileInfo } from '@/containers/Providers/Jotai'
|
||||
|
||||
import { MessageRequestBuilder } from './messageRequestBuilder'
|
||||
|
||||
export class ThreadMessageBuilder {
|
||||
messageRequest: MessageRequestBuilder
|
||||
|
||||
content: ThreadContent[] = []
|
||||
|
||||
constructor(messageRequest: MessageRequestBuilder) {
|
||||
this.messageRequest = messageRequest
|
||||
}
|
||||
|
||||
build(): ThreadMessage {
|
||||
const timestamp = Date.now()
|
||||
return {
|
||||
id: this.messageRequest.msgId,
|
||||
thread_id: this.messageRequest.thread.id,
|
||||
role: ChatCompletionRole.User,
|
||||
status: MessageStatus.Ready,
|
||||
created: timestamp,
|
||||
updated: timestamp,
|
||||
object: 'thread.message',
|
||||
content: this.content,
|
||||
}
|
||||
}
|
||||
|
||||
pushMessage(
|
||||
prompt: string,
|
||||
base64: string | undefined,
|
||||
fileUpload: FileInfo[]
|
||||
) {
|
||||
if (base64 && fileUpload[0]?.type === 'image') {
|
||||
this.content.push({
|
||||
type: ContentType.Image,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if (base64 && fileUpload[0]?.type === 'pdf') {
|
||||
this.content.push({
|
||||
type: ContentType.Pdf,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [base64],
|
||||
name: fileUpload[0].file.name,
|
||||
size: fileUpload[0].file.size,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if (prompt && !base64) {
|
||||
this.content.push({
|
||||
type: ContentType.Text,
|
||||
text: {
|
||||
value: prompt,
|
||||
annotations: [],
|
||||
},
|
||||
})
|
||||
}
|
||||
return this
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user