fix: token speed should not be calculated based on state updates

This commit is contained in:
Louis 2024-11-29 09:34:38 +07:00
parent 0c37be302c
commit 0485134343
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
5 changed files with 51 additions and 31 deletions

View File

@ -30,6 +30,7 @@ import {
getCurrentChatMessagesAtom, getCurrentChatMessagesAtom,
addNewMessageAtom, addNewMessageAtom,
updateMessageAtom, updateMessageAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import { import {
@ -62,6 +63,7 @@ export default function ModelHandler() {
const activeModelRef = useRef(activeModel) const activeModelRef = useRef(activeModel)
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
const activeModelParamsRef = useRef(activeModelParams) const activeModelParamsRef = useRef(activeModelParams)
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
useEffect(() => { useEffect(() => {
threadsRef.current = threads threadsRef.current = threads
@ -179,6 +181,31 @@ export default function ModelHandler() {
if (message.content.length) { if (message.content.length) {
setIsGeneratingResponse(false) setIsGeneratingResponse(false)
} }
setTokenSpeed((prev) => {
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!prev) {
// If this is the first update, just set the lastTimestamp and return
return {
lastTimestamp: currentTimestamp,
tokenSpeed: 1,
tokenCount: 1,
message: message.id,
}
}
const timeDiffInSeconds =
(currentTimestamp - prev.lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = prev.tokenCount + 1
const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
return {
...prev,
tokenSpeed: averageTokenSpeed,
tokenCount: totalTokenCount,
message: message.id,
}
})
return return
} else if ( } else if (
message.status === MessageStatus.Error && message.status === MessageStatus.Error &&

View File

@ -11,13 +11,22 @@ import {
updateThreadStateLastMessageAtom, updateThreadStateLastMessageAtom,
} from './Thread.atom' } from './Thread.atom'
import { TokenSpeed } from '@/types/token'
/** /**
* Stores all chat messages for all threads * Stores all chat messages for all threads
*/ */
export const chatMessages = atom<Record<string, ThreadMessage[]>>({}) export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
/**
* Stores the status of the messages load for each thread
*/
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({}) export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
/**
* Store the token speed for current message
*/
export const tokenSpeedAtom = atom<TokenSpeed | undefined>(undefined)
/** /**
* Return the chat messages for the current active conversation * Return the chat messages for the current active conversation
*/ */

View File

@ -34,6 +34,7 @@ import {
addNewMessageAtom, addNewMessageAtom,
deleteMessageAtom, deleteMessageAtom,
getCurrentChatMessagesAtom, getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
import { selectedModelAtom } from '@/helpers/atoms/Model.atom' import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
import { import {
@ -45,7 +46,6 @@ import {
updateThreadWaitingForResponseAtom, updateThreadWaitingForResponseAtom,
} from '@/helpers/atoms/Thread.atom' } from '@/helpers/atoms/Thread.atom'
export const queuedMessageAtom = atom(false)
export const reloadModelAtom = atom(false) export const reloadModelAtom = atom(false)
export default function useSendChatMessage() { export default function useSendChatMessage() {
@ -70,7 +70,7 @@ export default function useSendChatMessage() {
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom) const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom) const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
const activeThreadRef = useRef<Thread | undefined>() const activeThreadRef = useRef<Thread | undefined>()
const setQueuedMessage = useSetAtom(queuedMessageAtom) const setTokenSpeed = useSetAtom(tokenSpeedAtom)
const selectedModelRef = useRef<Model | undefined>() const selectedModelRef = useRef<Model | undefined>()
@ -147,6 +147,7 @@ export default function useSendChatMessage() {
} }
if (engineParamsUpdate) setReloadModel(true) if (engineParamsUpdate) setReloadModel(true)
setTokenSpeed(undefined)
const runtimeParams = extractInferenceParams(activeModelParams) const runtimeParams = extractInferenceParams(activeModelParams)
const settingParams = extractModelLoadParams(activeModelParams) const settingParams = extractModelLoadParams(activeModelParams)
@ -231,9 +232,7 @@ export default function useSendChatMessage() {
} }
if (modelRef.current?.id !== modelId) { if (modelRef.current?.id !== modelId) {
setQueuedMessage(true)
const error = await startModel(modelId).catch((error: Error) => error) const error = await startModel(modelId).catch((error: Error) => error)
setQueuedMessage(false)
if (error) { if (error) {
updateThreadWaiting(activeThreadRef.current.id, false) updateThreadWaiting(activeThreadRef.current.id, false)
return return

View File

@ -45,6 +45,7 @@ import { RelativeImage } from './RelativeImage'
import { import {
editMessageAtom, editMessageAtom,
getCurrentChatMessagesAtom, getCurrentChatMessagesAtom,
tokenSpeedAtom,
} from '@/helpers/atoms/ChatMessage.atom' } from '@/helpers/atoms/ChatMessage.atom'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
@ -233,32 +234,9 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
} }
const { onViewFile, onViewFileContainer } = usePath() const { onViewFile, onViewFileContainer } = usePath()
const [tokenCount, setTokenCount] = useState(0) const tokenSpeed = useAtomValue(tokenSpeedAtom)
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
const [tokenSpeed, setTokenSpeed] = useState(0)
const messages = useAtomValue(getCurrentChatMessagesAtom) const messages = useAtomValue(getCurrentChatMessagesAtom)
useEffect(() => {
if (props.status !== MessageStatus.Pending) {
return
}
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
if (!lastTimestamp) {
// If this is the first update, just set the lastTimestamp and return
if (props.content[0]?.text?.value !== '')
setLastTimestamp(currentTimestamp)
return
}
const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = tokenCount + 1
const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed
setTokenSpeed(averageTokenSpeed)
setTokenCount(totalTokenCount)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [props.content])
return ( return (
<div className="group relative mx-auto max-w-[700px] p-4"> <div className="group relative mx-auto max-w-[700px] p-4">
<div <div
@ -308,10 +286,11 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
> >
<MessageToolbar message={props} /> <MessageToolbar message={props} />
</div> </div>
{messages[messages.length - 1]?.id === props.id && {tokenSpeed &&
(props.status === MessageStatus.Pending || tokenSpeed > 0) && ( tokenSpeed.message === props.id &&
tokenSpeed.tokenSpeed > 0 && (
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]"> <p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
Token Speed: {Number(tokenSpeed).toFixed(2)}t/s Token Speed: {Number(tokenSpeed.tokenSpeed).toFixed(2)}t/s
</p> </p>
)} )}
</div> </div>

6
web/types/token.d.ts vendored Normal file
View File

@ -0,0 +1,6 @@
export type TokenSpeed = {
message: string
tokenSpeed: number
tokenCount: number
lastTimestamp: number
}