From cddaf61c99fbb56fed5ac0ed6d9f68b9e5c4ea39 Mon Sep 17 00:00:00 2001 From: Louis Date: Fri, 21 Feb 2025 00:45:11 +0700 Subject: [PATCH] feat: preserve token speed in the thread (#4711) * feat: preserve token speed in the thread * chore: lint fix --- web/containers/Providers/ModelHandler.tsx | 16 ++++++++++++++-- .../ThreadCenterPanel/TextMessage/index.tsx | 19 +++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/web/containers/Providers/ModelHandler.tsx b/web/containers/Providers/ModelHandler.tsx index cceb88a4c..42ef8afbc 100644 --- a/web/containers/Providers/ModelHandler.tsx +++ b/web/containers/Providers/ModelHandler.tsx @@ -18,7 +18,7 @@ import { extractInferenceParams, ModelExtension, } from '@janhq/core' -import { useAtomValue, useSetAtom } from 'jotai' +import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulidx' import { activeModelAtom, stateModelAtom } from '@/hooks/useActiveModel' @@ -75,8 +75,10 @@ export default function ModelHandler() { const activeThreadRef = useRef(activeThread) const activeModelParams = useAtomValue(getActiveThreadModelParamsAtom) const activeModelParamsRef = useRef(activeModelParams) - const setTokenSpeed = useSetAtom(tokenSpeedAtom) + + const [tokenSpeed, setTokenSpeed] = useAtom(tokenSpeedAtom) const { engines } = useGetEngines() + const tokenSpeedRef = useRef(tokenSpeed) useEffect(() => { activeThreadRef.current = activeThread @@ -106,6 +108,10 @@ export default function ModelHandler() { messageGenerationSubscriber.current = subscribedGeneratingMessage }, [subscribedGeneratingMessage]) + useEffect(() => { + tokenSpeedRef.current = tokenSpeed + }, [tokenSpeed]) + const onNewMessageResponse = useCallback( async (message: ThreadMessage) => { if (message.type === MessageRequestType.Thread) { @@ -275,6 +281,12 @@ export default function ModelHandler() { metadata, }) + // Update message's metadata with token usage + message.metadata = { + ...message.metadata, + token_speed: tokenSpeedRef.current?.tokenSpeed, + } + if (message.status === MessageStatus.Error) { message.metadata = { ...message.metadata, diff --git a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx index ab86a0142..94d923906 100644 --- a/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx +++ b/web/screens/Thread/ThreadCenterPanel/TextMessage/index.tsx @@ -122,13 +122,20 @@ const MessageContainer: React.FC< )} >
- {tokenSpeed && + {((!!tokenSpeed && tokenSpeed.message === props.id && - tokenSpeed.tokenSpeed > 0 && ( -

- Token Speed: {Number(tokenSpeed.tokenSpeed).toFixed(2)}t/s -

- )} + tokenSpeed.tokenSpeed > 0) || + (props.metadata && + 'token_speed' in props.metadata && + !!props.metadata?.token_speed)) && ( +

+ Token Speed:{' '} + {Number( + props.metadata?.token_speed ?? tokenSpeed?.tokenSpeed + ).toFixed(2)} + t/s +

+ )}