diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 291642a60..a83adc59e 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -48,7 +48,7 @@ type ChatInputProps = { const ChatInput = ({ model, className, - showSpeedToken = false, + showSpeedToken = true, initialMessage, }: ChatInputProps) => { const textareaRef = useRef(null) diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index 833846db1..40c26993b 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -34,6 +34,9 @@ import { } from '@/components/ui/tooltip' import { formatDate } from '@/utils/formatDate' import { AvatarEmoji } from '@/containers/AvatarEmoji' + +import TokenSpeedIndicator from '@/containers/TokenSpeedIndicator' + import CodeEditor from '@uiw/react-textarea-code-editor' import '@uiw/react-textarea-code-editor/dist.css' @@ -360,8 +363,8 @@ export const ThreadContent = memo( className={cn( 'flex items-center gap-2', item.isLastMessage && - streamingContent && - 'opacity-0 visibility-hidden pointer-events-none' + streamingContent && + 'opacity-0 visibility-hidden pointer-events-none' )} > @@ -445,6 +448,11 @@ export const ThreadContent = memo( )} + + )} diff --git a/web-app/src/containers/TokenSpeedIndicator.tsx b/web-app/src/containers/TokenSpeedIndicator.tsx new file mode 100644 index 000000000..b1dfb841c --- /dev/null +++ b/web-app/src/containers/TokenSpeedIndicator.tsx @@ -0,0 +1,22 @@ +import { IconBrandSpeedtest } from '@tabler/icons-react' + +interface TokenSpeedIndicatorProps { + metadata?: Record +} + +export const TokenSpeedIndicator = ({ + metadata +}: TokenSpeedIndicatorProps) => { + const persistedTokenSpeed = (metadata?.tokenSpeed as { tokenSpeed: number })?.tokenSpeed + + return ( +
+ + + {Math.round(persistedTokenSpeed)} tokens/sec + +
+ ) +} + +export default TokenSpeedIndicator diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index dc29f7f8a..ccf044a4f 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -1,36 +1,36 @@ -import { create } from 'zustand' -import { ThreadMessage } from '@janhq/core' -import { MCPTool } from '@/types/completion' -import { useAssistant } from './useAssistant' -import { ChatCompletionMessageToolCall } from 'openai/resources' +import { create } from "zustand"; +import { ThreadMessage } from "@janhq/core"; +import { MCPTool } from "@/types/completion"; +import { useAssistant } from "./useAssistant"; +import { ChatCompletionMessageToolCall } from "openai/resources"; type AppState = { - streamingContent?: ThreadMessage - loadingModel?: boolean - tools: MCPTool[] - serverStatus: 'running' | 'stopped' | 'pending' - abortControllers: Record - tokenSpeed?: TokenSpeed - currentToolCall?: ChatCompletionMessageToolCall - showOutOfContextDialog?: boolean - setServerStatus: (value: 'running' | 'stopped' | 'pending') => void - updateStreamingContent: (content: ThreadMessage | undefined) => void + streamingContent?: ThreadMessage; + loadingModel?: boolean; + tools: MCPTool[]; + serverStatus: "running" | "stopped" | "pending"; + abortControllers: Record; + tokenSpeed?: TokenSpeed; + currentToolCall?: ChatCompletionMessageToolCall; + showOutOfContextDialog?: boolean; + setServerStatus: (value: "running" | "stopped" | "pending") => void; + updateStreamingContent: (content: ThreadMessage | undefined) => void; updateCurrentToolCall: ( - toolCall: ChatCompletionMessageToolCall | undefined - ) => void - updateLoadingModel: (loading: boolean) => void - updateTools: (tools: MCPTool[]) => void - setAbortController: (threadId: string, controller: AbortController) => void - updateTokenSpeed: (message: ThreadMessage) => void - resetTokenSpeed: () => void - setOutOfContextDialog: (show: boolean) => void -} + toolCall: ChatCompletionMessageToolCall | undefined, + ) => void; + updateLoadingModel: (loading: boolean) => void; + updateTools: (tools: MCPTool[]) => void; + setAbortController: (threadId: string, controller: AbortController) => void; + updateTokenSpeed: (message: ThreadMessage) => void; + resetTokenSpeed: () => void; + setOutOfContextDialog: (show: boolean) => void; +}; export const useAppState = create()((set) => ({ streamingContent: undefined, loadingModel: false, tools: [], - serverStatus: 'stopped', + serverStatus: "stopped", abortControllers: {}, tokenSpeed: undefined, currentToolCall: undefined, @@ -46,18 +46,19 @@ export const useAppState = create()((set) => ({ }, } : undefined, - })) + })); + console.log(useAppState.getState().streamingContent); }, updateCurrentToolCall: (toolCall) => { set(() => ({ currentToolCall: toolCall, - })) + })); }, updateLoadingModel: (loading) => { - set({ loadingModel: loading }) + set({ loadingModel: loading }); }, updateTools: (tools) => { - set({ tools }) + set({ tools }); }, setServerStatus: (value) => set({ serverStatus: value }), setAbortController: (threadId, controller) => { @@ -66,11 +67,11 @@ export const useAppState = create()((set) => ({ ...state.abortControllers, [threadId]: controller, }, - })) + })); }, updateTokenSpeed: (message) => set((state) => { - const currentTimestamp = new Date().getTime() // Get current time in milliseconds + const currentTimestamp = new Date().getTime(); // Get current time in milliseconds if (!state.tokenSpeed) { // If this is the first update, just set the lastTimestamp and return return { @@ -80,14 +81,14 @@ export const useAppState = create()((set) => ({ tokenCount: 1, message: message.id, }, - } + }; } const timeDiffInSeconds = - (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds - const totalTokenCount = state.tokenSpeed.tokenCount + 1 + (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000; // Time difference in seconds + const totalTokenCount = state.tokenSpeed.tokenCount + 1; const averageTokenSpeed = - totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed + totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1); // Calculate average token speed return { tokenSpeed: { ...state.tokenSpeed, @@ -95,7 +96,7 @@ export const useAppState = create()((set) => ({ tokenCount: totalTokenCount, message: message.id, }, - } + }; }), resetTokenSpeed: () => set({ @@ -104,6 +105,6 @@ export const useAppState = create()((set) => ({ setOutOfContextDialog: (show) => { set(() => ({ showOutOfContextDialog: show, - })) + })); }, -})) +})); diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 164555563..0fbfeb5d9 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -1,12 +1,12 @@ -import { useCallback, useEffect, useMemo } from 'react' -import { usePrompt } from './usePrompt' -import { useModelProvider } from './useModelProvider' -import { useThreads } from './useThreads' -import { useAppState } from './useAppState' -import { useMessages } from './useMessages' -import { useRouter } from '@tanstack/react-router' -import { defaultModel } from '@/lib/models' -import { route } from '@/constants/routes' +import { useCallback, useEffect, useMemo } from "react"; +import { usePrompt } from "./usePrompt"; +import { useModelProvider } from "./useModelProvider"; +import { useThreads } from "./useThreads"; +import { useAppState } from "./useAppState"; +import { useMessages } from "./useMessages"; +import { useRouter } from "@tanstack/react-router"; +import { defaultModel } from "@/lib/models"; +import { route } from "@/constants/routes"; import { emptyThreadContent, extractToolCall, @@ -15,23 +15,23 @@ import { newUserThreadContent, postMessageProcessing, sendCompletion, -} from '@/lib/completion' -import { CompletionMessagesBuilder } from '@/lib/messages' -import { ChatCompletionMessageToolCall } from 'openai/resources' -import { useAssistant } from './useAssistant' -import { toast } from 'sonner' -import { getTools } from '@/services/mcp' -import { MCPTool } from '@/types/completion' -import { listen } from '@tauri-apps/api/event' -import { SystemEvent } from '@/types/events' -import { stopModel, startModel, stopAllModels } from '@/services/models' +} from "@/lib/completion"; +import { CompletionMessagesBuilder } from "@/lib/messages"; +import { ChatCompletionMessageToolCall } from "openai/resources"; +import { useAssistant } from "./useAssistant"; +import { toast } from "sonner"; +import { getTools } from "@/services/mcp"; +import { MCPTool } from "@/types/completion"; +import { listen } from "@tauri-apps/api/event"; +import { SystemEvent } from "@/types/events"; +import { stopModel, startModel, stopAllModels } from "@/services/models"; -import { useToolApproval } from '@/hooks/useToolApproval' -import { useToolAvailable } from '@/hooks/useToolAvailable' -import { OUT_OF_CONTEXT_SIZE } from '@/utils/error' +import { useToolApproval } from "@/hooks/useToolApproval"; +import { useToolAvailable } from "@/hooks/useToolAvailable"; +import { OUT_OF_CONTEXT_SIZE } from "@/utils/error"; export const useChat = () => { - const { prompt, setPrompt } = usePrompt() + const { prompt, setPrompt } = usePrompt(); const { tools, updateTokenSpeed, @@ -40,51 +40,51 @@ export const useChat = () => { updateStreamingContent, updateLoadingModel, setAbortController, - } = useAppState() - const { currentAssistant } = useAssistant() - const { updateProvider } = useModelProvider() + } = useAppState(); + const { currentAssistant } = useAssistant(); + const { updateProvider } = useModelProvider(); const { approvedTools, showApprovalModal, allowAllMCPPermissions } = - useToolApproval() - const { getDisabledToolsForThread } = useToolAvailable() + useToolApproval(); + const { getDisabledToolsForThread } = useToolAvailable(); const { getProviderByName, selectedModel, selectedProvider } = - useModelProvider() + useModelProvider(); const { getCurrentThread: retrieveThread, createThread, updateThreadTimestamp, - } = useThreads() - const { getMessages, addMessage } = useMessages() - const router = useRouter() + } = useThreads(); + const { getMessages, addMessage } = useMessages(); + const router = useRouter(); const provider = useMemo(() => { - return getProviderByName(selectedProvider) - }, [selectedProvider, getProviderByName]) + return getProviderByName(selectedProvider); + }, [selectedProvider, getProviderByName]); const currentProviderId = useMemo(() => { - return provider?.provider || selectedProvider - }, [provider, selectedProvider]) + return provider?.provider || selectedProvider; + }, [provider, selectedProvider]); useEffect(() => { function setTools() { getTools().then((data: MCPTool[]) => { - updateTools(data) - }) + updateTools(data); + }); } - setTools() + setTools(); - let unsubscribe = () => {} + let unsubscribe = () => {}; listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => { // Unsubscribe from the event when the component unmounts - unsubscribe = unsub - }) - return unsubscribe - }, [updateTools]) + unsubscribe = unsub; + }); + return unsubscribe; + }, [updateTools]); const getCurrentThread = useCallback(async () => { - let currentThread = retrieveThread() + let currentThread = retrieveThread(); if (!currentThread) { currentThread = await createThread( { @@ -92,14 +92,14 @@ export const useChat = () => { provider: selectedProvider, }, prompt, - currentAssistant - ) + currentAssistant, + ); router.navigate({ to: route.threadsDetail, params: { threadId: currentThread.id }, - }) + }); } - return currentThread + return currentThread; }, [ createThread, prompt, @@ -108,7 +108,7 @@ export const useChat = () => { selectedModel?.id, selectedProvider, currentAssistant, - ]) + ]); const increaseModelContextSize = useCallback( (model: Model, provider: ProviderObject) => { @@ -118,12 +118,12 @@ export const useChat = () => { */ const ctxSize = Math.max( model.settings?.ctx_len?.controller_props.value - ? typeof model.settings.ctx_len.controller_props.value === 'string' + ? typeof model.settings.ctx_len.controller_props.value === "string" ? parseInt(model.settings.ctx_len.controller_props.value as string) : (model.settings.ctx_len.controller_props.value as number) : 8192, - 8192 - ) + 8192, + ); const updatedModel = { ...model, settings: { @@ -136,80 +136,80 @@ export const useChat = () => { }, }, }, - } + }; // Find the model index in the provider's models array - const modelIndex = provider.models.findIndex((m) => m.id === model.id) + const modelIndex = provider.models.findIndex((m) => m.id === model.id); if (modelIndex !== -1) { // Create a copy of the provider's models array - const updatedModels = [...provider.models] + const updatedModels = [...provider.models]; // Update the specific model in the array - updatedModels[modelIndex] = updatedModel as Model + updatedModels[modelIndex] = updatedModel as Model; // Update the provider with the new models array updateProvider(provider.provider, { models: updatedModels, - }) + }); } - stopAllModels() + stopAllModels(); }, - [updateProvider] - ) + [updateProvider], + ); const sendMessage = useCallback( async ( message: string, showModal?: () => Promise, - troubleshooting = true + troubleshooting = true, ) => { - const activeThread = await getCurrentThread() + const activeThread = await getCurrentThread(); - resetTokenSpeed() + resetTokenSpeed(); const activeProvider = currentProviderId ? getProviderByName(currentProviderId) - : provider - if (!activeThread || !activeProvider) return - const messages = getMessages(activeThread.id) - const abortController = new AbortController() - setAbortController(activeThread.id, abortController) - updateStreamingContent(emptyThreadContent) + : provider; + if (!activeThread || !activeProvider) return; + const messages = getMessages(activeThread.id); + const abortController = new AbortController(); + setAbortController(activeThread.id, abortController); + updateStreamingContent(emptyThreadContent); // Do not add new message on retry if (troubleshooting) - addMessage(newUserThreadContent(activeThread.id, message)) - updateThreadTimestamp(activeThread.id) - setPrompt('') + addMessage(newUserThreadContent(activeThread.id, message)); + updateThreadTimestamp(activeThread.id); + setPrompt(""); try { if (selectedModel?.id) { - updateLoadingModel(true) + updateLoadingModel(true); await startModel( activeProvider, selectedModel.id, - abortController - ).catch(console.error) - updateLoadingModel(false) + abortController, + ).catch(console.error); + updateLoadingModel(false); } const builder = new CompletionMessagesBuilder( messages, - currentAssistant?.instructions - ) + currentAssistant?.instructions, + ); - builder.addUserMessage(message) + builder.addUserMessage(message); - let isCompleted = false + let isCompleted = false; // Filter tools based on model capabilities and available tools for this thread - let availableTools = selectedModel?.capabilities?.includes('tools') + let availableTools = selectedModel?.capabilities?.includes("tools") ? tools.filter((tool) => { - const disabledTools = getDisabledToolsForThread(activeThread.id) - return !disabledTools.includes(tool.name) + const disabledTools = getDisabledToolsForThread(activeThread.id); + return !disabledTools.includes(tool.name); }) - : [] + : []; // TODO: Later replaced by Agent setup? - const followUpWithToolUse = true + const followUpWithToolUse = true; while (!isCompleted && !abortController.signal.aborted) { const completion = await sendCompletion( activeThread, @@ -218,51 +218,51 @@ export const useChat = () => { abortController, availableTools, currentAssistant.parameters?.stream === false ? false : true, - currentAssistant.parameters as unknown as Record + currentAssistant.parameters as unknown as Record, // TODO: replace it with according provider setting later on // selectedProvider === 'llama.cpp' && availableTools.length > 0 // ? false // : true - ) + ); - if (!completion) throw new Error('No completion received') - let accumulatedText = '' - const currentCall: ChatCompletionMessageToolCall | null = null - const toolCalls: ChatCompletionMessageToolCall[] = [] + if (!completion) throw new Error("No completion received"); + let accumulatedText = ""; + const currentCall: ChatCompletionMessageToolCall | null = null; + const toolCalls: ChatCompletionMessageToolCall[] = []; if (isCompletionResponse(completion)) { - accumulatedText = completion.choices[0]?.message?.content || '' + accumulatedText = completion.choices[0]?.message?.content || ""; if (completion.choices[0]?.message?.tool_calls) { - toolCalls.push(...completion.choices[0].message.tool_calls) + toolCalls.push(...completion.choices[0].message.tool_calls); } } else { for await (const part of completion) { // Error message if (!part.choices) { throw new Error( - 'message' in part + "message" in part ? (part.message as string) - : (JSON.stringify(part) ?? '') - ) + : (JSON.stringify(part) ?? ""), + ); } - const delta = part.choices[0]?.delta?.content || '' + const delta = part.choices[0]?.delta?.content || ""; if (part.choices[0]?.delta?.tool_calls) { - const calls = extractToolCall(part, currentCall, toolCalls) + const calls = extractToolCall(part, currentCall, toolCalls); const currentContent = newAssistantThreadContent( activeThread.id, accumulatedText, { tool_calls: calls.map((e) => ({ ...e, - state: 'pending', + state: "pending", })), - } - ) - updateStreamingContent(currentContent) - await new Promise((resolve) => setTimeout(resolve, 0)) + }, + ); + updateStreamingContent(currentContent); + await new Promise((resolve) => setTimeout(resolve, 0)); } if (delta) { - accumulatedText += delta + accumulatedText += delta; // Create a new object each time to avoid reference issues // Use a timeout to prevent React from batching updates too quickly const currentContent = newAssistantThreadContent( @@ -271,13 +271,13 @@ export const useChat = () => { { tool_calls: toolCalls.map((e) => ({ ...e, - state: 'pending', + state: "pending", })), - } - ) - updateStreamingContent(currentContent) - updateTokenSpeed(currentContent) - await new Promise((resolve) => setTimeout(resolve, 0)) + }, + ); + updateStreamingContent(currentContent); + updateTokenSpeed(currentContent); + await new Promise((resolve) => setTimeout(resolve, 0)); } } } @@ -286,18 +286,22 @@ export const useChat = () => { accumulatedText.length === 0 && toolCalls.length === 0 && activeThread.model?.id && - activeProvider.provider === 'llama.cpp' + activeProvider.provider === "llama.cpp" ) { - await stopModel(activeThread.model.id, 'cortex') - throw new Error('No response received from the model') + await stopModel(activeThread.model.id, "cortex"); + throw new Error("No response received from the model"); } // Create a final content object for adding to the thread const finalContent = newAssistantThreadContent( activeThread.id, - accumulatedText - ) - builder.addAssistantMessage(accumulatedText, undefined, toolCalls) + accumulatedText, + { + tokenSpeed: useAppState.getState().tokenSpeed, + }, + ); + + builder.addAssistantMessage(accumulatedText, undefined, toolCalls); const updatedMessage = await postMessageProcessing( toolCalls, builder, @@ -305,41 +309,41 @@ export const useChat = () => { abortController, approvedTools, allowAllMCPPermissions ? undefined : showApprovalModal, - allowAllMCPPermissions - ) - addMessage(updatedMessage ?? finalContent) - updateStreamingContent(emptyThreadContent) - updateThreadTimestamp(activeThread.id) + allowAllMCPPermissions, + ); + addMessage(updatedMessage ?? finalContent); + updateStreamingContent(emptyThreadContent); + updateThreadTimestamp(activeThread.id); - isCompleted = !toolCalls.length + isCompleted = !toolCalls.length; // Do not create agent loop if there is no need for it - if (!followUpWithToolUse) availableTools = [] + if (!followUpWithToolUse) availableTools = []; } } catch (error) { const errorMessage = - error && typeof error === 'object' && 'message' in error + error && typeof error === "object" && "message" in error ? error.message - : error + : error; if ( - typeof errorMessage === 'string' && + typeof errorMessage === "string" && errorMessage.includes(OUT_OF_CONTEXT_SIZE) && selectedModel && troubleshooting ) { showModal?.().then((confirmed) => { if (confirmed) { - increaseModelContextSize(selectedModel, activeProvider) + increaseModelContextSize(selectedModel, activeProvider); setTimeout(() => { - sendMessage(message, showModal, false) // Retry sending the message without troubleshooting - }, 1000) + sendMessage(message, showModal, false); // Retry sending the message without troubleshooting + }, 1000); } - }) + }); } - toast.error(`Error sending message: ${errorMessage}`) - console.error('Error sending message:', error) + toast.error(`Error sending message: ${errorMessage}`); + console.error("Error sending message:", error); } finally { - updateLoadingModel(false) - updateStreamingContent(undefined) + updateLoadingModel(false); + updateStreamingContent(undefined); } }, [ @@ -364,8 +368,8 @@ export const useChat = () => { showApprovalModal, updateTokenSpeed, increaseModelContextSize, - ] - ) + ], + ); - return { sendMessage } -} + return { sendMessage }; +}; diff --git a/web-app/src/hooks/useMessages.ts b/web-app/src/hooks/useMessages.ts index 3a83b5a48..251d67438 100644 --- a/web-app/src/hooks/useMessages.ts +++ b/web-app/src/hooks/useMessages.ts @@ -1,23 +1,23 @@ -import { create } from 'zustand' -import { ThreadMessage } from '@janhq/core' +import { create } from "zustand"; +import { ThreadMessage } from "@janhq/core"; import { createMessage, deleteMessage as deleteMessageExt, -} from '@/services/messages' -import { useAssistant } from './useAssistant' +} from "@/services/messages"; +import { useAssistant } from "./useAssistant"; type MessageState = { - messages: Record - getMessages: (threadId: string) => ThreadMessage[] - setMessages: (threadId: string, messages: ThreadMessage[]) => void - addMessage: (message: ThreadMessage) => void - deleteMessage: (threadId: string, messageId: string) => void -} + messages: Record; + getMessages: (threadId: string) => ThreadMessage[]; + setMessages: (threadId: string, messages: ThreadMessage[]) => void; + addMessage: (message: ThreadMessage) => void; + deleteMessage: (threadId: string, messageId: string) => void; +}; export const useMessages = create()((set, get) => ({ messages: {}, getMessages: (threadId) => { - return get().messages[threadId] || [] + return get().messages[threadId] || []; }, setMessages: (threadId, messages) => { set((state) => ({ @@ -25,10 +25,11 @@ export const useMessages = create()((set, get) => ({ ...state.messages, [threadId]: messages, }, - })) + })); }, addMessage: (message) => { - const currentAssistant = useAssistant.getState().currentAssistant + console.log("addMessage: ", message); + const currentAssistant = useAssistant.getState().currentAssistant; const newMessage = { ...message, created_at: message.created_at || Date.now(), @@ -36,7 +37,7 @@ export const useMessages = create()((set, get) => ({ ...message.metadata, assistant: currentAssistant, }, - } + }; createMessage(newMessage).then((createdMessage) => { set((state) => ({ messages: { @@ -46,19 +47,19 @@ export const useMessages = create()((set, get) => ({ createdMessage, ], }, - })) - }) + })); + }); }, deleteMessage: (threadId, messageId) => { - deleteMessageExt(threadId, messageId) + deleteMessageExt(threadId, messageId); set((state) => ({ messages: { ...state.messages, [threadId]: state.messages[threadId]?.filter( - (message) => message.id !== messageId + (message) => message.id !== messageId, ) || [], }, - })) + })); }, -})) +}));