199 lines
5.2 KiB
TypeScript
199 lines
5.2 KiB
TypeScript
import {
|
|
ChatCompletionRole,
|
|
MessageStatus,
|
|
ThreadContent,
|
|
ThreadMessage,
|
|
} from '@janhq/core'
|
|
import { atom } from 'jotai'
|
|
|
|
import { atomWithStorage } from 'jotai/utils'
|
|
|
|
import {
|
|
getActiveThreadIdAtom,
|
|
updateThreadStateLastMessageAtom,
|
|
} from './Thread.atom'
|
|
|
|
import { TokenSpeed } from '@/types/token'
|
|
|
|
const CHAT_MESSAGE_NAME = 'chatMessages'
|
|
/**
|
|
* Stores all chat messages for all threads
|
|
*/
|
|
export const chatMessagesStorage = atomWithStorage<
|
|
Record<string, ThreadMessage[]>
|
|
>(CHAT_MESSAGE_NAME, {}, undefined, { getOnInit: true })
|
|
|
|
export const cachedMessages = atom<Record<string, ThreadMessage[]>>()
|
|
/**
|
|
* Retrieve chat messages for all threads
|
|
*/
|
|
export const chatMessages = atom(
|
|
(get) => get(cachedMessages) ?? get(chatMessagesStorage),
|
|
(_get, set, newValue: Record<string, ThreadMessage[]>) => {
|
|
set(cachedMessages, newValue)
|
|
;(() => set(chatMessagesStorage, newValue))()
|
|
}
|
|
)
|
|
|
|
/**
|
|
* Store subscribed generating message thread
|
|
*/
|
|
export const subscribedGeneratingMessageAtom = atom<{
|
|
thread_id?: string
|
|
}>({})
|
|
|
|
/**
|
|
* Stores the status of the messages load for each thread
|
|
*/
|
|
export const readyThreadsMessagesAtom = atomWithStorage<
|
|
Record<string, boolean>
|
|
>('currentThreadMessages', {}, undefined, { getOnInit: true })
|
|
|
|
/**
|
|
* Store the token speed for current message
|
|
*/
|
|
export const tokenSpeedAtom = atom<TokenSpeed | undefined>(undefined)
|
|
/**
|
|
* Return the chat messages for the current active conversation
|
|
*/
|
|
export const getCurrentChatMessagesAtom = atom<ThreadMessage[]>((get) => {
|
|
const activeThreadId = get(getActiveThreadIdAtom)
|
|
if (!activeThreadId) return []
|
|
const messages = get(chatMessages)[activeThreadId]
|
|
if (!Array.isArray(messages)) return []
|
|
return messages ?? []
|
|
})
|
|
|
|
export const setConvoMessagesAtom = atom(
|
|
null,
|
|
(get, set, threadId: string, messages: ThreadMessage[]) => {
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[threadId] = messages
|
|
set(chatMessages, newData)
|
|
set(readyThreadsMessagesAtom, {
|
|
...get(readyThreadsMessagesAtom),
|
|
[threadId]: true,
|
|
})
|
|
}
|
|
)
|
|
|
|
/**
|
|
* Used for pagination. Add old messages to the current conversation
|
|
*/
|
|
export const addOldMessagesAtom = atom(
|
|
null,
|
|
(get, set, newMessages: ThreadMessage[]) => {
|
|
const currentConvoId = get(getActiveThreadIdAtom)
|
|
if (!currentConvoId) return
|
|
|
|
const currentMessages = get(chatMessages)[currentConvoId] ?? []
|
|
const updatedMessages = [...currentMessages, ...newMessages]
|
|
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[currentConvoId] = updatedMessages
|
|
set(chatMessages, newData)
|
|
}
|
|
)
|
|
|
|
export const addNewMessageAtom = atom(
|
|
null,
|
|
(get, set, newMessage: ThreadMessage) => {
|
|
const currentMessages = get(chatMessages)[newMessage.thread_id] ?? []
|
|
const updatedMessages = [...currentMessages, newMessage]
|
|
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[newMessage.thread_id] = updatedMessages
|
|
set(chatMessages, newData)
|
|
|
|
// Update thread last message
|
|
if (newMessage.content.length)
|
|
set(
|
|
updateThreadStateLastMessageAtom,
|
|
newMessage.thread_id,
|
|
newMessage.content
|
|
)
|
|
}
|
|
)
|
|
|
|
export const deleteChatMessageAtom = atom(
|
|
null,
|
|
(get, set, threadId: string) => {
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[threadId] = []
|
|
set(chatMessages, newData)
|
|
}
|
|
)
|
|
|
|
export const cleanChatMessageAtom = atom(null, (get, set, id: string) => {
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[id] = newData[id]?.filter((e) => e.role === ChatCompletionRole.System)
|
|
set(chatMessages, newData)
|
|
})
|
|
|
|
export const deleteMessageAtom = atom(null, (get, set, id: string) => {
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
const threadId = get(getActiveThreadIdAtom)
|
|
if (threadId) {
|
|
// Should also delete error messages to clear out the error state
|
|
newData[threadId] = newData[threadId].filter(
|
|
(e) => e.id !== id && !e.metadata?.error
|
|
)
|
|
|
|
set(chatMessages, newData)
|
|
}
|
|
})
|
|
|
|
export const editMessageAtom = atom('')
|
|
|
|
export const updateMessageAtom = atom(
|
|
null,
|
|
(
|
|
get,
|
|
set,
|
|
id: string,
|
|
conversationId: string,
|
|
text: ThreadContent[],
|
|
status: MessageStatus
|
|
) => {
|
|
const messages = get(chatMessages)[conversationId] ?? []
|
|
const message = messages.find((e) => e.id === id)
|
|
if (message) {
|
|
message.content = text
|
|
message.status = status
|
|
const updatedMessages = [...messages]
|
|
|
|
const newData: Record<string, ThreadMessage[]> = {
|
|
...get(chatMessages),
|
|
}
|
|
newData[conversationId] = updatedMessages
|
|
set(chatMessages, newData)
|
|
// Update thread last message
|
|
if (text.length)
|
|
set(updateThreadStateLastMessageAtom, conversationId, text)
|
|
} else {
|
|
set(addNewMessageAtom, {
|
|
id,
|
|
thread_id: conversationId,
|
|
content: text,
|
|
status,
|
|
role: ChatCompletionRole.Assistant,
|
|
created_at: Date.now() / 1000,
|
|
completed_at: Date.now() / 1000,
|
|
object: 'thread.message',
|
|
})
|
|
}
|
|
}
|
|
)
|