jan/web/hooks/useSendChatMessage.ts
2025-04-14 15:33:53 +07:00

516 lines
15 KiB
TypeScript

import { useEffect, useRef } from 'react'
import {
MessageRequestType,
ExtensionTypeEnum,
Thread,
ThreadMessage,
Model,
ConversationalExtension,
ThreadAssistantInfo,
events,
MessageEvent,
ContentType,
EngineManager,
InferenceEngine,
MessageStatus,
ChatCompletionRole,
} from '@janhq/core'
import { extractInferenceParams, extractModelLoadParams } from '@janhq/core'
import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai'
import { OpenAI } from 'openai'
import {
ChatCompletionMessageParam,
ChatCompletionRole as OpenAIChatCompletionRole,
ChatCompletionTool,
ChatCompletionMessageToolCall,
} from 'openai/resources/chat'
import { Stream } from 'openai/streaming'
import { ulid } from 'ulidx'
import { modelDropdownStateAtom } from '@/containers/ModelDropdown'
import {
currentPromptAtom,
editPromptAtom,
fileUploadAtom,
} from '@/containers/Providers/Jotai'
import { compressImage, getBase64 } from '@/utils/base64'
import { MessageRequestBuilder } from '@/utils/messageRequestBuilder'
import { ThreadMessageBuilder } from '@/utils/threadMessageBuilder'
import { useActiveModel } from './useActiveModel'
import { extensionManager } from '@/extension/ExtensionManager'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
import {
addNewMessageAtom,
deleteMessageAtom,
getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import {
activeThreadAtom,
engineParamsUpdateAtom,
getActiveThreadModelParamsAtom,
isGeneratingResponseAtom,
updateThreadAtom,
updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Thread.atom'
import { ModelTool } from '@/types/model'
export const reloadModelAtom = atom(false)
export default function useSendChatMessage() {
const activeThread = useAtomValue(activeThreadAtom)
const activeAssistant = useAtomValue(activeAssistantAtom)
const addNewMessage = useSetAtom(addNewMessageAtom)
const updateThread = useSetAtom(updateThreadAtom)
const updateThreadWaiting = useSetAtom(updateThreadWaitingForResponseAtom)
const setCurrentPrompt = useSetAtom(currentPromptAtom)
const deleteMessage = useSetAtom(deleteMessageAtom)
const setEditPrompt = useSetAtom(editPromptAtom)
const currentMessages = useAtomValue(getCurrentChatMessagesAtom)
const selectedModel = useAtomValue(selectedModelAtom)
const { activeModel, startModel } = useActiveModel()
const modelRef = useRef<Model | undefined>()
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const engineParamsUpdate = useAtomValue(engineParamsUpdateAtom)
const setEngineParamsUpdate = useSetAtom(engineParamsUpdateAtom)
const setReloadModel = useSetAtom(reloadModelAtom)
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>()
const activeAssistantRef = useRef<ThreadAssistantInfo | undefined>()
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
const setModelDropdownState = useSetAtom(modelDropdownStateAtom)
const selectedModelRef = useRef<Model | undefined>()
useEffect(() => {
modelRef.current = activeModel
}, [activeModel])
useEffect(() => {
activeThreadRef.current = activeThread
}, [activeThread])
useEffect(() => {
selectedModelRef.current = selectedModel
}, [selectedModel])
useEffect(() => {
activeAssistantRef.current = activeAssistant
}, [activeAssistant])
const resendChatMessage = async () => {
// Delete last response before regenerating
const newConvoData = Array.from(currentMessages)
let toSendMessage = newConvoData.pop()
while (toSendMessage && toSendMessage?.role !== 'user') {
await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.deleteMessage(toSendMessage.thread_id, toSendMessage.id)
.catch(console.error)
deleteMessage(toSendMessage.id ?? '')
toSendMessage = newConvoData.pop()
}
if (toSendMessage?.content[0]?.text?.value)
sendChatMessage(toSendMessage.content[0].text.value, true, newConvoData)
}
const sendChatMessage = async (
message: string,
isResend: boolean = false,
messages?: ThreadMessage[]
) => {
if (!message || message.trim().length === 0) return
if (!activeThreadRef.current || !activeAssistantRef.current) {
console.error('No active thread or assistant')
return
}
if (selectedModelRef.current?.id === undefined) {
setModelDropdownState(true)
return
}
if (engineParamsUpdate) setReloadModel(true)
setTokenSpeed(undefined)
const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = extractModelLoadParams(activeModelParams)
const prompt = message.trim()
updateThreadWaiting(activeThreadRef.current.id, true)
setCurrentPrompt('')
setEditPrompt('')
let base64Blob = fileUpload ? await getBase64(fileUpload.file) : undefined
if (base64Blob && fileUpload?.type === 'image') {
// Compress image
base64Blob = await compressImage(base64Blob, 512)
}
const modelRequest =
selectedModelRef?.current ?? activeAssistantRef.current?.model
// Fallback support for previous broken threads
if (activeAssistantRef.current?.model?.id === '*') {
activeAssistantRef.current.model = {
id: modelRequest.id,
settings: modelRequest.settings,
parameters: modelRequest.parameters,
}
}
if (runtimeParams.stream == null) {
runtimeParams.stream = true
}
// Build Message Request
const requestBuilder = new MessageRequestBuilder(
MessageRequestType.Thread,
{
...modelRequest,
settings: settingParams,
parameters: runtimeParams,
},
activeThreadRef.current,
messages ?? currentMessages,
(await window.core.api.getTools())?.map((tool: ModelTool) => ({
type: 'function' as const,
function: {
name: tool.name,
description: tool.description?.slice(0, 1024),
parameters: tool.inputSchema,
strict: false,
},
}))
).addSystemMessage(activeAssistantRef.current?.instructions)
requestBuilder.pushMessage(prompt, base64Blob, fileUpload)
// Build Thread Message to persist
const threadMessageBuilder = new ThreadMessageBuilder(
requestBuilder
).pushMessage(prompt, base64Blob, fileUpload)
const newMessage = threadMessageBuilder.build()
// Update thread state
const updatedThread: Thread = {
...activeThreadRef.current,
updated: newMessage.created_at,
metadata: {
...activeThreadRef.current.metadata,
lastMessage: prompt,
},
}
updateThread(updatedThread)
if (
!isResend &&
(newMessage.content.length || newMessage.attachments?.length)
) {
// Add message
const createdMessage = await extensionManager
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.createMessage(newMessage)
.catch(() => undefined)
if (!createdMessage) return
// Push to states
addNewMessage(createdMessage)
}
// Start Model if not started
const modelId =
selectedModelRef.current?.id ?? activeAssistantRef.current?.model.id
if (base64Blob) {
setFileUpload(undefined)
}
if (modelRef.current?.id !== modelId && modelId) {
const error = await startModel(modelId).catch((error: Error) => error)
if (error) {
updateThreadWaiting(activeThreadRef.current.id, false)
return
}
}
setIsGeneratingResponse(true)
if (requestBuilder.tools && requestBuilder.tools.length) {
let isDone = false
const openai = new OpenAI({
apiKey: await window.core.api.appToken(),
baseURL: `${API_BASE_URL}/v1`,
dangerouslyAllowBrowser: true,
})
let parentMessageId: string | undefined
while (!isDone) {
let messageId = ulid()
if (!parentMessageId) {
parentMessageId = ulid()
messageId = parentMessageId
}
const data = requestBuilder.build()
const message: ThreadMessage = {
id: messageId,
object: 'message',
thread_id: activeThreadRef.current.id,
assistant_id: activeAssistantRef.current.assistant_id,
role: ChatCompletionRole.Assistant,
content: [],
metadata: {
...(messageId !== parentMessageId
? { parent_id: parentMessageId }
: {}),
},
status: MessageStatus.Pending,
created_at: Date.now() / 1000,
completed_at: Date.now() / 1000,
}
events.emit(MessageEvent.OnMessageResponse, message)
const response = await openai.chat.completions.create({
messages: requestBuilder.messages as ChatCompletionMessageParam[],
model: data.model?.id ?? '',
tools: data.tools as ChatCompletionTool[],
stream: data.model?.parameters?.stream ?? false,
tool_choice: 'auto',
})
// Variables to track and accumulate streaming content
if (!message.content.length) {
message.content = [
{
type: ContentType.Text,
text: {
value: '',
annotations: [],
},
},
]
}
if (data.model?.parameters?.stream)
isDone = await processStreamingResponse(
response as Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder,
message
)
else {
isDone = await processNonStreamingResponse(
response as OpenAI.Chat.Completions.ChatCompletion,
requestBuilder,
message
)
}
}
} else {
// Request for inference
EngineManager.instance()
.get(InferenceEngine.cortex)
?.inference(requestBuilder.build())
}
// Reset states
setReloadModel(false)
setEngineParamsUpdate(false)
}
const processNonStreamingResponse = async (
response: OpenAI.Chat.Completions.ChatCompletion,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Handle tool calls in the response
const toolCalls: ChatCompletionMessageToolCall[] =
response.choices[0]?.message?.tool_calls ?? []
const content = response.choices[0].message?.content
message.content = [
{
type: ContentType.Text,
text: {
value: content ?? '',
annotations: [],
},
},
]
events.emit(MessageEvent.OnMessageUpdate, message)
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
content ?? ''
)
return !toolCalls || !toolCalls.length
}
const processStreamingResponse = async (
response: Stream<OpenAI.Chat.Completions.ChatCompletionChunk>,
requestBuilder: MessageRequestBuilder,
message: ThreadMessage
): Promise<boolean> => {
// Variables to track and accumulate streaming content
let currentToolCall: {
id: string
function: { name: string; arguments: string }
} | null = null
let accumulatedContent = ''
const toolCalls: ChatCompletionMessageToolCall[] = []
// Process the streaming chunks
for await (const chunk of response) {
// Handle tool calls in the chunk
if (chunk.choices[0]?.delta?.tool_calls) {
const deltaToolCalls = chunk.choices[0].delta.tool_calls
// Handle the beginning of a new tool call
if (
deltaToolCalls[0]?.index !== undefined &&
deltaToolCalls[0]?.function
) {
const index = deltaToolCalls[0].index
// Create new tool call if this is the first chunk for it
if (!toolCalls[index]) {
toolCalls[index] = {
id: deltaToolCalls[0]?.id || '',
function: {
name: deltaToolCalls[0]?.function?.name || '',
arguments: deltaToolCalls[0]?.function?.arguments || '',
},
type: 'function',
}
currentToolCall = toolCalls[index]
} else {
// Continuation of existing tool call
currentToolCall = toolCalls[index]
// Append to function name or arguments if they exist in this chunk
if (deltaToolCalls[0]?.function?.name) {
currentToolCall!.function.name += deltaToolCalls[0].function.name
}
if (deltaToolCalls[0]?.function?.arguments) {
currentToolCall!.function.arguments +=
deltaToolCalls[0].function.arguments
}
}
}
}
// Handle regular content in the chunk
if (chunk.choices[0]?.delta?.content) {
const content = chunk.choices[0].delta.content
accumulatedContent += content
message.content = [
{
type: ContentType.Text,
text: {
value: accumulatedContent,
annotations: [],
},
},
]
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
await postMessageProcessing(
toolCalls ?? [],
requestBuilder,
message,
accumulatedContent ?? ''
)
return !toolCalls || !toolCalls.length
}
const postMessageProcessing = async (
toolCalls: ChatCompletionMessageToolCall[],
requestBuilder: MessageRequestBuilder,
message: ThreadMessage,
content: string
) => {
requestBuilder.pushAssistantMessage({
content,
role: 'assistant',
refusal: null,
tool_calls: toolCalls,
})
// Handle completed tool calls
if (toolCalls.length > 0) {
for (const toolCall of toolCalls) {
const toolId = ulid()
const toolCallsMetadata =
message.metadata?.tool_calls &&
Array.isArray(message.metadata?.tool_calls)
? message.metadata?.tool_calls
: []
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: undefined,
state: 'pending',
},
],
}
events.emit(MessageEvent.OnMessageUpdate, message)
const result = await window.core.api.callTool({
toolName: toolCall.function.name,
arguments: JSON.parse(toolCall.function.arguments),
})
if (result.error) break
message.metadata = {
...(message.metadata ?? {}),
tool_calls: [
...toolCallsMetadata,
{
tool: {
...toolCall,
id: toolId,
},
response: result,
state: 'ready',
},
],
}
requestBuilder.pushToolMessage(
result.content[0]?.text ?? '',
toolCall.id
)
events.emit(MessageEvent.OnMessageUpdate, message)
}
}
message.status = MessageStatus.Ready
events.emit(MessageEvent.OnMessageUpdate, message)
}
return {
sendChatMessage,
resendChatMessage,
}
}