fix: token speed should not be calculated based on state updates
This commit is contained in:
parent
0c37be302c
commit
0485134343
@ -30,6 +30,7 @@ import {
|
||||
getCurrentChatMessagesAtom,
|
||||
addNewMessageAtom,
|
||||
updateMessageAtom,
|
||||
tokenSpeedAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
@ -62,6 +63,7 @@ export default function ModelHandler() {
|
||||
const activeModelRef = useRef(activeModel)
|
||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||
const activeModelParamsRef = useRef(activeModelParams)
|
||||
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||
|
||||
useEffect(() => {
|
||||
threadsRef.current = threads
|
||||
@ -179,6 +181,31 @@ export default function ModelHandler() {
|
||||
if (message.content.length) {
|
||||
setIsGeneratingResponse(false)
|
||||
}
|
||||
|
||||
setTokenSpeed((prev) => {
|
||||
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
||||
if (!prev) {
|
||||
// If this is the first update, just set the lastTimestamp and return
|
||||
return {
|
||||
lastTimestamp: currentTimestamp,
|
||||
tokenSpeed: 1,
|
||||
tokenCount: 1,
|
||||
message: message.id,
|
||||
}
|
||||
}
|
||||
|
||||
const timeDiffInSeconds =
|
||||
(currentTimestamp - prev.lastTimestamp) / 1000 // Time difference in seconds
|
||||
const totalTokenCount = prev.tokenCount + 1
|
||||
const averageTokenSpeed =
|
||||
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
|
||||
return {
|
||||
...prev,
|
||||
tokenSpeed: averageTokenSpeed,
|
||||
tokenCount: totalTokenCount,
|
||||
message: message.id,
|
||||
}
|
||||
})
|
||||
return
|
||||
} else if (
|
||||
message.status === MessageStatus.Error &&
|
||||
|
||||
@ -11,13 +11,22 @@ import {
|
||||
updateThreadStateLastMessageAtom,
|
||||
} from './Thread.atom'
|
||||
|
||||
import { TokenSpeed } from '@/types/token'
|
||||
|
||||
/**
|
||||
* Stores all chat messages for all threads
|
||||
*/
|
||||
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
||||
|
||||
/**
|
||||
* Stores the status of the messages load for each thread
|
||||
*/
|
||||
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
||||
|
||||
/**
|
||||
* Store the token speed for current message
|
||||
*/
|
||||
export const tokenSpeedAtom = atom<TokenSpeed | undefined>(undefined)
|
||||
/**
|
||||
* Return the chat messages for the current active conversation
|
||||
*/
|
||||
|
||||
@ -34,6 +34,7 @@ import {
|
||||
addNewMessageAtom,
|
||||
deleteMessageAtom,
|
||||
getCurrentChatMessagesAtom,
|
||||
tokenSpeedAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||
import {
|
||||
@ -45,7 +46,6 @@ import {
|
||||
updateThreadWaitingForResponseAtom,
|
||||
} from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
export const queuedMessageAtom = atom(false)
|
||||
export const reloadModelAtom = atom(false)
|
||||
|
||||
export default function useSendChatMessage() {
|
||||
@ -70,7 +70,7 @@ export default function useSendChatMessage() {
|
||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||
const activeThreadRef = useRef<Thread | undefined>()
|
||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
||||
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||
|
||||
const selectedModelRef = useRef<Model | undefined>()
|
||||
|
||||
@ -147,6 +147,7 @@ export default function useSendChatMessage() {
|
||||
}
|
||||
|
||||
if (engineParamsUpdate) setReloadModel(true)
|
||||
setTokenSpeed(undefined)
|
||||
|
||||
const runtimeParams = extractInferenceParams(activeModelParams)
|
||||
const settingParams = extractModelLoadParams(activeModelParams)
|
||||
@ -231,9 +232,7 @@ export default function useSendChatMessage() {
|
||||
}
|
||||
|
||||
if (modelRef.current?.id !== modelId) {
|
||||
setQueuedMessage(true)
|
||||
const error = await startModel(modelId).catch((error: Error) => error)
|
||||
setQueuedMessage(false)
|
||||
if (error) {
|
||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||
return
|
||||
|
||||
@ -45,6 +45,7 @@ import { RelativeImage } from './RelativeImage'
|
||||
import {
|
||||
editMessageAtom,
|
||||
getCurrentChatMessagesAtom,
|
||||
tokenSpeedAtom,
|
||||
} from '@/helpers/atoms/ChatMessage.atom'
|
||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||
|
||||
@ -233,32 +234,9 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
|
||||
}
|
||||
|
||||
const { onViewFile, onViewFileContainer } = usePath()
|
||||
const [tokenCount, setTokenCount] = useState(0)
|
||||
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
|
||||
const [tokenSpeed, setTokenSpeed] = useState(0)
|
||||
const tokenSpeed = useAtomValue(tokenSpeedAtom)
|
||||
const messages = useAtomValue(getCurrentChatMessagesAtom)
|
||||
|
||||
useEffect(() => {
|
||||
if (props.status !== MessageStatus.Pending) {
|
||||
return
|
||||
}
|
||||
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
||||
if (!lastTimestamp) {
|
||||
// If this is the first update, just set the lastTimestamp and return
|
||||
if (props.content[0]?.text?.value !== '')
|
||||
setLastTimestamp(currentTimestamp)
|
||||
return
|
||||
}
|
||||
|
||||
const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds
|
||||
const totalTokenCount = tokenCount + 1
|
||||
const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed
|
||||
|
||||
setTokenSpeed(averageTokenSpeed)
|
||||
setTokenCount(totalTokenCount)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [props.content])
|
||||
|
||||
return (
|
||||
<div className="group relative mx-auto max-w-[700px] p-4">
|
||||
<div
|
||||
@ -308,10 +286,11 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
|
||||
>
|
||||
<MessageToolbar message={props} />
|
||||
</div>
|
||||
{messages[messages.length - 1]?.id === props.id &&
|
||||
(props.status === MessageStatus.Pending || tokenSpeed > 0) && (
|
||||
{tokenSpeed &&
|
||||
tokenSpeed.message === props.id &&
|
||||
tokenSpeed.tokenSpeed > 0 && (
|
||||
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
|
||||
Token Speed: {Number(tokenSpeed).toFixed(2)}t/s
|
||||
Token Speed: {Number(tokenSpeed.tokenSpeed).toFixed(2)}t/s
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
6
web/types/token.d.ts
vendored
Normal file
6
web/types/token.d.ts
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
export type TokenSpeed = {
|
||||
message: string
|
||||
tokenSpeed: number
|
||||
tokenCount: number
|
||||
lastTimestamp: number
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user