feat: add token speed to each message that persist

This commit is contained in:
LazyYuuki 2025-06-15 18:30:39 +08:00
parent 3ae4d12f60
commit 665de7df55
6 changed files with 236 additions and 200 deletions

View File

@ -48,7 +48,7 @@ type ChatInputProps = {
const ChatInput = ({ const ChatInput = ({
model, model,
className, className,
showSpeedToken = false, showSpeedToken = true,
initialMessage, initialMessage,
}: ChatInputProps) => { }: ChatInputProps) => {
const textareaRef = useRef<HTMLTextAreaElement>(null) const textareaRef = useRef<HTMLTextAreaElement>(null)

View File

@ -34,6 +34,9 @@ import {
} from '@/components/ui/tooltip' } from '@/components/ui/tooltip'
import { formatDate } from '@/utils/formatDate' import { formatDate } from '@/utils/formatDate'
import { AvatarEmoji } from '@/containers/AvatarEmoji' import { AvatarEmoji } from '@/containers/AvatarEmoji'
import TokenSpeedIndicator from '@/containers/TokenSpeedIndicator'
import CodeEditor from '@uiw/react-textarea-code-editor' import CodeEditor from '@uiw/react-textarea-code-editor'
import '@uiw/react-textarea-code-editor/dist.css' import '@uiw/react-textarea-code-editor/dist.css'
@ -360,8 +363,8 @@ export const ThreadContent = memo(
className={cn( className={cn(
'flex items-center gap-2', 'flex items-center gap-2',
item.isLastMessage && item.isLastMessage &&
streamingContent && streamingContent &&
'opacity-0 visibility-hidden pointer-events-none' 'opacity-0 visibility-hidden pointer-events-none'
)} )}
> >
<CopyButton text={item.content?.[0]?.text.value || ''} /> <CopyButton text={item.content?.[0]?.text.value || ''} />
@ -445,6 +448,11 @@ export const ThreadContent = memo(
</TooltipContent> </TooltipContent>
</Tooltip> </Tooltip>
)} )}
<TokenSpeedIndicator
messageId={item.id}
metadata={item.metadata}
/>
</div> </div>
</div> </div>
)} )}

View File

@ -0,0 +1,22 @@
import { IconBrandSpeedtest } from '@tabler/icons-react'
interface TokenSpeedIndicatorProps {
metadata?: Record<string, unknown>
}
export const TokenSpeedIndicator = ({
metadata
}: TokenSpeedIndicatorProps) => {
const persistedTokenSpeed = (metadata?.tokenSpeed as { tokenSpeed: number })?.tokenSpeed
return (
<div className="flex items-center gap-1 text-main-view-fg/60 text-xs">
<IconBrandSpeedtest size={16} />
<span>
{Math.round(persistedTokenSpeed)} tokens/sec
</span>
</div>
)
}
export default TokenSpeedIndicator

View File

@ -1,36 +1,36 @@
import { create } from 'zustand' import { create } from "zustand";
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from "@janhq/core";
import { MCPTool } from '@/types/completion' import { MCPTool } from "@/types/completion";
import { useAssistant } from './useAssistant' import { useAssistant } from "./useAssistant";
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from "openai/resources";
type AppState = { type AppState = {
streamingContent?: ThreadMessage streamingContent?: ThreadMessage;
loadingModel?: boolean loadingModel?: boolean;
tools: MCPTool[] tools: MCPTool[];
serverStatus: 'running' | 'stopped' | 'pending' serverStatus: "running" | "stopped" | "pending";
abortControllers: Record<string, AbortController> abortControllers: Record<string, AbortController>;
tokenSpeed?: TokenSpeed tokenSpeed?: TokenSpeed;
currentToolCall?: ChatCompletionMessageToolCall currentToolCall?: ChatCompletionMessageToolCall;
showOutOfContextDialog?: boolean showOutOfContextDialog?: boolean;
setServerStatus: (value: 'running' | 'stopped' | 'pending') => void setServerStatus: (value: "running" | "stopped" | "pending") => void;
updateStreamingContent: (content: ThreadMessage | undefined) => void updateStreamingContent: (content: ThreadMessage | undefined) => void;
updateCurrentToolCall: ( updateCurrentToolCall: (
toolCall: ChatCompletionMessageToolCall | undefined toolCall: ChatCompletionMessageToolCall | undefined,
) => void ) => void;
updateLoadingModel: (loading: boolean) => void updateLoadingModel: (loading: boolean) => void;
updateTools: (tools: MCPTool[]) => void updateTools: (tools: MCPTool[]) => void;
setAbortController: (threadId: string, controller: AbortController) => void setAbortController: (threadId: string, controller: AbortController) => void;
updateTokenSpeed: (message: ThreadMessage) => void updateTokenSpeed: (message: ThreadMessage) => void;
resetTokenSpeed: () => void resetTokenSpeed: () => void;
setOutOfContextDialog: (show: boolean) => void setOutOfContextDialog: (show: boolean) => void;
} };
export const useAppState = create<AppState>()((set) => ({ export const useAppState = create<AppState>()((set) => ({
streamingContent: undefined, streamingContent: undefined,
loadingModel: false, loadingModel: false,
tools: [], tools: [],
serverStatus: 'stopped', serverStatus: "stopped",
abortControllers: {}, abortControllers: {},
tokenSpeed: undefined, tokenSpeed: undefined,
currentToolCall: undefined, currentToolCall: undefined,
@ -46,18 +46,19 @@ export const useAppState = create<AppState>()((set) => ({
}, },
} }
: undefined, : undefined,
})) }));
console.log(useAppState.getState().streamingContent);
}, },
updateCurrentToolCall: (toolCall) => { updateCurrentToolCall: (toolCall) => {
set(() => ({ set(() => ({
currentToolCall: toolCall, currentToolCall: toolCall,
})) }));
}, },
updateLoadingModel: (loading) => { updateLoadingModel: (loading) => {
set({ loadingModel: loading }) set({ loadingModel: loading });
}, },
updateTools: (tools) => { updateTools: (tools) => {
set({ tools }) set({ tools });
}, },
setServerStatus: (value) => set({ serverStatus: value }), setServerStatus: (value) => set({ serverStatus: value }),
setAbortController: (threadId, controller) => { setAbortController: (threadId, controller) => {
@ -66,11 +67,11 @@ export const useAppState = create<AppState>()((set) => ({
...state.abortControllers, ...state.abortControllers,
[threadId]: controller, [threadId]: controller,
}, },
})) }));
}, },
updateTokenSpeed: (message) => updateTokenSpeed: (message) =>
set((state) => { set((state) => {
const currentTimestamp = new Date().getTime() // Get current time in milliseconds const currentTimestamp = new Date().getTime(); // Get current time in milliseconds
if (!state.tokenSpeed) { if (!state.tokenSpeed) {
// If this is the first update, just set the lastTimestamp and return // If this is the first update, just set the lastTimestamp and return
return { return {
@ -80,14 +81,14 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: 1, tokenCount: 1,
message: message.id, message: message.id,
}, },
} };
} }
const timeDiffInSeconds = const timeDiffInSeconds =
(currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000; // Time difference in seconds
const totalTokenCount = state.tokenSpeed.tokenCount + 1 const totalTokenCount = state.tokenSpeed.tokenCount + 1;
const averageTokenSpeed = const averageTokenSpeed =
totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1); // Calculate average token speed
return { return {
tokenSpeed: { tokenSpeed: {
...state.tokenSpeed, ...state.tokenSpeed,
@ -95,7 +96,7 @@ export const useAppState = create<AppState>()((set) => ({
tokenCount: totalTokenCount, tokenCount: totalTokenCount,
message: message.id, message: message.id,
}, },
} };
}), }),
resetTokenSpeed: () => resetTokenSpeed: () =>
set({ set({
@ -104,6 +105,6 @@ export const useAppState = create<AppState>()((set) => ({
setOutOfContextDialog: (show) => { setOutOfContextDialog: (show) => {
set(() => ({ set(() => ({
showOutOfContextDialog: show, showOutOfContextDialog: show,
})) }));
}, },
})) }));

View File

@ -1,12 +1,12 @@
import { useCallback, useEffect, useMemo } from 'react' import { useCallback, useEffect, useMemo } from "react";
import { usePrompt } from './usePrompt' import { usePrompt } from "./usePrompt";
import { useModelProvider } from './useModelProvider' import { useModelProvider } from "./useModelProvider";
import { useThreads } from './useThreads' import { useThreads } from "./useThreads";
import { useAppState } from './useAppState' import { useAppState } from "./useAppState";
import { useMessages } from './useMessages' import { useMessages } from "./useMessages";
import { useRouter } from '@tanstack/react-router' import { useRouter } from "@tanstack/react-router";
import { defaultModel } from '@/lib/models' import { defaultModel } from "@/lib/models";
import { route } from '@/constants/routes' import { route } from "@/constants/routes";
import { import {
emptyThreadContent, emptyThreadContent,
extractToolCall, extractToolCall,
@ -15,23 +15,23 @@ import {
newUserThreadContent, newUserThreadContent,
postMessageProcessing, postMessageProcessing,
sendCompletion, sendCompletion,
} from '@/lib/completion' } from "@/lib/completion";
import { CompletionMessagesBuilder } from '@/lib/messages' import { CompletionMessagesBuilder } from "@/lib/messages";
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from "openai/resources";
import { useAssistant } from './useAssistant' import { useAssistant } from "./useAssistant";
import { toast } from 'sonner' import { toast } from "sonner";
import { getTools } from '@/services/mcp' import { getTools } from "@/services/mcp";
import { MCPTool } from '@/types/completion' import { MCPTool } from "@/types/completion";
import { listen } from '@tauri-apps/api/event' import { listen } from "@tauri-apps/api/event";
import { SystemEvent } from '@/types/events' import { SystemEvent } from "@/types/events";
import { stopModel, startModel, stopAllModels } from '@/services/models' import { stopModel, startModel, stopAllModels } from "@/services/models";
import { useToolApproval } from '@/hooks/useToolApproval' import { useToolApproval } from "@/hooks/useToolApproval";
import { useToolAvailable } from '@/hooks/useToolAvailable' import { useToolAvailable } from "@/hooks/useToolAvailable";
import { OUT_OF_CONTEXT_SIZE } from '@/utils/error' import { OUT_OF_CONTEXT_SIZE } from "@/utils/error";
export const useChat = () => { export const useChat = () => {
const { prompt, setPrompt } = usePrompt() const { prompt, setPrompt } = usePrompt();
const { const {
tools, tools,
updateTokenSpeed, updateTokenSpeed,
@ -40,51 +40,51 @@ export const useChat = () => {
updateStreamingContent, updateStreamingContent,
updateLoadingModel, updateLoadingModel,
setAbortController, setAbortController,
} = useAppState() } = useAppState();
const { currentAssistant } = useAssistant() const { currentAssistant } = useAssistant();
const { updateProvider } = useModelProvider() const { updateProvider } = useModelProvider();
const { approvedTools, showApprovalModal, allowAllMCPPermissions } = const { approvedTools, showApprovalModal, allowAllMCPPermissions } =
useToolApproval() useToolApproval();
const { getDisabledToolsForThread } = useToolAvailable() const { getDisabledToolsForThread } = useToolAvailable();
const { getProviderByName, selectedModel, selectedProvider } = const { getProviderByName, selectedModel, selectedProvider } =
useModelProvider() useModelProvider();
const { const {
getCurrentThread: retrieveThread, getCurrentThread: retrieveThread,
createThread, createThread,
updateThreadTimestamp, updateThreadTimestamp,
} = useThreads() } = useThreads();
const { getMessages, addMessage } = useMessages() const { getMessages, addMessage } = useMessages();
const router = useRouter() const router = useRouter();
const provider = useMemo(() => { const provider = useMemo(() => {
return getProviderByName(selectedProvider) return getProviderByName(selectedProvider);
}, [selectedProvider, getProviderByName]) }, [selectedProvider, getProviderByName]);
const currentProviderId = useMemo(() => { const currentProviderId = useMemo(() => {
return provider?.provider || selectedProvider return provider?.provider || selectedProvider;
}, [provider, selectedProvider]) }, [provider, selectedProvider]);
useEffect(() => { useEffect(() => {
function setTools() { function setTools() {
getTools().then((data: MCPTool[]) => { getTools().then((data: MCPTool[]) => {
updateTools(data) updateTools(data);
}) });
} }
setTools() setTools();
let unsubscribe = () => {} let unsubscribe = () => {};
listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => { listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => {
// Unsubscribe from the event when the component unmounts // Unsubscribe from the event when the component unmounts
unsubscribe = unsub unsubscribe = unsub;
}) });
return unsubscribe return unsubscribe;
}, [updateTools]) }, [updateTools]);
const getCurrentThread = useCallback(async () => { const getCurrentThread = useCallback(async () => {
let currentThread = retrieveThread() let currentThread = retrieveThread();
if (!currentThread) { if (!currentThread) {
currentThread = await createThread( currentThread = await createThread(
{ {
@ -92,14 +92,14 @@ export const useChat = () => {
provider: selectedProvider, provider: selectedProvider,
}, },
prompt, prompt,
currentAssistant currentAssistant,
) );
router.navigate({ router.navigate({
to: route.threadsDetail, to: route.threadsDetail,
params: { threadId: currentThread.id }, params: { threadId: currentThread.id },
}) });
} }
return currentThread return currentThread;
}, [ }, [
createThread, createThread,
prompt, prompt,
@ -108,7 +108,7 @@ export const useChat = () => {
selectedModel?.id, selectedModel?.id,
selectedProvider, selectedProvider,
currentAssistant, currentAssistant,
]) ]);
const increaseModelContextSize = useCallback( const increaseModelContextSize = useCallback(
(model: Model, provider: ProviderObject) => { (model: Model, provider: ProviderObject) => {
@ -118,12 +118,12 @@ export const useChat = () => {
*/ */
const ctxSize = Math.max( const ctxSize = Math.max(
model.settings?.ctx_len?.controller_props.value model.settings?.ctx_len?.controller_props.value
? typeof model.settings.ctx_len.controller_props.value === 'string' ? typeof model.settings.ctx_len.controller_props.value === "string"
? parseInt(model.settings.ctx_len.controller_props.value as string) ? parseInt(model.settings.ctx_len.controller_props.value as string)
: (model.settings.ctx_len.controller_props.value as number) : (model.settings.ctx_len.controller_props.value as number)
: 8192, : 8192,
8192 8192,
) );
const updatedModel = { const updatedModel = {
...model, ...model,
settings: { settings: {
@ -136,80 +136,80 @@ export const useChat = () => {
}, },
}, },
}, },
} };
// Find the model index in the provider's models array // Find the model index in the provider's models array
const modelIndex = provider.models.findIndex((m) => m.id === model.id) const modelIndex = provider.models.findIndex((m) => m.id === model.id);
if (modelIndex !== -1) { if (modelIndex !== -1) {
// Create a copy of the provider's models array // Create a copy of the provider's models array
const updatedModels = [...provider.models] const updatedModels = [...provider.models];
// Update the specific model in the array // Update the specific model in the array
updatedModels[modelIndex] = updatedModel as Model updatedModels[modelIndex] = updatedModel as Model;
// Update the provider with the new models array // Update the provider with the new models array
updateProvider(provider.provider, { updateProvider(provider.provider, {
models: updatedModels, models: updatedModels,
}) });
} }
stopAllModels() stopAllModels();
}, },
[updateProvider] [updateProvider],
) );
const sendMessage = useCallback( const sendMessage = useCallback(
async ( async (
message: string, message: string,
showModal?: () => Promise<unknown>, showModal?: () => Promise<unknown>,
troubleshooting = true troubleshooting = true,
) => { ) => {
const activeThread = await getCurrentThread() const activeThread = await getCurrentThread();
resetTokenSpeed() resetTokenSpeed();
const activeProvider = currentProviderId const activeProvider = currentProviderId
? getProviderByName(currentProviderId) ? getProviderByName(currentProviderId)
: provider : provider;
if (!activeThread || !activeProvider) return if (!activeThread || !activeProvider) return;
const messages = getMessages(activeThread.id) const messages = getMessages(activeThread.id);
const abortController = new AbortController() const abortController = new AbortController();
setAbortController(activeThread.id, abortController) setAbortController(activeThread.id, abortController);
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent);
// Do not add new message on retry // Do not add new message on retry
if (troubleshooting) if (troubleshooting)
addMessage(newUserThreadContent(activeThread.id, message)) addMessage(newUserThreadContent(activeThread.id, message));
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id);
setPrompt('') setPrompt("");
try { try {
if (selectedModel?.id) { if (selectedModel?.id) {
updateLoadingModel(true) updateLoadingModel(true);
await startModel( await startModel(
activeProvider, activeProvider,
selectedModel.id, selectedModel.id,
abortController abortController,
).catch(console.error) ).catch(console.error);
updateLoadingModel(false) updateLoadingModel(false);
} }
const builder = new CompletionMessagesBuilder( const builder = new CompletionMessagesBuilder(
messages, messages,
currentAssistant?.instructions currentAssistant?.instructions,
) );
builder.addUserMessage(message) builder.addUserMessage(message);
let isCompleted = false let isCompleted = false;
// Filter tools based on model capabilities and available tools for this thread // Filter tools based on model capabilities and available tools for this thread
let availableTools = selectedModel?.capabilities?.includes('tools') let availableTools = selectedModel?.capabilities?.includes("tools")
? tools.filter((tool) => { ? tools.filter((tool) => {
const disabledTools = getDisabledToolsForThread(activeThread.id) const disabledTools = getDisabledToolsForThread(activeThread.id);
return !disabledTools.includes(tool.name) return !disabledTools.includes(tool.name);
}) })
: [] : [];
// TODO: Later replaced by Agent setup? // TODO: Later replaced by Agent setup?
const followUpWithToolUse = true const followUpWithToolUse = true;
while (!isCompleted && !abortController.signal.aborted) { while (!isCompleted && !abortController.signal.aborted) {
const completion = await sendCompletion( const completion = await sendCompletion(
activeThread, activeThread,
@ -218,51 +218,51 @@ export const useChat = () => {
abortController, abortController,
availableTools, availableTools,
currentAssistant.parameters?.stream === false ? false : true, currentAssistant.parameters?.stream === false ? false : true,
currentAssistant.parameters as unknown as Record<string, object> currentAssistant.parameters as unknown as Record<string, object>,
// TODO: replace it with according provider setting later on // TODO: replace it with according provider setting later on
// selectedProvider === 'llama.cpp' && availableTools.length > 0 // selectedProvider === 'llama.cpp' && availableTools.length > 0
// ? false // ? false
// : true // : true
) );
if (!completion) throw new Error('No completion received') if (!completion) throw new Error("No completion received");
let accumulatedText = '' let accumulatedText = "";
const currentCall: ChatCompletionMessageToolCall | null = null const currentCall: ChatCompletionMessageToolCall | null = null;
const toolCalls: ChatCompletionMessageToolCall[] = [] const toolCalls: ChatCompletionMessageToolCall[] = [];
if (isCompletionResponse(completion)) { if (isCompletionResponse(completion)) {
accumulatedText = completion.choices[0]?.message?.content || '' accumulatedText = completion.choices[0]?.message?.content || "";
if (completion.choices[0]?.message?.tool_calls) { if (completion.choices[0]?.message?.tool_calls) {
toolCalls.push(...completion.choices[0].message.tool_calls) toolCalls.push(...completion.choices[0].message.tool_calls);
} }
} else { } else {
for await (const part of completion) { for await (const part of completion) {
// Error message // Error message
if (!part.choices) { if (!part.choices) {
throw new Error( throw new Error(
'message' in part "message" in part
? (part.message as string) ? (part.message as string)
: (JSON.stringify(part) ?? '') : (JSON.stringify(part) ?? ""),
) );
} }
const delta = part.choices[0]?.delta?.content || '' const delta = part.choices[0]?.delta?.content || "";
if (part.choices[0]?.delta?.tool_calls) { if (part.choices[0]?.delta?.tool_calls) {
const calls = extractToolCall(part, currentCall, toolCalls) const calls = extractToolCall(part, currentCall, toolCalls);
const currentContent = newAssistantThreadContent( const currentContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText, accumulatedText,
{ {
tool_calls: calls.map((e) => ({ tool_calls: calls.map((e) => ({
...e, ...e,
state: 'pending', state: "pending",
})), })),
} },
) );
updateStreamingContent(currentContent) updateStreamingContent(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0)) await new Promise((resolve) => setTimeout(resolve, 0));
} }
if (delta) { if (delta) {
accumulatedText += delta accumulatedText += delta;
// Create a new object each time to avoid reference issues // Create a new object each time to avoid reference issues
// Use a timeout to prevent React from batching updates too quickly // Use a timeout to prevent React from batching updates too quickly
const currentContent = newAssistantThreadContent( const currentContent = newAssistantThreadContent(
@ -271,13 +271,13 @@ export const useChat = () => {
{ {
tool_calls: toolCalls.map((e) => ({ tool_calls: toolCalls.map((e) => ({
...e, ...e,
state: 'pending', state: "pending",
})), })),
} },
) );
updateStreamingContent(currentContent) updateStreamingContent(currentContent);
updateTokenSpeed(currentContent) updateTokenSpeed(currentContent);
await new Promise((resolve) => setTimeout(resolve, 0)) await new Promise((resolve) => setTimeout(resolve, 0));
} }
} }
} }
@ -286,18 +286,22 @@ export const useChat = () => {
accumulatedText.length === 0 && accumulatedText.length === 0 &&
toolCalls.length === 0 && toolCalls.length === 0 &&
activeThread.model?.id && activeThread.model?.id &&
activeProvider.provider === 'llama.cpp' activeProvider.provider === "llama.cpp"
) { ) {
await stopModel(activeThread.model.id, 'cortex') await stopModel(activeThread.model.id, "cortex");
throw new Error('No response received from the model') throw new Error("No response received from the model");
} }
// Create a final content object for adding to the thread // Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent( const finalContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText accumulatedText,
) {
builder.addAssistantMessage(accumulatedText, undefined, toolCalls) tokenSpeed: useAppState.getState().tokenSpeed,
},
);
builder.addAssistantMessage(accumulatedText, undefined, toolCalls);
const updatedMessage = await postMessageProcessing( const updatedMessage = await postMessageProcessing(
toolCalls, toolCalls,
builder, builder,
@ -305,41 +309,41 @@ export const useChat = () => {
abortController, abortController,
approvedTools, approvedTools,
allowAllMCPPermissions ? undefined : showApprovalModal, allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions allowAllMCPPermissions,
) );
addMessage(updatedMessage ?? finalContent) addMessage(updatedMessage ?? finalContent);
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent);
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id);
isCompleted = !toolCalls.length isCompleted = !toolCalls.length;
// Do not create agent loop if there is no need for it // Do not create agent loop if there is no need for it
if (!followUpWithToolUse) availableTools = [] if (!followUpWithToolUse) availableTools = [];
} }
} catch (error) { } catch (error) {
const errorMessage = const errorMessage =
error && typeof error === 'object' && 'message' in error error && typeof error === "object" && "message" in error
? error.message ? error.message
: error : error;
if ( if (
typeof errorMessage === 'string' && typeof errorMessage === "string" &&
errorMessage.includes(OUT_OF_CONTEXT_SIZE) && errorMessage.includes(OUT_OF_CONTEXT_SIZE) &&
selectedModel && selectedModel &&
troubleshooting troubleshooting
) { ) {
showModal?.().then((confirmed) => { showModal?.().then((confirmed) => {
if (confirmed) { if (confirmed) {
increaseModelContextSize(selectedModel, activeProvider) increaseModelContextSize(selectedModel, activeProvider);
setTimeout(() => { setTimeout(() => {
sendMessage(message, showModal, false) // Retry sending the message without troubleshooting sendMessage(message, showModal, false); // Retry sending the message without troubleshooting
}, 1000) }, 1000);
} }
}) });
} }
toast.error(`Error sending message: ${errorMessage}`) toast.error(`Error sending message: ${errorMessage}`);
console.error('Error sending message:', error) console.error("Error sending message:", error);
} finally { } finally {
updateLoadingModel(false) updateLoadingModel(false);
updateStreamingContent(undefined) updateStreamingContent(undefined);
} }
}, },
[ [
@ -364,8 +368,8 @@ export const useChat = () => {
showApprovalModal, showApprovalModal,
updateTokenSpeed, updateTokenSpeed,
increaseModelContextSize, increaseModelContextSize,
] ],
) );
return { sendMessage } return { sendMessage };
} };

View File

@ -1,23 +1,23 @@
import { create } from 'zustand' import { create } from "zustand";
import { ThreadMessage } from '@janhq/core' import { ThreadMessage } from "@janhq/core";
import { import {
createMessage, createMessage,
deleteMessage as deleteMessageExt, deleteMessage as deleteMessageExt,
} from '@/services/messages' } from "@/services/messages";
import { useAssistant } from './useAssistant' import { useAssistant } from "./useAssistant";
type MessageState = { type MessageState = {
messages: Record<string, ThreadMessage[]> messages: Record<string, ThreadMessage[]>;
getMessages: (threadId: string) => ThreadMessage[] getMessages: (threadId: string) => ThreadMessage[];
setMessages: (threadId: string, messages: ThreadMessage[]) => void setMessages: (threadId: string, messages: ThreadMessage[]) => void;
addMessage: (message: ThreadMessage) => void addMessage: (message: ThreadMessage) => void;
deleteMessage: (threadId: string, messageId: string) => void deleteMessage: (threadId: string, messageId: string) => void;
} };
export const useMessages = create<MessageState>()((set, get) => ({ export const useMessages = create<MessageState>()((set, get) => ({
messages: {}, messages: {},
getMessages: (threadId) => { getMessages: (threadId) => {
return get().messages[threadId] || [] return get().messages[threadId] || [];
}, },
setMessages: (threadId, messages) => { setMessages: (threadId, messages) => {
set((state) => ({ set((state) => ({
@ -25,10 +25,11 @@ export const useMessages = create<MessageState>()((set, get) => ({
...state.messages, ...state.messages,
[threadId]: messages, [threadId]: messages,
}, },
})) }));
}, },
addMessage: (message) => { addMessage: (message) => {
const currentAssistant = useAssistant.getState().currentAssistant console.log("addMessage: ", message);
const currentAssistant = useAssistant.getState().currentAssistant;
const newMessage = { const newMessage = {
...message, ...message,
created_at: message.created_at || Date.now(), created_at: message.created_at || Date.now(),
@ -36,7 +37,7 @@ export const useMessages = create<MessageState>()((set, get) => ({
...message.metadata, ...message.metadata,
assistant: currentAssistant, assistant: currentAssistant,
}, },
} };
createMessage(newMessage).then((createdMessage) => { createMessage(newMessage).then((createdMessage) => {
set((state) => ({ set((state) => ({
messages: { messages: {
@ -46,19 +47,19 @@ export const useMessages = create<MessageState>()((set, get) => ({
createdMessage, createdMessage,
], ],
}, },
})) }));
}) });
}, },
deleteMessage: (threadId, messageId) => { deleteMessage: (threadId, messageId) => {
deleteMessageExt(threadId, messageId) deleteMessageExt(threadId, messageId);
set((state) => ({ set((state) => ({
messages: { messages: {
...state.messages, ...state.messages,
[threadId]: [threadId]:
state.messages[threadId]?.filter( state.messages[threadId]?.filter(
(message) => message.id !== messageId (message) => message.id !== messageId,
) || [], ) || [],
}, },
})) }));
}, },
})) }));