fix: token speed should not be calculated based on state updates
This commit is contained in:
parent
0c37be302c
commit
0485134343
@ -30,6 +30,7 @@ import {
|
|||||||
getCurrentChatMessagesAtom,
|
getCurrentChatMessagesAtom,
|
||||||
addNewMessageAtom,
|
addNewMessageAtom,
|
||||||
updateMessageAtom,
|
updateMessageAtom,
|
||||||
|
tokenSpeedAtom,
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
@ -62,6 +63,7 @@ export default function ModelHandler() {
|
|||||||
const activeModelRef = useRef(activeModel)
|
const activeModelRef = useRef(activeModel)
|
||||||
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom)
|
||||||
const activeModelParamsRef = useRef(activeModelParams)
|
const activeModelParamsRef = useRef(activeModelParams)
|
||||||
|
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
threadsRef.current = threads
|
threadsRef.current = threads
|
||||||
@ -179,6 +181,31 @@ export default function ModelHandler() {
|
|||||||
if (message.content.length) {
|
if (message.content.length) {
|
||||||
setIsGeneratingResponse(false)
|
setIsGeneratingResponse(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setTokenSpeed((prev) => {
|
||||||
|
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
||||||
|
if (!prev) {
|
||||||
|
// If this is the first update, just set the lastTimestamp and return
|
||||||
|
return {
|
||||||
|
lastTimestamp: currentTimestamp,
|
||||||
|
tokenSpeed: 1,
|
||||||
|
tokenCount: 1,
|
||||||
|
message: message.id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const timeDiffInSeconds =
|
||||||
|
(currentTimestamp - prev.lastTimestamp) / 1000 // Time difference in seconds
|
||||||
|
const totalTokenCount = prev.tokenCount + 1
|
||||||
|
const averageTokenSpeed =
|
||||||
|
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
|
||||||
|
return {
|
||||||
|
...prev,
|
||||||
|
tokenSpeed: averageTokenSpeed,
|
||||||
|
tokenCount: totalTokenCount,
|
||||||
|
message: message.id,
|
||||||
|
}
|
||||||
|
})
|
||||||
return
|
return
|
||||||
} else if (
|
} else if (
|
||||||
message.status === MessageStatus.Error &&
|
message.status === MessageStatus.Error &&
|
||||||
|
|||||||
@ -11,13 +11,22 @@ import {
|
|||||||
updateThreadStateLastMessageAtom,
|
updateThreadStateLastMessageAtom,
|
||||||
} from './Thread.atom'
|
} from './Thread.atom'
|
||||||
|
|
||||||
|
import { TokenSpeed } from '@/types/token'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stores all chat messages for all threads
|
* Stores all chat messages for all threads
|
||||||
*/
|
*/
|
||||||
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
export const chatMessages = atom<Record<string, ThreadMessage[]>>({})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores the status of the messages load for each thread
|
||||||
|
*/
|
||||||
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
export const readyThreadsMessagesAtom = atom<Record<string, boolean>>({})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Store the token speed for current message
|
||||||
|
*/
|
||||||
|
export const tokenSpeedAtom = atom<TokenSpeed | undefined>(undefined)
|
||||||
/**
|
/**
|
||||||
* Return the chat messages for the current active conversation
|
* Return the chat messages for the current active conversation
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -34,6 +34,7 @@ import {
|
|||||||
addNewMessageAtom,
|
addNewMessageAtom,
|
||||||
deleteMessageAtom,
|
deleteMessageAtom,
|
||||||
getCurrentChatMessagesAtom,
|
getCurrentChatMessagesAtom,
|
||||||
|
tokenSpeedAtom,
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
import { selectedModelAtom } from '@/helpers/atoms/Model.atom'
|
||||||
import {
|
import {
|
||||||
@ -45,7 +46,6 @@ import {
|
|||||||
updateThreadWaitingForResponseAtom,
|
updateThreadWaitingForResponseAtom,
|
||||||
} from '@/helpers/atoms/Thread.atom'
|
} from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
export const queuedMessageAtom = atom(false)
|
|
||||||
export const reloadModelAtom = atom(false)
|
export const reloadModelAtom = atom(false)
|
||||||
|
|
||||||
export default function useSendChatMessage() {
|
export default function useSendChatMessage() {
|
||||||
@ -70,7 +70,7 @@ export default function useSendChatMessage() {
|
|||||||
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
const [fileUpload, setFileUpload] = useAtom(fileUploadAtom)
|
||||||
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
const setIsGeneratingResponse = useSetAtom(isGeneratingResponseAtom)
|
||||||
const activeThreadRef = useRef<Thread | undefined>()
|
const activeThreadRef = useRef<Thread | undefined>()
|
||||||
const setQueuedMessage = useSetAtom(queuedMessageAtom)
|
const setTokenSpeed = useSetAtom(tokenSpeedAtom)
|
||||||
|
|
||||||
const selectedModelRef = useRef<Model | undefined>()
|
const selectedModelRef = useRef<Model | undefined>()
|
||||||
|
|
||||||
@ -147,6 +147,7 @@ export default function useSendChatMessage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (engineParamsUpdate) setReloadModel(true)
|
if (engineParamsUpdate) setReloadModel(true)
|
||||||
|
setTokenSpeed(undefined)
|
||||||
|
|
||||||
const runtimeParams = extractInferenceParams(activeModelParams)
|
const runtimeParams = extractInferenceParams(activeModelParams)
|
||||||
const settingParams = extractModelLoadParams(activeModelParams)
|
const settingParams = extractModelLoadParams(activeModelParams)
|
||||||
@ -231,9 +232,7 @@ export default function useSendChatMessage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (modelRef.current?.id !== modelId) {
|
if (modelRef.current?.id !== modelId) {
|
||||||
setQueuedMessage(true)
|
|
||||||
const error = await startModel(modelId).catch((error: Error) => error)
|
const error = await startModel(modelId).catch((error: Error) => error)
|
||||||
setQueuedMessage(false)
|
|
||||||
if (error) {
|
if (error) {
|
||||||
updateThreadWaiting(activeThreadRef.current.id, false)
|
updateThreadWaiting(activeThreadRef.current.id, false)
|
||||||
return
|
return
|
||||||
|
|||||||
@ -45,6 +45,7 @@ import { RelativeImage } from './RelativeImage'
|
|||||||
import {
|
import {
|
||||||
editMessageAtom,
|
editMessageAtom,
|
||||||
getCurrentChatMessagesAtom,
|
getCurrentChatMessagesAtom,
|
||||||
|
tokenSpeedAtom,
|
||||||
} from '@/helpers/atoms/ChatMessage.atom'
|
} from '@/helpers/atoms/ChatMessage.atom'
|
||||||
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
|
||||||
|
|
||||||
@ -233,32 +234,9 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const { onViewFile, onViewFileContainer } = usePath()
|
const { onViewFile, onViewFileContainer } = usePath()
|
||||||
const [tokenCount, setTokenCount] = useState(0)
|
const tokenSpeed = useAtomValue(tokenSpeedAtom)
|
||||||
const [lastTimestamp, setLastTimestamp] = useState<number | undefined>()
|
|
||||||
const [tokenSpeed, setTokenSpeed] = useState(0)
|
|
||||||
const messages = useAtomValue(getCurrentChatMessagesAtom)
|
const messages = useAtomValue(getCurrentChatMessagesAtom)
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (props.status !== MessageStatus.Pending) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
|
|
||||||
if (!lastTimestamp) {
|
|
||||||
// If this is the first update, just set the lastTimestamp and return
|
|
||||||
if (props.content[0]?.text?.value !== '')
|
|
||||||
setLastTimestamp(currentTimestamp)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const timeDiffInSeconds = (currentTimestamp - lastTimestamp) / 1000 // Time difference in seconds
|
|
||||||
const totalTokenCount = tokenCount + 1
|
|
||||||
const averageTokenSpeed = totalTokenCount / timeDiffInSeconds // Calculate average token speed
|
|
||||||
|
|
||||||
setTokenSpeed(averageTokenSpeed)
|
|
||||||
setTokenCount(totalTokenCount)
|
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
|
||||||
}, [props.content])
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="group relative mx-auto max-w-[700px] p-4">
|
<div className="group relative mx-auto max-w-[700px] p-4">
|
||||||
<div
|
<div
|
||||||
@ -308,10 +286,11 @@ const SimpleTextMessage: React.FC<ThreadMessage> = (props) => {
|
|||||||
>
|
>
|
||||||
<MessageToolbar message={props} />
|
<MessageToolbar message={props} />
|
||||||
</div>
|
</div>
|
||||||
{messages[messages.length - 1]?.id === props.id &&
|
{tokenSpeed &&
|
||||||
(props.status === MessageStatus.Pending || tokenSpeed > 0) && (
|
tokenSpeed.message === props.id &&
|
||||||
|
tokenSpeed.tokenSpeed > 0 && (
|
||||||
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
|
<p className="absolute right-8 text-xs font-medium text-[hsla(var(--text-secondary))]">
|
||||||
Token Speed: {Number(tokenSpeed).toFixed(2)}t/s
|
Token Speed: {Number(tokenSpeed.tokenSpeed).toFixed(2)}t/s
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
6
web/types/token.d.ts
vendored
Normal file
6
web/types/token.d.ts
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
export type TokenSpeed = {
|
||||||
|
message: string
|
||||||
|
tokenSpeed: number
|
||||||
|
tokenCount: number
|
||||||
|
lastTimestamp: number
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user