diff --git a/web-app/src/containers/ThreadContent.tsx b/web-app/src/containers/ThreadContent.tsx index 8116ad8d8..3a2cf4681 100644 --- a/web-app/src/containers/ThreadContent.tsx +++ b/web-app/src/containers/ThreadContent.tsx @@ -116,10 +116,13 @@ export const ThreadContent = memo( // Only regenerate assistant message is allowed deleteMessage(item.thread_id, item.id) const threadMessages = getMessages(item.thread_id) - const lastMessage = threadMessages[threadMessages.length - 1] - if (!lastMessage) return - deleteMessage(lastMessage.thread_id, lastMessage.id) - sendMessage(lastMessage.content?.[0]?.text?.value || '') + let toSendMessage = threadMessages.pop() + while (toSendMessage && toSendMessage?.role !== 'user') { + deleteMessage(toSendMessage.thread_id, toSendMessage.id ?? '') + toSendMessage = threadMessages.pop() + } + if (toSendMessage) + sendMessage(toSendMessage.content?.[0]?.text?.value || '') }, [deleteMessage, getMessages, item, sendMessage]) const editMessage = useCallback( diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index dd46d3789..4af528dcf 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -10,6 +10,7 @@ import { route } from '@/constants/routes' import { emptyThreadContent, extractToolCall, + isCompletionResponse, newAssistantThreadContent, newUserThreadContent, postMessageProcessing, @@ -19,6 +20,7 @@ import { import { CompletionMessagesBuilder } from '@/lib/messages' import { ChatCompletionMessageToolCall } from 'openai/resources' import { useAssistant } from './useAssistant' +import { toast } from 'sonner' export const useChat = () => { const { prompt, setPrompt } = usePrompt() @@ -78,9 +80,7 @@ export const useChat = () => { try { if (selectedModel?.id) { updateLoadingModel(true) - await startModel(provider, selectedModel.id).catch( - console.error - ) + await startModel(provider, selectedModel.id).catch(console.error) updateLoadingModel(false) } @@ -100,29 +100,38 @@ export const useChat = () => { provider, builder.getMessages(), abortController, - tools + tools, + // TODO: replace it with according provider setting later on + selectedProvider === 'llama.cpp' && tools.length > 0 ? false : true ) if (!completion) throw new Error('No completion received') let accumulatedText = '' const currentCall: ChatCompletionMessageToolCall | null = null const toolCalls: ChatCompletionMessageToolCall[] = [] - for await (const part of completion) { - const delta = part.choices[0]?.delta?.content || '' - if (part.choices[0]?.delta?.tool_calls) { - extractToolCall(part, currentCall, toolCalls) + if (isCompletionResponse(completion)) { + accumulatedText = completion.choices[0]?.message?.content || '' + if (completion.choices[0]?.message?.tool_calls) { + toolCalls.push(...completion.choices[0].message.tool_calls) } - if (delta) { - accumulatedText += delta - // Create a new object each time to avoid reference issues - // Use a timeout to prevent React from batching updates too quickly - const currentContent = newAssistantThreadContent( - activeThread.id, - accumulatedText - ) - updateStreamingContent(currentContent) - updateTokenSpeed(currentContent) - await new Promise((resolve) => setTimeout(resolve, 0)) + } else { + for await (const part of completion) { + const delta = part.choices[0]?.delta?.content || '' + if (part.choices[0]?.delta?.tool_calls) { + extractToolCall(part, currentCall, toolCalls) + } + if (delta) { + accumulatedText += delta + // Create a new object each time to avoid reference issues + // Use a timeout to prevent React from batching updates too quickly + const currentContent = newAssistantThreadContent( + activeThread.id, + accumulatedText + ) + updateStreamingContent(currentContent) + updateTokenSpeed(currentContent) + await new Promise((resolve) => setTimeout(resolve, 0)) + } } } // Create a final content object for adding to the thread @@ -141,9 +150,14 @@ export const useChat = () => { isCompleted = !toolCalls.length } } catch (error) { + toast.error( + `Error sending message: ${error && typeof error === 'object' && 'message' in error ? error.message : error}` + ) console.error('Error sending message:', error) + } finally { + updateLoadingModel(false) + updateStreamingContent(undefined) } - updateStreamingContent(undefined) }, [ getCurrentThread, @@ -157,6 +171,7 @@ export const useChat = () => { setAbortController, updateLoadingModel, tools, + selectedProvider, updateTokenSpeed, ] ) diff --git a/web-app/src/lib/completion.ts b/web-app/src/lib/completion.ts index b7fa59068..92a17f321 100644 --- a/web-app/src/lib/completion.ts +++ b/web-app/src/lib/completion.ts @@ -10,6 +10,7 @@ import { invoke } from '@tauri-apps/api/core' import { ChatCompletionMessageParam, ChatCompletionTool, + CompletionResponse, CompletionResponseChunk, models, StreamCompletionResponse, @@ -111,8 +112,9 @@ export const sendCompletion = async ( provider: ModelProvider, messages: ChatCompletionMessageParam[], abortController: AbortController, - tools: MCPTool[] = [] -): Promise => { + tools: MCPTool[] = [], + stream: boolean = true +): Promise => { if (!thread?.model?.id || !provider) return undefined let providerName = provider.provider as unknown as keyof typeof models @@ -127,22 +129,37 @@ export const sendCompletion = async ( }) // TODO: Add message history - const completion = await tokenJS.chat.completions.create( - { - stream: true, - provider: providerName, - model: thread.model?.id, - messages, - tools: normalizeTools(tools), - tool_choice: tools.length ? 'auto' : undefined, - }, - { - signal: abortController.signal, - } - ) + const completion = stream + ? await tokenJS.chat.completions.create( + { + stream: true, + provider: providerName, + model: thread.model?.id, + messages, + tools: normalizeTools(tools), + tool_choice: tools.length ? 'auto' : undefined, + }, + { + signal: abortController.signal, + } + ) + : await tokenJS.chat.completions.create({ + stream: false, + provider: providerName, + model: thread.model?.id, + messages, + tools: normalizeTools(tools), + tool_choice: tools.length ? 'auto' : undefined, + }) return completion } +export const isCompletionResponse = ( + response: StreamCompletionResponse | CompletionResponse +): response is CompletionResponse => { + return 'choices' in response +} + /** * @fileoverview Helper function to start a model. * This function loads the model from the provider. diff --git a/web-app/src/routes/settings/mcp-servers.tsx b/web-app/src/routes/settings/mcp-servers.tsx index 752a15f83..9cce36d25 100644 --- a/web-app/src/routes/settings/mcp-servers.tsx +++ b/web-app/src/routes/settings/mcp-servers.tsx @@ -137,6 +137,12 @@ function MCPServers() { useEffect(() => { getConnectedServers().then(setConnectedServers) + + const intervalId = setInterval(() => { + getConnectedServers().then(setConnectedServers) + }, 5000) + + return () => clearInterval(intervalId) }, [setConnectedServers]) return (