diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index ccf044a4f..e3fa1753b 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -1,36 +1,36 @@ -import { create } from "zustand"; -import { ThreadMessage } from "@janhq/core"; -import { MCPTool } from "@/types/completion"; -import { useAssistant } from "./useAssistant"; -import { ChatCompletionMessageToolCall } from "openai/resources"; +import { create } from 'zustand' +import { ThreadMessage } from '@janhq/core' +import { MCPTool } from '@/types/completion' +import { useAssistant } from './useAssistant' +import { ChatCompletionMessageToolCall } from 'openai/resources' type AppState = { - streamingContent?: ThreadMessage; - loadingModel?: boolean; - tools: MCPTool[]; - serverStatus: "running" | "stopped" | "pending"; - abortControllers: Record; - tokenSpeed?: TokenSpeed; - currentToolCall?: ChatCompletionMessageToolCall; - showOutOfContextDialog?: boolean; - setServerStatus: (value: "running" | "stopped" | "pending") => void; - updateStreamingContent: (content: ThreadMessage | undefined) => void; + streamingContent?: ThreadMessage + loadingModel?: boolean + tools: MCPTool[] + serverStatus: 'running' | 'stopped' | 'pending' + abortControllers: Record + tokenSpeed?: TokenSpeed + currentToolCall?: ChatCompletionMessageToolCall + showOutOfContextDialog?: boolean + setServerStatus: (value: 'running' | 'stopped' | 'pending') => void + updateStreamingContent: (content: ThreadMessage | undefined) => void updateCurrentToolCall: ( - toolCall: ChatCompletionMessageToolCall | undefined, - ) => void; - updateLoadingModel: (loading: boolean) => void; - updateTools: (tools: MCPTool[]) => void; - setAbortController: (threadId: string, controller: AbortController) => void; - updateTokenSpeed: (message: ThreadMessage) => void; - resetTokenSpeed: () => void; - setOutOfContextDialog: (show: boolean) => void; -}; + toolCall: ChatCompletionMessageToolCall | undefined + ) => void + updateLoadingModel: (loading: boolean) => void + updateTools: (tools: MCPTool[]) => void + setAbortController: (threadId: string, controller: AbortController) => void + updateTokenSpeed: (message: ThreadMessage) => void + resetTokenSpeed: () => void + setOutOfContextDialog: (show: boolean) => void +} export const useAppState = create()((set) => ({ streamingContent: undefined, loadingModel: false, tools: [], - serverStatus: "stopped", + serverStatus: 'stopped', abortControllers: {}, tokenSpeed: undefined, currentToolCall: undefined, @@ -46,19 +46,19 @@ export const useAppState = create()((set) => ({ }, } : undefined, - })); - console.log(useAppState.getState().streamingContent); + })) + console.log(useAppState.getState().streamingContent) }, updateCurrentToolCall: (toolCall) => { set(() => ({ currentToolCall: toolCall, - })); + })) }, updateLoadingModel: (loading) => { - set({ loadingModel: loading }); + set({ loadingModel: loading }) }, updateTools: (tools) => { - set({ tools }); + set({ tools }) }, setServerStatus: (value) => set({ serverStatus: value }), setAbortController: (threadId, controller) => { @@ -67,11 +67,11 @@ export const useAppState = create()((set) => ({ ...state.abortControllers, [threadId]: controller, }, - })); + })) }, updateTokenSpeed: (message) => set((state) => { - const currentTimestamp = new Date().getTime(); // Get current time in milliseconds + const currentTimestamp = new Date().getTime() // Get current time in milliseconds if (!state.tokenSpeed) { // If this is the first update, just set the lastTimestamp and return return { @@ -81,14 +81,14 @@ export const useAppState = create()((set) => ({ tokenCount: 1, message: message.id, }, - }; + } } const timeDiffInSeconds = - (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000; // Time difference in seconds - const totalTokenCount = state.tokenSpeed.tokenCount + 1; + (currentTimestamp - state.tokenSpeed.lastTimestamp) / 1000 // Time difference in seconds + const totalTokenCount = state.tokenSpeed.tokenCount + 1 const averageTokenSpeed = - totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1); // Calculate average token speed + totalTokenCount / (timeDiffInSeconds > 0 ? timeDiffInSeconds : 1) // Calculate average token speed return { tokenSpeed: { ...state.tokenSpeed, @@ -96,7 +96,7 @@ export const useAppState = create()((set) => ({ tokenCount: totalTokenCount, message: message.id, }, - }; + } }), resetTokenSpeed: () => set({ @@ -105,6 +105,6 @@ export const useAppState = create()((set) => ({ setOutOfContextDialog: (show) => { set(() => ({ showOutOfContextDialog: show, - })); + })) }, -})); +})) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 0fbfeb5d9..c8e0fe9f1 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -1,12 +1,12 @@ -import { useCallback, useEffect, useMemo } from "react"; -import { usePrompt } from "./usePrompt"; -import { useModelProvider } from "./useModelProvider"; -import { useThreads } from "./useThreads"; -import { useAppState } from "./useAppState"; -import { useMessages } from "./useMessages"; -import { useRouter } from "@tanstack/react-router"; -import { defaultModel } from "@/lib/models"; -import { route } from "@/constants/routes"; +import { useCallback, useEffect, useMemo } from 'react' +import { usePrompt } from './usePrompt' +import { useModelProvider } from './useModelProvider' +import { useThreads } from './useThreads' +import { useAppState } from './useAppState' +import { useMessages } from './useMessages' +import { useRouter } from '@tanstack/react-router' +import { defaultModel } from '@/lib/models' +import { route } from '@/constants/routes' import { emptyThreadContent, extractToolCall, @@ -15,23 +15,23 @@ import { newUserThreadContent, postMessageProcessing, sendCompletion, -} from "@/lib/completion"; -import { CompletionMessagesBuilder } from "@/lib/messages"; -import { ChatCompletionMessageToolCall } from "openai/resources"; -import { useAssistant } from "./useAssistant"; -import { toast } from "sonner"; -import { getTools } from "@/services/mcp"; -import { MCPTool } from "@/types/completion"; -import { listen } from "@tauri-apps/api/event"; -import { SystemEvent } from "@/types/events"; -import { stopModel, startModel, stopAllModels } from "@/services/models"; +} from '@/lib/completion' +import { CompletionMessagesBuilder } from '@/lib/messages' +import { ChatCompletionMessageToolCall } from 'openai/resources' +import { useAssistant } from './useAssistant' +import { toast } from 'sonner' +import { getTools } from '@/services/mcp' +import { MCPTool } from '@/types/completion' +import { listen } from '@tauri-apps/api/event' +import { SystemEvent } from '@/types/events' +import { stopModel, startModel, stopAllModels } from '@/services/models' -import { useToolApproval } from "@/hooks/useToolApproval"; -import { useToolAvailable } from "@/hooks/useToolAvailable"; -import { OUT_OF_CONTEXT_SIZE } from "@/utils/error"; +import { useToolApproval } from '@/hooks/useToolApproval' +import { useToolAvailable } from '@/hooks/useToolAvailable' +import { OUT_OF_CONTEXT_SIZE } from '@/utils/error' export const useChat = () => { - const { prompt, setPrompt } = usePrompt(); + const { prompt, setPrompt } = usePrompt() const { tools, updateTokenSpeed, @@ -40,51 +40,51 @@ export const useChat = () => { updateStreamingContent, updateLoadingModel, setAbortController, - } = useAppState(); - const { currentAssistant } = useAssistant(); - const { updateProvider } = useModelProvider(); + } = useAppState() + const { currentAssistant } = useAssistant() + const { updateProvider } = useModelProvider() const { approvedTools, showApprovalModal, allowAllMCPPermissions } = - useToolApproval(); - const { getDisabledToolsForThread } = useToolAvailable(); + useToolApproval() + const { getDisabledToolsForThread } = useToolAvailable() const { getProviderByName, selectedModel, selectedProvider } = - useModelProvider(); + useModelProvider() const { getCurrentThread: retrieveThread, createThread, updateThreadTimestamp, - } = useThreads(); - const { getMessages, addMessage } = useMessages(); - const router = useRouter(); + } = useThreads() + const { getMessages, addMessage } = useMessages() + const router = useRouter() const provider = useMemo(() => { - return getProviderByName(selectedProvider); - }, [selectedProvider, getProviderByName]); + return getProviderByName(selectedProvider) + }, [selectedProvider, getProviderByName]) const currentProviderId = useMemo(() => { - return provider?.provider || selectedProvider; - }, [provider, selectedProvider]); + return provider?.provider || selectedProvider + }, [provider, selectedProvider]) useEffect(() => { function setTools() { getTools().then((data: MCPTool[]) => { - updateTools(data); - }); + updateTools(data) + }) } - setTools(); + setTools() - let unsubscribe = () => {}; + let unsubscribe = () => {} listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => { // Unsubscribe from the event when the component unmounts - unsubscribe = unsub; - }); - return unsubscribe; - }, [updateTools]); + unsubscribe = unsub + }) + return unsubscribe + }, [updateTools]) const getCurrentThread = useCallback(async () => { - let currentThread = retrieveThread(); + let currentThread = retrieveThread() if (!currentThread) { currentThread = await createThread( { @@ -92,14 +92,14 @@ export const useChat = () => { provider: selectedProvider, }, prompt, - currentAssistant, - ); + currentAssistant + ) router.navigate({ to: route.threadsDetail, params: { threadId: currentThread.id }, - }); + }) } - return currentThread; + return currentThread }, [ createThread, prompt, @@ -108,7 +108,7 @@ export const useChat = () => { selectedModel?.id, selectedProvider, currentAssistant, - ]); + ]) const increaseModelContextSize = useCallback( (model: Model, provider: ProviderObject) => { @@ -118,12 +118,12 @@ export const useChat = () => { */ const ctxSize = Math.max( model.settings?.ctx_len?.controller_props.value - ? typeof model.settings.ctx_len.controller_props.value === "string" + ? typeof model.settings.ctx_len.controller_props.value === 'string' ? parseInt(model.settings.ctx_len.controller_props.value as string) : (model.settings.ctx_len.controller_props.value as number) : 8192, - 8192, - ); + 8192 + ) const updatedModel = { ...model, settings: { @@ -136,80 +136,80 @@ export const useChat = () => { }, }, }, - }; + } // Find the model index in the provider's models array - const modelIndex = provider.models.findIndex((m) => m.id === model.id); + const modelIndex = provider.models.findIndex((m) => m.id === model.id) if (modelIndex !== -1) { // Create a copy of the provider's models array - const updatedModels = [...provider.models]; + const updatedModels = [...provider.models] // Update the specific model in the array - updatedModels[modelIndex] = updatedModel as Model; + updatedModels[modelIndex] = updatedModel as Model // Update the provider with the new models array updateProvider(provider.provider, { models: updatedModels, - }); + }) } - stopAllModels(); + stopAllModels() }, - [updateProvider], - ); + [updateProvider] + ) const sendMessage = useCallback( async ( message: string, showModal?: () => Promise, - troubleshooting = true, + troubleshooting = true ) => { - const activeThread = await getCurrentThread(); + const activeThread = await getCurrentThread() - resetTokenSpeed(); + resetTokenSpeed() const activeProvider = currentProviderId ? getProviderByName(currentProviderId) - : provider; - if (!activeThread || !activeProvider) return; - const messages = getMessages(activeThread.id); - const abortController = new AbortController(); - setAbortController(activeThread.id, abortController); - updateStreamingContent(emptyThreadContent); + : provider + if (!activeThread || !activeProvider) return + const messages = getMessages(activeThread.id) + const abortController = new AbortController() + setAbortController(activeThread.id, abortController) + updateStreamingContent(emptyThreadContent) // Do not add new message on retry if (troubleshooting) - addMessage(newUserThreadContent(activeThread.id, message)); - updateThreadTimestamp(activeThread.id); - setPrompt(""); + addMessage(newUserThreadContent(activeThread.id, message)) + updateThreadTimestamp(activeThread.id) + setPrompt('') try { if (selectedModel?.id) { - updateLoadingModel(true); + updateLoadingModel(true) await startModel( activeProvider, selectedModel.id, - abortController, - ).catch(console.error); - updateLoadingModel(false); + abortController + ).catch(console.error) + updateLoadingModel(false) } const builder = new CompletionMessagesBuilder( messages, - currentAssistant?.instructions, - ); + currentAssistant?.instructions + ) - builder.addUserMessage(message); + builder.addUserMessage(message) - let isCompleted = false; + let isCompleted = false // Filter tools based on model capabilities and available tools for this thread - let availableTools = selectedModel?.capabilities?.includes("tools") + let availableTools = selectedModel?.capabilities?.includes('tools') ? tools.filter((tool) => { - const disabledTools = getDisabledToolsForThread(activeThread.id); - return !disabledTools.includes(tool.name); + const disabledTools = getDisabledToolsForThread(activeThread.id) + return !disabledTools.includes(tool.name) }) - : []; + : [] // TODO: Later replaced by Agent setup? - const followUpWithToolUse = true; + const followUpWithToolUse = true while (!isCompleted && !abortController.signal.aborted) { const completion = await sendCompletion( activeThread, @@ -218,51 +218,51 @@ export const useChat = () => { abortController, availableTools, currentAssistant.parameters?.stream === false ? false : true, - currentAssistant.parameters as unknown as Record, + currentAssistant.parameters as unknown as Record // TODO: replace it with according provider setting later on // selectedProvider === 'llama.cpp' && availableTools.length > 0 // ? false // : true - ); + ) - if (!completion) throw new Error("No completion received"); - let accumulatedText = ""; - const currentCall: ChatCompletionMessageToolCall | null = null; - const toolCalls: ChatCompletionMessageToolCall[] = []; + if (!completion) throw new Error('No completion received') + let accumulatedText = '' + const currentCall: ChatCompletionMessageToolCall | null = null + const toolCalls: ChatCompletionMessageToolCall[] = [] if (isCompletionResponse(completion)) { - accumulatedText = completion.choices[0]?.message?.content || ""; + accumulatedText = completion.choices[0]?.message?.content || '' if (completion.choices[0]?.message?.tool_calls) { - toolCalls.push(...completion.choices[0].message.tool_calls); + toolCalls.push(...completion.choices[0].message.tool_calls) } } else { for await (const part of completion) { // Error message if (!part.choices) { throw new Error( - "message" in part + 'message' in part ? (part.message as string) - : (JSON.stringify(part) ?? ""), - ); + : (JSON.stringify(part) ?? '') + ) } - const delta = part.choices[0]?.delta?.content || ""; + const delta = part.choices[0]?.delta?.content || '' if (part.choices[0]?.delta?.tool_calls) { - const calls = extractToolCall(part, currentCall, toolCalls); + const calls = extractToolCall(part, currentCall, toolCalls) const currentContent = newAssistantThreadContent( activeThread.id, accumulatedText, { tool_calls: calls.map((e) => ({ ...e, - state: "pending", + state: 'pending', })), - }, - ); - updateStreamingContent(currentContent); - await new Promise((resolve) => setTimeout(resolve, 0)); + } + ) + updateStreamingContent(currentContent) + await new Promise((resolve) => setTimeout(resolve, 0)) } if (delta) { - accumulatedText += delta; + accumulatedText += delta // Create a new object each time to avoid reference issues // Use a timeout to prevent React from batching updates too quickly const currentContent = newAssistantThreadContent( @@ -271,13 +271,13 @@ export const useChat = () => { { tool_calls: toolCalls.map((e) => ({ ...e, - state: "pending", + state: 'pending', })), - }, - ); - updateStreamingContent(currentContent); - updateTokenSpeed(currentContent); - await new Promise((resolve) => setTimeout(resolve, 0)); + } + ) + updateStreamingContent(currentContent) + updateTokenSpeed(currentContent) + await new Promise((resolve) => setTimeout(resolve, 0)) } } } @@ -286,10 +286,10 @@ export const useChat = () => { accumulatedText.length === 0 && toolCalls.length === 0 && activeThread.model?.id && - activeProvider.provider === "llama.cpp" + activeProvider.provider === 'llama.cpp' ) { - await stopModel(activeThread.model.id, "cortex"); - throw new Error("No response received from the model"); + await stopModel(activeThread.model.id, 'cortex') + throw new Error('No response received from the model') } // Create a final content object for adding to the thread @@ -298,10 +298,10 @@ export const useChat = () => { accumulatedText, { tokenSpeed: useAppState.getState().tokenSpeed, - }, - ); + } + ) - builder.addAssistantMessage(accumulatedText, undefined, toolCalls); + builder.addAssistantMessage(accumulatedText, undefined, toolCalls) const updatedMessage = await postMessageProcessing( toolCalls, builder, @@ -309,41 +309,41 @@ export const useChat = () => { abortController, approvedTools, allowAllMCPPermissions ? undefined : showApprovalModal, - allowAllMCPPermissions, - ); - addMessage(updatedMessage ?? finalContent); - updateStreamingContent(emptyThreadContent); - updateThreadTimestamp(activeThread.id); + allowAllMCPPermissions + ) + addMessage(updatedMessage ?? finalContent) + updateStreamingContent(emptyThreadContent) + updateThreadTimestamp(activeThread.id) - isCompleted = !toolCalls.length; + isCompleted = !toolCalls.length // Do not create agent loop if there is no need for it - if (!followUpWithToolUse) availableTools = []; + if (!followUpWithToolUse) availableTools = [] } } catch (error) { const errorMessage = - error && typeof error === "object" && "message" in error + error && typeof error === 'object' && 'message' in error ? error.message - : error; + : error if ( - typeof errorMessage === "string" && + typeof errorMessage === 'string' && errorMessage.includes(OUT_OF_CONTEXT_SIZE) && selectedModel && troubleshooting ) { showModal?.().then((confirmed) => { if (confirmed) { - increaseModelContextSize(selectedModel, activeProvider); + increaseModelContextSize(selectedModel, activeProvider) setTimeout(() => { - sendMessage(message, showModal, false); // Retry sending the message without troubleshooting - }, 1000); + sendMessage(message, showModal, false) // Retry sending the message without troubleshooting + }, 1000) } - }); + }) } - toast.error(`Error sending message: ${errorMessage}`); - console.error("Error sending message:", error); + toast.error(`Error sending message: ${errorMessage}`) + console.error('Error sending message:', error) } finally { - updateLoadingModel(false); - updateStreamingContent(undefined); + updateLoadingModel(false) + updateStreamingContent(undefined) } }, [ @@ -368,8 +368,8 @@ export const useChat = () => { showApprovalModal, updateTokenSpeed, increaseModelContextSize, - ], - ); + ] + ) - return { sendMessage }; -}; + return { sendMessage } +}