diff --git a/web-app/src/containers/ChatInput.tsx b/web-app/src/containers/ChatInput.tsx index 1a047710a..825716260 100644 --- a/web-app/src/containers/ChatInput.tsx +++ b/web-app/src/containers/ChatInput.tsx @@ -3,7 +3,7 @@ import TextareaAutosize from 'react-textarea-autosize' import { cn } from '@/lib/utils' import { usePrompt } from '@/hooks/usePrompt' -import { useCallback, useEffect, useMemo, useRef, useState } from 'react' +import { useEffect, useRef, useState } from 'react' import { Button } from '@/components/ui/button' import { ArrowRight } from 'lucide-react' import { @@ -20,28 +20,14 @@ import { import { useTranslation } from 'react-i18next' import { useGeneralSetting } from '@/hooks/useGeneralSetting' import { useModelProvider } from '@/hooks/useModelProvider' -import { - emptyThreadContent, - extractToolCall, - newAssistantThreadContent, - newUserThreadContent, - postMessageProcessing, - sendCompletion, - startModel, -} from '@/lib/completion' -import { useThreads } from '@/hooks/useThreads' -import { defaultModel } from '@/lib/models' -import { useMessages } from '@/hooks/useMessages' -import { useRouter } from '@tanstack/react-router' -import { route } from '@/constants/routes' + import { useAppState } from '@/hooks/useAppState' import { MovingBorder } from './MovingBorder' import { MCPTool } from '@/types/completion' import { listen } from '@tauri-apps/api/event' import { SystemEvent } from '@/types/events' -import { CompletionMessagesBuilder } from '@/lib/messages' -import { ChatCompletionMessageToolCall } from 'openai/resources' import { getTools } from '@/services/mcp' +import { useChat } from '@/hooks/useChat' type ChatInputProps = { className?: string @@ -52,24 +38,14 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { const textareaRef = useRef(null) const [isFocused, setIsFocused] = useState(false) const [rows, setRows] = useState(1) - const [tools, setTools] = useState([]) + const { streamingContent, updateTools } = useAppState() const { prompt, setPrompt } = usePrompt() const { t } = useTranslation() const { spellCheckChatInput } = useGeneralSetting() const maxRows = 10 - const { getProviderByName, selectedModel, selectedProvider } = - useModelProvider() - - const { getCurrentThread: retrieveThread, createThread } = useThreads() - const { streamingContent, updateStreamingContent, updateLoadingModel } = - useAppState() - const { addMessage } = useMessages() - const router = useRouter() - - const provider = useMemo(() => { - return getProviderByName(selectedProvider) - }, [selectedProvider, getProviderByName]) + const { selectedModel } = useModelProvider() + const { sendMessage } = useChat() useEffect(() => { const handleFocusIn = () => { @@ -94,20 +70,20 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { }, []) useEffect(() => { - function updateTools() { + function setTools() { getTools().then((data: MCPTool[]) => { - setTools(data) + updateTools(data) }) } - updateTools() + setTools() let unsubscribe = () => {} - listen(SystemEvent.MCP_UPDATE, updateTools).then((unsub) => { + listen(SystemEvent.MCP_UPDATE, setTools).then((unsub) => { // Unsubscribe from the event when the component unmounts unsubscribe = unsub }) return unsubscribe - }, []) + }, [updateTools]) useEffect(() => { if (textareaRef.current) { @@ -115,115 +91,6 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { } }, []) - const getCurrentThread = useCallback(async () => { - let currentThread = retrieveThread() - if (!currentThread) { - currentThread = await createThread( - { - id: selectedModel?.id ?? defaultModel(selectedProvider), - provider: selectedProvider, - }, - prompt - ) - router.navigate({ - to: route.threadsDetail, - params: { threadId: currentThread.id }, - }) - } - return currentThread - }, [ - createThread, - prompt, - retrieveThread, - router, - selectedModel?.id, - selectedProvider, - ]) - - const sendMessage = useCallback(async () => { - const activeThread = await getCurrentThread() - - if (!activeThread || !provider) return - - updateStreamingContent(emptyThreadContent) - addMessage(newUserThreadContent(activeThread.id, prompt)) - setPrompt('') - try { - if (selectedModel?.id) { - updateLoadingModel(true) - await startModel(provider.provider, selectedModel.id).catch( - console.error - ) - updateLoadingModel(false) - } - - const builder = new CompletionMessagesBuilder() - // REMARK: Would it possible to not attach the entire message history to the request? - // TODO: If not amend messages history here - builder.addUserMessage(prompt) - - let isCompleted = false - - while (!isCompleted) { - const completion = await sendCompletion( - activeThread, - provider, - builder.getMessages(), - tools - ) - - if (!completion) throw new Error('No completion received') - let accumulatedText = '' - const currentCall: ChatCompletionMessageToolCall | null = null - const toolCalls: ChatCompletionMessageToolCall[] = [] - for await (const part of completion) { - const delta = part.choices[0]?.delta?.content || '' - if (part.choices[0]?.delta?.tool_calls) { - extractToolCall(part, currentCall, toolCalls) - } - if (delta) { - accumulatedText += delta - // Create a new object each time to avoid reference issues - // Use a timeout to prevent React from batching updates too quickly - const currentContent = newAssistantThreadContent( - activeThread.id, - accumulatedText - ) - updateStreamingContent(currentContent) - await new Promise((resolve) => setTimeout(resolve, 0)) - } - } - // Create a final content object for adding to the thread - const finalContent = newAssistantThreadContent( - activeThread.id, - accumulatedText - ) - builder.addAssistantMessage(accumulatedText, undefined, toolCalls) - const updatedMessage = await postMessageProcessing( - toolCalls, - builder, - finalContent - ) - addMessage(updatedMessage ?? finalContent) - - isCompleted = !toolCalls.length - } - } catch (error) { - console.error('Error sending message:', error) - } - updateStreamingContent(undefined) - }, [ - getCurrentThread, - provider, - updateStreamingContent, - addMessage, - prompt, - setPrompt, - selectedModel, - tools, - updateLoadingModel, - ]) - return (
{ if (e.key === 'Enter' && !e.shiftKey && prompt) { e.preventDefault() // Submit the message when Enter is pressed without Shift - sendMessage() + sendMessage(prompt) // When Shift+Enter is pressed, a new line is added (default behavior) } }} @@ -351,7 +218,7 @@ const ChatInput = ({ className, showSpeedToken = true }: ChatInputProps) => { variant={!prompt ? null : 'default'} size="icon" disabled={!prompt} - onClick={sendMessage} + onClick={() => sendMessage(prompt)} > {streamingContent ? ( diff --git a/web-app/src/containers/DownloadManegement.tsx b/web-app/src/containers/DownloadManegement.tsx index 79d90aeed..4aecfc5fd 100644 --- a/web-app/src/containers/DownloadManegement.tsx +++ b/web-app/src/containers/DownloadManegement.tsx @@ -7,7 +7,7 @@ import { Progress } from '@/components/ui/progress' import { useDownloadStore } from '@/hooks/useDownloadStore' import { abortDownload } from '@/services/models' import { DownloadEvent, DownloadState, events } from '@janhq/core' -import { IconPlayerPauseFilled, IconX } from '@tabler/icons-react' +import { IconX } from '@tabler/icons-react' import { useCallback, useEffect, useMemo } from 'react' export function DownloadManagement() { diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index f7ae1bec4..e9f0602c1 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -1,6 +1,6 @@ import { ThreadMessage } from '@janhq/core' import { RenderMarkdown } from './RenderMarkdown' -import { Fragment, memo, useMemo, useState } from 'react' +import { Fragment, memo, useCallback, useMemo, useState } from 'react' import { IconCopy, IconCopyCheck, @@ -13,6 +13,7 @@ import { cn } from '@/lib/utils' import { useMessages } from '@/hooks/useMessages' import ThinkingBlock from '@/containers/ThinkingBlock' import ToolCallBlock from '@/containers/ToolCallBlock' +import { useChat } from '@/hooks/useChat' const CopyButton = ({ text }: { text: string }) => { const [copied, setCopied] = useState(false) @@ -25,7 +26,7 @@ const CopyButton = ({ text }: { text: string }) => { return ( - + {item.isLastMessage && ( + + )}
)} diff --git a/web-app/src/hooks/useAppState.ts b/web-app/src/hooks/useAppState.ts index 93629dc73..f1f3f2864 100644 --- a/web-app/src/hooks/useAppState.ts +++ b/web-app/src/hooks/useAppState.ts @@ -1,20 +1,27 @@ import { create } from 'zustand' import { ThreadMessage } from '@janhq/core' +import { MCPTool } from '@/types/completion' type AppState = { streamingContent?: ThreadMessage loadingModel?: boolean + tools: MCPTool[] updateStreamingContent: (content: ThreadMessage | undefined) => void updateLoadingModel: (loading: boolean) => void + updateTools: (tools: MCPTool[]) => void } export const useAppState = create()((set) => ({ streamingContent: undefined, loadingModel: false, + tools: [], updateStreamingContent: (content) => { set({ streamingContent: content }) }, updateLoadingModel: (loading) => { set({ loadingModel: loading }) }, + updateTools: (tools) => { + set({ tools }) + }, })) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts new file mode 100644 index 000000000..92d27bf54 --- /dev/null +++ b/web-app/src/hooks/useChat.ts @@ -0,0 +1,149 @@ +import { useCallback, useMemo } from 'react' +import { usePrompt } from './usePrompt' +import { useModelProvider } from './useModelProvider' +import { useThreads } from './useThreads' +import { useAppState } from './useAppState' +import { useMessages } from './useMessages' +import { useRouter } from '@tanstack/react-router' +import { defaultModel } from '@/lib/models' +import { route } from '@/constants/routes' +import { + emptyThreadContent, + extractToolCall, + newAssistantThreadContent, + newUserThreadContent, + postMessageProcessing, + sendCompletion, + startModel, +} from '@/lib/completion' +import { CompletionMessagesBuilder } from '@/lib/messages' +import { ChatCompletionMessageToolCall } from 'openai/resources' + +export const useChat = () => { + const { prompt, setPrompt } = usePrompt() + const { tools } = useAppState() + + const { getProviderByName, selectedModel, selectedProvider } = + useModelProvider() + + const { getCurrentThread: retrieveThread, createThread } = useThreads() + const { updateStreamingContent, updateLoadingModel } = useAppState() + const { addMessage } = useMessages() + const router = useRouter() + + const provider = useMemo(() => { + return getProviderByName(selectedProvider) + }, [selectedProvider, getProviderByName]) + const getCurrentThread = useCallback(async () => { + let currentThread = retrieveThread() + if (!currentThread) { + currentThread = await createThread( + { + id: selectedModel?.id ?? defaultModel(selectedProvider), + provider: selectedProvider, + }, + prompt + ) + router.navigate({ + to: route.threadsDetail, + params: { threadId: currentThread.id }, + }) + } + return currentThread + }, [ + createThread, + prompt, + retrieveThread, + router, + selectedModel?.id, + selectedProvider, + ]) + + const sendMessage = useCallback( + async (message: string) => { + const activeThread = await getCurrentThread() + + if (!activeThread || !provider) return + + updateStreamingContent(emptyThreadContent) + addMessage(newUserThreadContent(activeThread.id, message)) + setPrompt('') + try { + if (selectedModel?.id) { + updateLoadingModel(true) + await startModel(provider.provider, selectedModel.id).catch( + console.error + ) + updateLoadingModel(false) + } + + const builder = new CompletionMessagesBuilder() + // REMARK: Would it possible to not attach the entire message history to the request? + // TODO: If not amend messages history here + builder.addUserMessage(message) + + let isCompleted = false + + while (!isCompleted) { + const completion = await sendCompletion( + activeThread, + provider, + builder.getMessages(), + tools + ) + + if (!completion) throw new Error('No completion received') + let accumulatedText = '' + const currentCall: ChatCompletionMessageToolCall | null = null + const toolCalls: ChatCompletionMessageToolCall[] = [] + for await (const part of completion) { + const delta = part.choices[0]?.delta?.content || '' + if (part.choices[0]?.delta?.tool_calls) { + extractToolCall(part, currentCall, toolCalls) + } + if (delta) { + accumulatedText += delta + // Create a new object each time to avoid reference issues + // Use a timeout to prevent React from batching updates too quickly + const currentContent = newAssistantThreadContent( + activeThread.id, + accumulatedText + ) + updateStreamingContent(currentContent) + await new Promise((resolve) => setTimeout(resolve, 0)) + } + } + // Create a final content object for adding to the thread + const finalContent = newAssistantThreadContent( + activeThread.id, + accumulatedText + ) + builder.addAssistantMessage(accumulatedText, undefined, toolCalls) + const updatedMessage = await postMessageProcessing( + toolCalls, + builder, + finalContent + ) + addMessage(updatedMessage ?? finalContent) + + isCompleted = !toolCalls.length + } + } catch (error) { + console.error('Error sending message:', error) + } + updateStreamingContent(undefined) + }, + [ + getCurrentThread, + provider, + updateStreamingContent, + addMessage, + setPrompt, + selectedModel, + tools, + updateLoadingModel, + ] + ) + + return { sendMessage } +} diff --git a/web-app/src/hooks/useMessages.ts b/web-app/src/hooks/useMessages.ts index aa24f6729..10da6e099 100644 --- a/web-app/src/hooks/useMessages.ts +++ b/web-app/src/hooks/useMessages.ts @@ -9,6 +9,7 @@ import { type MessageState = { messages: Record + getMessages: (threadId: string) => ThreadMessage[] setMessages: (threadId: string, messages: ThreadMessage[]) => void addMessage: (message: ThreadMessage) => void deleteMessage: (threadId: string, messageId: string) => void @@ -16,8 +17,11 @@ type MessageState = { export const useMessages = create()( persist( - (set) => ({ + (set, get) => ({ messages: {}, + getMessages: (threadId) => { + return get().messages[threadId] || [] + }, setMessages: (threadId, messages) => { set((state) => ({ messages: {