diff --git a/web-app/src/containers/GenerateResponseButton.tsx b/web-app/src/containers/GenerateResponseButton.tsx index 29b098776..477fc1e58 100644 --- a/web-app/src/containers/GenerateResponseButton.tsx +++ b/web-app/src/containers/GenerateResponseButton.tsx @@ -33,14 +33,18 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => { }, [messages]) const generateAIResponse = () => { - // If continuing a partial response, delete the partial message first + // If continuing a partial response, keep the message and continue from it if (isPartialResponse) { const partialMessage = messages[messages.length - 1] - deleteMessage(partialMessage.thread_id, partialMessage.id ?? '') - // Get the user message that prompted this partial response const userMessage = messages[messages.length - 2] if (userMessage?.content?.[0]?.text?.value) { - sendMessage(userMessage.content[0].text.value, false) + // Pass the partial message ID to continue from it + sendMessage( + userMessage.content[0].text.value, + false, + undefined, + partialMessage.id + ) } return } diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 292adb235..6f636cbec 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -40,14 +40,20 @@ import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat' const createThreadContent = ( threadId: string, text: string, - toolCalls: ChatCompletionMessageToolCall[] + toolCalls: ChatCompletionMessageToolCall[], + messageId?: string ) => { - return newAssistantThreadContent(threadId, text, { + const content = newAssistantThreadContent(threadId, text, { tool_calls: toolCalls.map((e) => ({ ...e, state: 'pending', })), }) + // If continuing from a message, preserve the message ID + if (messageId) { + return { ...content, id: messageId } + } + return content } // Helper to cancel animation frame cross-platform @@ -66,9 +72,16 @@ const finalizeMessage = ( addMessage: (message: ThreadMessage) => void, updateStreamingContent: (content: ThreadMessage | undefined) => void, updatePromptProgress: (progress: unknown) => void, - updateThreadTimestamp: (threadId: string) => void + updateThreadTimestamp: (threadId: string) => void, + updateMessage?: (message: ThreadMessage) => void, + continueFromMessageId?: string ) => { - addMessage(finalContent) + // If continuing from a message, update it; otherwise add new message + if (continueFromMessageId && updateMessage) { + updateMessage({ ...finalContent, id: continueFromMessageId }) + } else { + addMessage(finalContent) + } updateStreamingContent(emptyThreadContent) updatePromptProgress(undefined) updateThreadTimestamp(finalContent.thread_id) @@ -85,7 +98,9 @@ const processStreamingCompletion = async ( currentCall: ChatCompletionMessageToolCall | null, updateStreamingContent: (content: ThreadMessage | undefined) => void, updateTokenSpeed: (message: ThreadMessage, increment?: number) => void, - updatePromptProgress: (progress: unknown) => void + updatePromptProgress: (progress: unknown) => void, + continueFromMessageId?: string, + updateMessage?: (message: ThreadMessage) => void ) => { // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) let rafScheduled = false @@ -97,9 +112,20 @@ const processStreamingCompletion = async ( const currentContent = createThreadContent( activeThread.id, accumulatedText.value, - toolCalls + toolCalls, + continueFromMessageId ) - updateStreamingContent(currentContent) + + // When continuing, update the message directly instead of using streamingContent + if (continueFromMessageId && updateMessage) { + updateMessage({ + ...currentContent, + status: MessageStatus.Stopped, // Keep as Stopped while streaming + }) + } else { + updateStreamingContent(currentContent) + } + if (pendingDeltaCount > 0) { updateTokenSpeed(currentContent, pendingDeltaCount) } @@ -232,6 +258,7 @@ export const useChat = () => { const getMessages = useMessages((state) => state.getMessages) const addMessage = useMessages((state) => state.addMessage) + const updateMessage = useMessages((state) => state.updateMessage) const setMessages = useMessages((state) => state.setMessages) const setModelLoadError = useModelLoad((state) => state.setModelLoadError) const router = useRouter() @@ -396,7 +423,8 @@ export const useChat = () => { size: number base64: string dataUrl: string - }> + }>, + continueFromMessageId?: string ) => { const activeThread = await getCurrentThread() const selectedProvider = useModelProvider.getState().selectedProvider @@ -409,15 +437,24 @@ export const useChat = () => { setAbortController(activeThread.id, abortController) updateStreamingContent(emptyThreadContent) updatePromptProgress(undefined) - // Do not add new message on retry - if (troubleshooting) + + // Find the message to continue from if provided + const continueFromMessage = continueFromMessageId + ? messages.find((m) => m.id === continueFromMessageId) + : undefined + + // Do not add new message on retry or when continuing + if (troubleshooting && !continueFromMessageId) addMessage(newUserThreadContent(activeThread.id, message, attachments)) updateThreadTimestamp(activeThread.id) usePrompt.getState().setPrompt('') const selectedModel = useModelProvider.getState().selectedModel // Declare accumulatedTextRef BEFORE try block so it's accessible in catch block - const accumulatedTextRef = { value: '' } + // If continuing, start with the previous content + const accumulatedTextRef = { + value: continueFromMessage?.content?.[0]?.text?.value || '' + } let currentAssistant: Assistant | undefined try { @@ -427,13 +464,28 @@ export const useChat = () => { updateLoadingModel(false) } currentAssistant = useAssistant.getState().currentAssistant + + // Filter out the stopped message from context if continuing + const contextMessages = continueFromMessageId + ? messages.filter((m) => m.id !== continueFromMessageId) + : messages + const builder = new CompletionMessagesBuilder( - messages, + contextMessages, currentAssistant ? renderInstructions(currentAssistant.instructions) : undefined ) - if (troubleshooting) builder.addUserMessage(message, attachments) + if (troubleshooting && !continueFromMessageId) { + builder.addUserMessage(message, attachments) + } else if (continueFromMessage) { + // When continuing, add the partial assistant response to the context + builder.addAssistantMessage( + continueFromMessage.content?.[0]?.text?.value || '', + undefined, + [] + ) + } let isCompleted = false @@ -513,7 +565,9 @@ export const useChat = () => { currentCall, updateStreamingContent, updateTokenSpeed, - updatePromptProgress + updatePromptProgress, + continueFromMessageId, + updateMessage ) } } catch (error) { @@ -560,7 +614,7 @@ export const useChat = () => { } // Create a final content object for adding to the thread - const finalContent = newAssistantThreadContent( + let finalContent = newAssistantThreadContent( activeThread.id, accumulatedTextRef.value, { @@ -569,6 +623,15 @@ export const useChat = () => { } ) + // If continuing from a message, preserve the ID and set status to Ready + if (continueFromMessageId) { + finalContent = { + ...finalContent, + id: continueFromMessageId, + status: MessageStatus.Ready, + } + } + // Normal completion flow (abort is handled after loop exits) builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls) const updatedMessage = await postMessageProcessing( @@ -585,7 +648,9 @@ export const useChat = () => { addMessage, updateStreamingContent, updatePromptProgress, - updateThreadTimestamp + updateThreadTimestamp, + updateMessage, + continueFromMessageId ) isCompleted = !toolCalls.length @@ -618,8 +683,13 @@ export const useChat = () => { ), status: MessageStatus.Stopped, } - // Save the partial message - addMessage(partialContent) + + // If continuing, update the existing message; otherwise add new + if (continueFromMessageId) { + updateMessage({ ...partialContent, id: continueFromMessageId }) + } else { + addMessage(partialContent) + } updatePromptProgress(undefined) updateThreadTimestamp(activeThread.id) } @@ -649,7 +719,13 @@ export const useChat = () => { ), status: MessageStatus.Stopped, } - addMessage(partialContent) + + // If continuing, update the existing message; otherwise add new + if (continueFromMessageId) { + updateMessage({ ...partialContent, id: continueFromMessageId }) + } else { + addMessage(partialContent) + } updatePromptProgress(undefined) updateThreadTimestamp(activeThread.id) } else if (!abortController.signal.aborted) { diff --git a/web-app/src/hooks/useMessages.ts b/web-app/src/hooks/useMessages.ts index 71fd0c4e0..7db28ff34 100644 --- a/web-app/src/hooks/useMessages.ts +++ b/web-app/src/hooks/useMessages.ts @@ -8,6 +8,7 @@ type MessageState = { getMessages: (threadId: string) => ThreadMessage[] setMessages: (threadId: string, messages: ThreadMessage[]) => void addMessage: (message: ThreadMessage) => void + updateMessage: (message: ThreadMessage) => void deleteMessage: (threadId: string, messageId: string) => void clearAllMessages: () => void } @@ -57,6 +58,36 @@ export const useMessages = create()((set, get) => ({ console.error('Failed to persist message:', error) }) }, + updateMessage: (message) => { + const assistants = useAssistant.getState().assistants + const currentAssistant = useAssistant.getState().currentAssistant + + const selectedAssistant = + assistants.find((a) => a.id === currentAssistant?.id) || assistants[0] + + const updatedMessage = { + ...message, + metadata: { + ...message.metadata, + assistant: selectedAssistant, + }, + } + + // Optimistically update state immediately for instant UI feedback + set((state) => ({ + messages: { + ...state.messages, + [message.thread_id]: (state.messages[message.thread_id] || []).map((m) => + m.id === message.id ? updatedMessage : m + ), + }, + })) + + // Persist to storage asynchronously + getServiceHub().messages().createMessage(updatedMessage).catch((error) => { + console.error('Failed to persist message update:', error) + }) + }, deleteMessage: (threadId, messageId) => { getServiceHub().messages().deleteMessage(threadId, messageId) set((state) => ({