feat: add chunk count (#3290)

* feat: add chunk count

* bump cortex version
This commit is contained in:
NamH 2024-08-07 13:58:51 +07:00 committed by GitHub
parent 2201e6c5f8
commit bad481bf05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 105 additions and 57 deletions

View File

@ -1 +1 @@
0.5.0-30
0.5.0-31

View File

@ -5,6 +5,10 @@ import { getActiveThreadIdAtom } from './Thread.atom'
const chatMessages = atom<Record<string, Message[]>>({})
export const disableStopInferenceAtom = atom(false)
export const chunkCountAtom = atom<Record<string, number>>({})
/**
* Return the chat messages for the current active thread
*/

View File

@ -36,6 +36,8 @@ import useModelStart from './useModelStart'
import {
addNewMessageAtom,
chunkCountAtom,
disableStopInferenceAtom,
getCurrentChatMessagesAtom,
updateMessageAtom,
} from '@/helpers/atoms/ChatMessage.atom'
@ -104,6 +106,9 @@ const useSendMessage = () => {
showWarningMultipleModelModalAtom
)
const setDisableStopInference = useSetAtom(disableStopInferenceAtom)
const setChunkCount = useSetAtom(chunkCountAtom)
const validatePrerequisite = useCallback(async (): Promise<boolean> => {
const errorTitle = 'Failed to send message'
if (!activeThread) {
@ -361,7 +366,12 @@ const useSendMessage = () => {
addNewMessage(responseMessage)
let chunkCount = 1
for await (const chunk of stream) {
setChunkCount((prev) => ({
...prev,
[responseMessage.id]: chunkCount++,
}))
const content = chunk.choices[0]?.delta?.content || ''
assistantResponseMessage += content
const messageContent: MessageContent = {
@ -579,6 +589,7 @@ const useSendMessage = () => {
let assistantResponseMessage = ''
try {
if (selectedModel!.stream === true) {
setDisableStopInference(true)
const stream = await chatCompletionStreaming({
messages,
model: selectedModel!.model,
@ -623,7 +634,14 @@ const useSendMessage = () => {
addNewMessage(responseMessage)
let chunkCount = 1
for await (const chunk of stream) {
setChunkCount((prev) => ({
...prev,
[responseMessage.id]: chunkCount++,
}))
// we have first chunk, enable the inference button
setDisableStopInference(false)
const content = chunk.choices[0]?.delta?.content || ''
assistantResponseMessage += content
const messageContent: MessageContent = {
@ -737,6 +755,7 @@ const useSendMessage = () => {
})
}
setDisableStopInference(false)
setIsGeneratingResponse(false)
shouldSummarize = false
@ -780,6 +799,8 @@ const useSendMessage = () => {
chatCompletionStreaming,
summarizeThread,
setShowWarningMultipleModelModal,
setDisableStopInference,
setChunkCount,
]
)

View File

@ -2,20 +2,28 @@ import React from 'react'
import { Button } from '@janhq/joi'
import { useAtomValue } from 'jotai'
import { StopCircle } from 'lucide-react'
import { disableStopInferenceAtom } from '@/helpers/atoms/ChatMessage.atom'
type Props = {
onStopInferenceClick: () => void
}
const StopInferenceButton: React.FC<Props> = ({ onStopInferenceClick }) => (
<Button
theme="destructive"
onClick={onStopInferenceClick}
className="h-8 w-8 rounded-lg p-0"
>
<StopCircle size={20} />
</Button>
)
const StopInferenceButton: React.FC<Props> = ({ onStopInferenceClick }) => {
const disabled = useAtomValue(disableStopInferenceAtom)
return (
<Button
disabled={disabled}
theme="destructive"
onClick={onStopInferenceClick}
className="h-8 w-8 rounded-lg p-0"
>
<StopCircle size={20} />
</Button>
)
}
export default React.memo(StopInferenceButton)

View File

@ -0,0 +1,59 @@
import { useEffect, useMemo, useState } from 'react'
import { Message, TextContentBlock } from '@janhq/core'
import { useAtomValue } from 'jotai'
import { chunkCountAtom } from '@/helpers/atoms/ChatMessage.atom'
type Props = {
message: Message
}
const TokenCount: React.FC<Props> = ({ message }) => {
const chunkCountMap = useAtomValue(chunkCountAtom)
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
const [tokenSpeed, setTokenSpeed] = useState(0)
const receivedChunkCount = useMemo(
() => chunkCountMap[message.id] ?? 0,
[chunkCountMap, message.id]
)
useEffect(() => {
if (message.status !== 'in_progress') {
return
}
const currentTimestamp = Date.now()
if (!lastTimestamp) {
// If this is the first update, just set the lastTimestamp and return
if (message.content && message.content.length > 0) {
const messageContent = message.content[0]
if (messageContent && messageContent.type === 'text') {
const textContentBlock = messageContent as TextContentBlock
if (textContentBlock.text.value !== '') {
setLastTimestamp(currentTimestamp)
}
}
}
return
}
const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000
const averageTokenSpeed = receivedChunkCount / timeDiffInSeconds
setTokenSpeed(averageTokenSpeed)
}, [message.content, lastTimestamp, receivedChunkCount, message.status])
if (tokenSpeed === 0) return null
return (
<div className="absolute right-8 flex flex-row text-xs font-medium text-[hsla(var(--text-secondary))]">
<p>
Token count: {receivedChunkCount}, speed:{' '}
{Number(tokenSpeed).toFixed(2)}t/s
</p>
</div>
)
}
export default TokenCount

View File

@ -28,6 +28,8 @@ import { openFileTitle } from '@/utils/titleUtils'
import EditChatInput from '../EditChatInput'
import MessageToolbar from '../MessageToolbar'
import TokenCount from './components/TokenCount'
import { editMessageAtom } from '@/helpers/atoms/ChatMessage.atom'
type Props = {
@ -114,9 +116,6 @@ const SimpleTextMessage: React.FC<Props> = ({
const isUser = msg.role === 'user'
const { onViewFileContainer } = usePath()
const parsedText = useMemo(() => marked.parse(text), [marked, text])
const [tokenCount, setTokenCount] = useState(0)
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
const [tokenSpeed, setTokenSpeed] = useState(0)
const codeBlockCopyEvent = useRef((e: Event) => {
const target: HTMLElement = e.target as HTMLElement
@ -138,34 +137,6 @@ const SimpleTextMessage: React.FC<Props> = ({
}
}, [])
useEffect(() => {
if (msg.status !== 'in_progress') {
return
}
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!lastTimestamp) {
// If this is the first update, just set the lastTimestamp and return
if (msg.content && msg.content.length > 0) {
const message = msg.content[0]
if (message && message.type === 'text') {
const textContentBlock = message as TextContentBlock
if (textContentBlock.text.value !== '') {
setLastTimestamp(currentTimestamp)
}
}
}
return
}
const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = tokenCount + 1
const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed
setTokenSpeed(averageTokenSpeed)
setTokenCount(totalTokenCount)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [msg.content])
return (
<div className="group relative mx-auto max-w-[700px] p-4">
<div
@ -175,7 +146,6 @@ const SimpleTextMessage: React.FC<Props> = ({
)}
>
{isUser ? <UserAvatar /> : <LogoMark width={32} height={32} />}
<div
className={twMerge(
'font-extrabold capitalize',
@ -201,12 +171,7 @@ const SimpleTextMessage: React.FC<Props> = ({
onResendMessage={onResendMessage}
/>
</div>
{isLatestMessage &&
(msg.status === 'in_progress' || tokenSpeed > 0) && (
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
Token Speed: {Number(tokenSpeed).toFixed(2)}t/s
</p>
)}
{isLatestMessage && <TokenCount message={msg} />}
</div>
<div
@ -218,15 +183,6 @@ const SimpleTextMessage: React.FC<Props> = ({
<Fragment>
{msg.content[0]?.type === 'image_file' && (
<div className="group/image relative mb-2 inline-flex cursor-pointer overflow-hidden rounded-xl">
<div className="left-0 top-0 z-20 h-full w-full group-hover/image:inline-block">
{/* <RelativeImage */}
{/* src={msg.content[0]?.text.annotations[0]} */}
{/* id={msg.id} */}
{/* onClick={() => */}
{/* onViewFile(`${msg.content[0]?.text.annotations[0]}`) */}
{/* } */}
{/* /> */}
</div>
<Tooltip
trigger={
<div