import { create } from 'zustand' import { ThreadMessage } from '@janhq/core' import { MCPTool } from '@/types/completion' import { useAssistant } from './useAssistant' import { ChatCompletionMessageToolCall } from 'openai/resources' type PromptProgress = { cache: number processed: number time_ms: number total: number } type AppErrorMessage = { message?: string title?: string subtitle: string } type AppState = { streamingContent?: ThreadMessage loadingModel?: boolean tools: MCPTool[] serverStatus: 'running' | 'stopped' | 'pending' abortControllers: Record tokenSpeed?: TokenSpeed currentToolCall?: ChatCompletionMessageToolCall showOutOfContextDialog?: boolean errorMessage?: AppErrorMessage promptProgress?: PromptProgress cancelToolCall?: () => void setServerStatus: (value: 'running' | 'stopped' | 'pending') => void updateStreamingContent: (content: ThreadMessage | undefined) => void updateCurrentToolCall: ( toolCall: ChatCompletionMessageToolCall | undefined ) => void updateLoadingModel: (loading: boolean) => void updateTools: (tools: MCPTool[]) => void setAbortController: (threadId: string, controller: AbortController) => void updateTokenSpeed: (message: ThreadMessage, increment?: number) => void resetTokenSpeed: () => void setOutOfContextDialog: (show: boolean) => void setCancelToolCall: (cancel: (() => void) | undefined) => void setErrorMessage: (error: AppErrorMessage | undefined) => void updatePromptProgress: (progress: PromptProgress | undefined) => void } export const useAppState = create()((set) => ({ streamingContent: undefined, loadingModel: false, tools: [], serverStatus: 'stopped', abortControllers: {}, tokenSpeed: undefined, currentToolCall: undefined, promptProgress: undefined, cancelToolCall: undefined, updateStreamingContent: (content: ThreadMessage | undefined) => { const assistants = useAssistant.getState().assistants const currentAssistant = useAssistant.getState().currentAssistant const selectedAssistant = assistants.find((a) => a.id === currentAssistant?.id) || assistants[0] set(() => ({ streamingContent: content ? { ...content, created_at: content.created_at || Date.now(), metadata: { ...content.metadata, assistant: selectedAssistant, }, } : undefined, })) }, updateCurrentToolCall: (toolCall) => { set(() => ({ currentToolCall: toolCall, })) }, updateLoadingModel: (loading) => { set({ loadingModel: loading }) }, updateTools: (tools) => { set({ tools }) }, setServerStatus: (value) => set({ serverStatus: value }), setAbortController: (threadId, controller) => { set((state) => ({ abortControllers: { ...state.abortControllers, [threadId]: controller, }, })) }, updateTokenSpeed: (message, increment = 1) => set((state) => { const currentTimestamp = new Date().getTime() // Get current time in milliseconds if (!state.tokenSpeed) { // If this is the first update, just set the lastTimestamp and return return { tokenSpeed: { lastTimestamp: currentTimestamp, tokenSpeed: 0, tokenCount: increment, message: message.id, }, } } const timeDiffInSeconds = (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds const totalTokenCount = state.tokenSpeed.tokenCount + increment const averageTokenSpeed = totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed return { tokenSpeed: { ...state.tokenSpeed, tokenSpeed: averageTokenSpeed, tokenCount: totalTokenCount, message: message.id, }, } }), resetTokenSpeed: () => set({ tokenSpeed: undefined, }), setOutOfContextDialog: (show) => { set(() => ({ showOutOfContextDialog: show, })) }, setCancelToolCall: (cancel) => { set(() => ({ cancelToolCall: cancel, })) }, setErrorMessage: (error) => { set(() => ({ errorMessage: error, })) }, updatePromptProgress: (progress) => { set(() => ({ promptProgress: progress, })) }, }))