feat: add token speed to each message that persist

This commit is contained in:
LazyYuuki 2025-06-15 18:30:39 +08:00
parent 3ae4d12f60
commit 665de7df55
6 changed files with 236 additions and 200 deletions

View File

@ -48,7 +48,7 @@ type ChatInputProps = {
const ChatInput = ({
model,
className,
showSpeedToken = false,
showSpeedToken = true,
initialMessage,
}: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null)

View File

@ -34,6 +34,9 @@ import {
} from '@/components/ui/tooltip'
import { formatDate } from '@/utils/formatDate'
import { AvatarEmoji } from '@/containers/AvatarEmoji'
import TokenSpeedIndicator from '@/containers/TokenSpeedIndicator'
import CodeEditor from '@uiw/react-textarea-code-editor'
import '@uiw/react-textarea-code-editor/dist.css'
@ -360,8 +363,8 @@ export const ThreadContent = memo(
className={cn(
'flex items-center gap-2',
item.isLastMessage &&
streamingContent &&
'opacity-0 visibility-hidden pointer-events-none'
streamingContent &&
'opacity-0 visibility-hidden pointer-events-none'
)}
>
<CopyButton text={item.content?.[0]?.text.value || ''} />
@ -445,6 +448,11 @@ export const ThreadContent = memo(
</TooltipContent>
</Tooltip>
)}
<TokenSpeedIndicator
messageId={item.id}
metadata={item.metadata}
/>
</div>
</div>
)}

View File

@ -0,0 +1,22 @@
import { IconBrandSpeedtest } from '@tabler/icons-react'
interface TokenSpeedIndicatorProps {
metadata?: Record<string, unknown>
}
export const TokenSpeedIndicator = ({
metadata
}: TokenSpeedIndicatorProps) => {
const persistedTokenSpeed = (metadata?.tokenSpeed as { tokenSpeed: number })?.tokenSpeed
return (
<div className="flex items-center gap-1 text-main-view-fg/60 text-xs">
<IconBrandSpeedtest size={16} />
<span>
{Math.round(persistedTokenSpeed)} tokens/sec
</span>
</div>
)
}
export default TokenSpeedIndicator

View File

@ -1,36 +1,36 @@
import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core'
import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { create } from "zustand";
import { ThreadMessage } from "@janhq/core";
import { MCPTool } from "@/types/completion";
import { useAssistant } from "./useAssistant";
import { ChatCompletionMessageToolCall } from "openai/resources";
type AppState = {
streamingContent?: ThreadMessage
loadingModel?: boolean
tools: MCPTool[]
serverStatus: 'running' | 'stopped' | 'pending'
abortControllers: Record<string, AbortController>
tokenSpeed?: TokenSpeed
currentToolCall?: ChatCompletionMessageToolCall
showOutOfContextDialog?: boolean
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void
updateStreamingContent: (content: ThreadMessage | undefined) => void
streamingContent?: ThreadMessage;
loadingModel?: boolean;
tools: MCPTool[];
serverStatus: "running" | "stopped" | "pending";
abortControllers: Record<string, AbortController>;
tokenSpeed?: TokenSpeed;
currentToolCall?: ChatCompletionMessageToolCall;
showOutOfContextDialog?: boolean;
setServerStatus: (value: "running" | "stopped" | "pending") => void;
updateStreamingContent: (content: ThreadMessage | undefined) => void;
updateCurrentToolCall: (
toolCall: ChatCompletionMessageToolCall | undefined
) => void
updateLoadingModel: (loading: boolean) => void
updateTools: (tools: MCPTool[]) => void
setAbortController: (threadId: string, controller: AbortController) => void
updateTokenSpeed: (message: ThreadMessage) => void
resetTokenSpeed: () => void
setOutOfContextDialog: (show: boolean) => void
}
toolCall: ChatCompletionMessageToolCall | undefined,
) => void;
updateLoadingModel: (loading: boolean) => void;
updateTools: (tools: MCPTool[]) => void;
setAbortController: (threadId: string, controller: AbortController) => void;
updateTokenSpeed: (message: ThreadMessage) => void;
resetTokenSpeed: () => void;
setOutOfContextDialog: (show: boolean) => void;
};
export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined,
loadingModel: false,
tools: [],
serverStatus: 'stopped',
serverStatus: "stopped",
abortControllers: {},
tokenSpeed: undefined,
currentToolCall: undefined,
@ -46,18 +46,19 @@ export const useAppState = create<AppState>()((set) => ({
},
}
: undefined,
}))
}));
console.log(useAppState.getState().streamingContent);
},
updateCurrentToolCall: (toolCall) => {
set(() => ({
currentToolCall: toolCall,
}))
}));
},
updateLoadingModel: (loading) => {
set({ loadingModel: loading })
set({ loadingModel: loading });
},
updateTools: (tools) => {
set({ tools })
set({ tools });
},
setServerStatus: (value) => set({ serverStatus: value }),
setAbortController: (threadId, controller) => {
@ -66,11 +67,11 @@ export const useAppState = create<AppState>()((set) => ({
...state.abortControllers,
[threadId]: controller,
},
}))
}));
},
updateTokenSpeed: (message) =>
set((state) => {
const currentTimestamp = new Date().getTime() // Get current time in milliseconds
const currentTimestamp = new Date().getTime(); // Get current time in milliseconds
if (!state.tokenSpeed) {
// If this is the first update, just set the lastTimestamp and return
return {
@ -80,14 +81,14 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: 1,
message: message.id,
},
}
};
}
const timeDiffInSeconds =
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + 1
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000; // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + 1;
const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1); // Calculate average token speed
return {
tokenSpeed: {
...state.tokenSpeed,
@ -95,7 +96,7 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: totalTokenCount,
message: message.id,
},
}
};
}),
resetTokenSpeed: () =>
set({
@ -104,6 +105,6 @@ export const useAppState = create<AppState>()((set) => ({
setOutOfContextDialog: (show) => {
set(() => ({
showOutOfContextDialog: show,
}))
}));
},
}))
}));

View File

@ -1,12 +1,12 @@
import { useCallback, useEffect, useMemo } from 'react'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
import { route } from '@/constants/routes'
import { useCallback, useEffect, useMemo } from "react";
import { usePrompt } from "./usePrompt";
import { useModelProvider } from "./useModelProvider";
import { useThreads } from "./useThreads";
import { useAppState } from "./useAppState";
import { useMessages } from "./useMessages";
import { useRouter } from "@tanstack/react-router";
import { defaultModel } from "@/lib/models";
import { route } from "@/constants/routes";
import {
emptyThreadContent,
extractToolCall,
@ -15,23 +15,23 @@ import {
newUserThreadContent,
postMessageProcessing,
sendCompletion,
} from '@/lib/completion'
import { CompletionMessagesBuilder } from '@/lib/messages'
import { ChatCompletionMessageToolCall } from 'openai/resources'
import { useAssistant } from './useAssistant'
import { toast } from 'sonner'
import { getTools } from '@/services/mcp'
import { MCPTool } from '@/types/completion'
import { listen } from '@tauri-apps/api/event'
import { SystemEvent } from '@/types/events'
import { stopModel, startModel, stopAllModels } from '@/services/models'
} from "@/lib/completion";
import { CompletionMessagesBuilder } from "@/lib/messages";
import { ChatCompletionMessageToolCall } from "openai/resources";
import { useAssistant } from "./useAssistant";
import { toast } from "sonner";
import { getTools } from "@/services/mcp";
import { MCPTool } from "@/types/completion";
import { listen } from "@tauri-apps/api/event";
import { SystemEvent } from "@/types/events";
import { stopModel, startModel, stopAllModels } from "@/services/models";
import { useToolApproval } from '@/hooks/useToolApproval'
import { useToolAvailable } from '@/hooks/useToolAvailable'
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error'
import { useToolApproval } from "@/hooks/useToolApproval";
import { useToolAvailable } from "@/hooks/useToolAvailable";
import { OUT_OF_CONTEXT_SIZE } from "@/utils/error";
export const useChat = () => {
const { prompt, setPrompt } = usePrompt()
const { prompt, setPrompt } = usePrompt();
const {
tools,
updateTokenSpeed,
@ -40,51 +40,51 @@ export const useChat = () => {
updateStreamingContent,
updateLoadingModel,
setAbortController,
} = useAppState()
const { currentAssistant } = useAssistant()
const { updateProvider } = useModelProvider()
} = useAppState();
const { currentAssistant } = useAssistant();
const { updateProvider } = useModelProvider();
const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
useToolApproval()
const { getDisabledToolsForThread } = useToolAvailable()
useToolApproval();
const { getDisabledToolsForThread } = useToolAvailable();
const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider()
useModelProvider();
const {
getCurrentThread: retrieveThread,
createThread,
updateThreadTimestamp,
} = useThreads()
const { getMessages, addMessage } = useMessages()
const router = useRouter()
} = useThreads();
const { getMessages, addMessage } = useMessages();
const router = useRouter();
const provider = useMemo(() => {
return getProviderByName(selectedProvider)
}, [selectedProvider, getProviderByName])
return getProviderByName(selectedProvider);
}, [selectedProvider, getProviderByName]);
const currentProviderId = useMemo(() => {
return provider?.provider || selectedProvider
}, [provider, selectedProvider])
return provider?.provider || selectedProvider;
}, [provider, selectedProvider]);
useEffect(() => {
function setTools() {
getTools().then((data: MCPTool[]) => {
updateTools(data)
})
updateTools(data);
});
}
setTools()
setTools();
let unsubscribe = () => {}
let unsubscribe = () => {};
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts
unsubscribe = unsub
})
return unsubscribe
}, [updateTools])
unsubscribe = unsub;
});
return unsubscribe;
}, [updateTools]);
const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread()
let currentThread = retrieveThread();
if (!currentThread) {
currentThread = await createThread(
{
@ -92,14 +92,14 @@ export const useChat = () => {
provider: selectedProvider,
},
prompt,
currentAssistant
)
currentAssistant,
);
router.navigate({
to: route.threadsDetail,
params: { threadId: currentThread.id },
})
});
}
return currentThread
return currentThread;
}, [
createThread,
prompt,
@ -108,7 +108,7 @@ export const useChat = () => {
selectedModel?.id,
selectedProvider,
currentAssistant,
])
]);
const increaseModelContextSize = useCallback(
(model: Model, provider: ProviderObject) => {
@ -118,12 +118,12 @@ export const useChat = () => {
*/
const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === 'string'
? typeof model.settings.ctx_len.controller_props.value === "string"
? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number)
: 8192,
8192
)
8192,
);
const updatedModel = {
...model,
settings: {
@ -136,80 +136,80 @@ export const useChat = () => {
},
},
},
}
};
// Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id)
const modelIndex = provider.models.findIndex((m) => m.id === model.id);
if (modelIndex !== -1) {
// Create a copy of the provider's models array
const updatedModels = [...provider.models]
const updatedModels = [...provider.models];
// Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model
updatedModels[modelIndex] = updatedModel as Model;
// Update the provider with the new models array
updateProvider(provider.provider, {
models: updatedModels,
})
});
}
stopAllModels()
stopAllModels();
},
[updateProvider]
)
[updateProvider],
);
const sendMessage = useCallback(
async (
message: string,
showModal?: () => Promise<unknown>,
troubleshooting = true
troubleshooting = true,
) => {
const activeThread = await getCurrentThread()
const activeThread = await getCurrentThread();
resetTokenSpeed()
resetTokenSpeed();
const activeProvider = currentProviderId
? getProviderByName(currentProviderId)
: provider
if (!activeThread || !activeProvider) return
const messages = getMessages(activeThread.id)
const abortController = new AbortController()
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
: provider;
if (!activeThread || !activeProvider) return;
const messages = getMessages(activeThread.id);
const abortController = new AbortController();
setAbortController(activeThread.id, abortController);
updateStreamingContent(emptyThreadContent);
// Do not add new message on retry
if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message))
updateThreadTimestamp(activeThread.id)
setPrompt('')
addMessage(newUserThreadContent(activeThread.id, message));
updateThreadTimestamp(activeThread.id);
setPrompt("");
try {
if (selectedModel?.id) {
updateLoadingModel(true)
updateLoadingModel(true);
await startModel(
activeProvider,
selectedModel.id,
abortController
).catch(console.error)
updateLoadingModel(false)
abortController,
).catch(console.error);
updateLoadingModel(false);
}
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant?.instructions
)
currentAssistant?.instructions,
);
builder.addUserMessage(message)
builder.addUserMessage(message);
let isCompleted = false
let isCompleted = false;
// Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools')
let availableTools = selectedModel?.capabilities?.includes("tools")
? tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id)
return !disabledTools.includes(tool.name)
const disabledTools = getDisabledToolsForThread(activeThread.id);
return !disabledTools.includes(tool.name);
})
: []
: [];
// TODO: Later replaced by Agent setup?
const followUpWithToolUse = true
const followUpWithToolUse = true;
while (!isCompleted && !abortController.signal.aborted) {
const completion = await sendCompletion(
activeThread,
@ -218,51 +218,51 @@ export const useChat = () => {
abortController,
availableTools,
currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object>
currentAssistant.parameters as unknown as Record<string, object>,
// TODO: replace it with according provider setting later on
// selectedProvider === 'llama.cpp' && availableTools.length > 0
// ? false
// : true
)
);
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
if (!completion) throw new Error("No completion received");
let accumulatedText = "";
const currentCall: ChatCompletionMessageToolCall | null = null;
const toolCalls: ChatCompletionMessageToolCall[] = [];
if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || ''
accumulatedText = completion.choices[0]?.message?.content || "";
if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls)
toolCalls.push(...completion.choices[0].message.tool_calls);
}
} else {
for await (const part of completion) {
// Error message
if (!part.choices) {
throw new Error(
'message' in part
"message" in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
: (JSON.stringify(part) ?? ""),
);
}
const delta = part.choices[0]?.delta?.content || ''
const delta = part.choices[0]?.delta?.content || "";
if (part.choices[0]?.delta?.tool_calls) {
const calls = extractToolCall(part, currentCall, toolCalls)
const calls = extractToolCall(part, currentCall, toolCalls);
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: calls.map((e) => ({
...e,
state: 'pending',
state: "pending",
})),
}
)
updateStreamingContent(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
},
);
updateStreamingContent(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0));
}
if (delta) {
accumulatedText += delta
accumulatedText += delta;
// Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent(
@ -271,13 +271,13 @@ export const useChat = () => {
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
state: "pending",
})),
}
)
updateStreamingContent(currentContent)
updateTokenSpeed(currentContent)
await new Promise((resolve) => setTimeout(resolve, 0))
},
);
updateStreamingContent(currentContent);
updateTokenSpeed(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0));
}
}
}
@ -286,18 +286,22 @@ export const useChat = () => {
accumulatedText.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider.provider === 'llama.cpp'
activeProvider.provider === "llama.cpp"
) {
await stopModel(activeThread.model.id, 'cortex')
throw new Error('No response received from the model')
await stopModel(activeThread.model.id, "cortex");
throw new Error("No response received from the model");
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
accumulatedText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
},
);
builder.addAssistantMessage(accumulatedText, undefined, toolCalls);
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
@ -305,41 +309,41 @@ export const useChat = () => {
abortController,
approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updateThreadTimestamp(activeThread.id)
allowAllMCPPermissions,
);
addMessage(updatedMessage ?? finalContent);
updateStreamingContent(emptyThreadContent);
updateThreadTimestamp(activeThread.id);
isCompleted = !toolCalls.length
isCompleted = !toolCalls.length;
// Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = []
if (!followUpWithToolUse) availableTools = [];
}
} catch (error) {
const errorMessage =
error && typeof error === 'object' && 'message' in error
error && typeof error === "object" && "message" in error
? error.message
: error
: error;
if (
typeof errorMessage === 'string' &&
typeof errorMessage === "string" &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel &&
troubleshooting
) {
showModal?.().then((confirmed) => {
if (confirmed) {
increaseModelContextSize(selectedModel, activeProvider)
increaseModelContextSize(selectedModel, activeProvider);
setTimeout(() => {
sendMessage(message, showModal, false) // Retry sending the message without troubleshooting
}, 1000)
sendMessage(message, showModal, false); // Retry sending the message without troubleshooting
}, 1000);
}
})
});
}
toast.error(`Error sending message: ${errorMessage}`)
console.error('Error sending message:', error)
toast.error(`Error sending message: ${errorMessage}`);
console.error("Error sending message:", error);
} finally {
updateLoadingModel(false)
updateStreamingContent(undefined)
updateLoadingModel(false);
updateStreamingContent(undefined);
}
},
[
@ -364,8 +368,8 @@ export const useChat = () => {
showApprovalModal,
updateTokenSpeed,
increaseModelContextSize,
]
)
],
);
return { sendMessage }
}
return { sendMessage };
};

View File

@ -1,23 +1,23 @@
import { create } from 'zustand'
import { ThreadMessage } from '@janhq/core'
import { create } from "zustand";
import { ThreadMessage } from "@janhq/core";
import {
createMessage,
deleteMessage as deleteMessageExt,
} from '@/services/messages'
import { useAssistant } from './useAssistant'
} from "@/services/messages";
import { useAssistant } from "./useAssistant";
type MessageState = {
messages: Record<string, ThreadMessage[]>
getMessages: (threadId: string) => ThreadMessage[]
setMessages: (threadId: string, messages: ThreadMessage[]) => void
addMessage: (message: ThreadMessage) => void
deleteMessage: (threadId: string, messageId: string) => void
}
messages: Record<string, ThreadMessage[]>;
getMessages: (threadId: string) => ThreadMessage[];
setMessages: (threadId: string, messages: ThreadMessage[]) => void;
addMessage: (message: ThreadMessage) => void;
deleteMessage: (threadId: string, messageId: string) => void;
};
export const useMessages = create<MessageState>()((set, get) => ({
messages: {},
getMessages: (threadId) => {
return get().messages[threadId] || []
return get().messages[threadId] || [];
},
setMessages: (threadId, messages) => {
set((state) => ({
@ -25,10 +25,11 @@ export const useMessages = create<MessageState>()((set, get) => ({
...state.messages,
[threadId]: messages,
},
}))
}));
},
addMessage: (message) => {
const currentAssistant = useAssistant.getState().currentAssistant
console.log("addMessage: ", message);
const currentAssistant = useAssistant.getState().currentAssistant;
const newMessage = {
...message,
created_at: message.created_at || Date.now(),
@ -36,7 +37,7 @@ export const useMessages = create<MessageState>()((set, get) => ({
...message.metadata,
assistant: currentAssistant,
},
}
};
createMessage(newMessage).then((createdMessage) => {
set((state) => ({
messages: {
@ -46,19 +47,19 @@ export const useMessages = create<MessageState>()((set, get) => ({
createdMessage,
],
},
}))
})
}));
});
},
deleteMessage: (threadId, messageId) => {
deleteMessageExt(threadId, messageId)
deleteMessageExt(threadId, messageId);
set((state) => ({
messages: {
...state.messages,
[threadId]:
state.messages[threadId]?.filter(
(message) => message.id !== messageId
(message) => message.id !== messageId,
) || [],
},
}))
}));
},
}))
}));