From 2e86d4e42180e7be6ac59858e2a159c3b41178ef Mon Sep 17 00:00:00 2001 From: Vanalite Date: Wed, 1 Oct 2025 15:43:05 +0700 Subject: [PATCH] feat: Allow to save the last message upon interrupting llm response --- web-app/src/containers/StreamingContent.tsx | 4 +- .../src/hooks/__tests__/useMessages.test.ts | 20 +- web-app/src/hooks/useChat.ts | 372 +++++++++++------- web-app/src/hooks/useMessages.ts | 25 +- 4 files changed, 258 insertions(+), 163 deletions(-) diff --git a/web-app/src/containers/StreamingContent.tsx b/web-app/src/containers/StreamingContent.tsx index 57aebe61e..e40a46f17 100644 --- a/web-app/src/containers/StreamingContent.tsx +++ b/web-app/src/containers/StreamingContent.tsx @@ -48,7 +48,9 @@ export const StreamingContent = memo(({ threadId }: Props) => { return extractReasoningSegment(text) }, [lastAssistant]) - if (!streamingContent || streamingContent.thread_id !== threadId) return null + if (!streamingContent || streamingContent.thread_id !== threadId) { + return null + } if (streamingReasoning && streamingReasoning === lastAssistantReasoning) { return null diff --git a/web-app/src/hooks/__tests__/useMessages.test.ts b/web-app/src/hooks/__tests__/useMessages.test.ts index 503806e38..4b75c4d58 100644 --- a/web-app/src/hooks/__tests__/useMessages.test.ts +++ b/web-app/src/hooks/__tests__/useMessages.test.ts @@ -225,9 +225,25 @@ describe('useMessages', () => { }) ) - // Wait for async operation + // Message should be immediately available (optimistic update) + expect(result.current.messages['thread1']).toContainEqual( + expect.objectContaining({ + id: messageToAdd.id, + thread_id: messageToAdd.thread_id, + role: messageToAdd.role, + content: messageToAdd.content, + metadata: expect.objectContaining({ + assistant: expect.objectContaining({ + id: expect.any(String), + name: expect.any(String), + }), + }), + }) + ) + + // Verify persistence was attempted await vi.waitFor(() => { - expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage) + expect(mockCreateMessage).toHaveBeenCalled() }) }) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index 935458326..a92269e96 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -35,6 +35,156 @@ import { useAssistant } from './useAssistant' import { useShallow } from 'zustand/shallow' import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat' +// Helper to create thread content with consistent structure +const createThreadContent = ( + threadId: string, + text: string, + toolCalls: ChatCompletionMessageToolCall[] +) => { + return newAssistantThreadContent(threadId, text, { + tool_calls: toolCalls.map((e) => ({ + ...e, + state: 'pending', + })), + }) +} + +// Helper to cancel animation frame cross-platform +const cancelFrame = (handle: number | undefined) => { + if (handle === undefined) return + if (typeof cancelAnimationFrame !== 'undefined') { + cancelAnimationFrame(handle) + } else { + clearTimeout(handle) + } +} + +// Helper to finalize and save a message +const finalizeMessage = ( + finalContent: ThreadMessage, + addMessage: (message: ThreadMessage) => void, + updateStreamingContent: (content: ThreadMessage | undefined) => void, + updatePromptProgress: (progress: unknown) => void, + updateThreadTimestamp: (threadId: string) => void +) => { + addMessage(finalContent) + updateStreamingContent(emptyThreadContent) + updatePromptProgress(undefined) + updateThreadTimestamp(finalContent.thread_id) +} + +// Helper to process streaming completion +const processStreamingCompletion = async ( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + completion: AsyncIterable, + abortController: AbortController, + activeThread: Thread, + accumulatedText: { value: string }, + toolCalls: ChatCompletionMessageToolCall[], + currentCall: ChatCompletionMessageToolCall | null, + updateStreamingContent: (content: ThreadMessage | undefined) => void, + updateTokenSpeed: (message: ThreadMessage, increment?: number) => void, + updatePromptProgress: (progress: unknown) => void +) => { + // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) + let rafScheduled = false + let rafHandle: number | undefined + let pendingDeltaCount = 0 + const reasoningProcessor = new ReasoningProcessor() + + const flushStreamingContent = () => { + const currentContent = createThreadContent( + activeThread.id, + accumulatedText.value, + toolCalls + ) + updateStreamingContent(currentContent) + if (pendingDeltaCount > 0) { + updateTokenSpeed(currentContent, pendingDeltaCount) + } + pendingDeltaCount = 0 + rafScheduled = false + } + + const scheduleFlush = () => { + if (rafScheduled || abortController.signal.aborted) return + rafScheduled = true + const doSchedule = (cb: () => void) => { + if (typeof requestAnimationFrame !== 'undefined') { + rafHandle = requestAnimationFrame(() => cb()) + } else { + // Fallback for non-browser test environments + const t = setTimeout(() => cb(), 0) as unknown as number + rafHandle = t + } + } + doSchedule(() => { + // Check abort status before executing the scheduled callback + if (abortController.signal.aborted) { + rafScheduled = false + return + } + flushStreamingContent() + }) + } + + try { + for await (const part of completion) { + // Check if aborted before processing each part + if (abortController.signal.aborted) { + break + } + + // Handle prompt progress if available + if ('prompt_progress' in part && part.prompt_progress) { + // Force immediate state update to ensure we see intermediate values + flushSync(() => { + updatePromptProgress(part.prompt_progress) + }) + // Add a small delay to make progress visible + await new Promise((resolve) => setTimeout(resolve, 100)) + } + + // Error message + if (!part.choices) { + throw new Error( + 'message' in part + ? (part.message as string) + : (JSON.stringify(part) ?? '') + ) + } + + if (part.choices[0]?.delta?.tool_calls) { + extractToolCall(part, currentCall, toolCalls) + // Schedule a flush to reflect tool update + scheduleFlush() + } + const deltaReasoning = reasoningProcessor.processReasoningChunk(part) + if (deltaReasoning) { + accumulatedText.value += deltaReasoning + pendingDeltaCount += 1 + // Schedule flush for reasoning updates + scheduleFlush() + } + const deltaContent = part.choices[0]?.delta?.content || '' + if (deltaContent) { + accumulatedText.value += deltaContent + pendingDeltaCount += 1 + // Batch UI update on next animation frame + scheduleFlush() + } + } + } finally { + // Always clean up scheduled RAF when stream ends (either normally or via abort) + cancelFrame(rafHandle) + rafHandle = undefined + rafScheduled = false + + // Finalize reasoning (close any open think tags) + accumulatedText.value += reasoningProcessor.finalize() + } +} + export const useChat = () => { const [ updateTokenSpeed, @@ -264,13 +414,18 @@ export const useChat = () => { updateThreadTimestamp(activeThread.id) usePrompt.getState().setPrompt('') const selectedModel = useModelProvider.getState().selectedModel + + // Declare accumulatedTextRef BEFORE try block so it's accessible in catch block + const accumulatedTextRef = { value: '' } + let currentAssistant: Assistant | undefined + try { if (selectedModel?.id) { updateLoadingModel(true) await serviceHub.models().startModel(activeProvider, selectedModel.id) updateLoadingModel(false) } - const currentAssistant = useAssistant.getState().currentAssistant + currentAssistant = useAssistant.getState().currentAssistant const builder = new CompletionMessagesBuilder( messages, currentAssistant @@ -330,162 +485,35 @@ export const useChat = () => { ) if (!completion) throw new Error('No completion received') - let accumulatedText = '' const currentCall: ChatCompletionMessageToolCall | null = null const toolCalls: ChatCompletionMessageToolCall[] = [] try { if (isCompletionResponse(completion)) { const message = completion.choices[0]?.message - accumulatedText = (message?.content as string) || '' + accumulatedTextRef.value = (message?.content as string) || '' // Handle reasoning field if there is one const reasoning = extractReasoningFromMessage(message) if (reasoning) { - accumulatedText = - `${reasoning}` + accumulatedText + accumulatedTextRef.value = + `${reasoning}` + accumulatedTextRef.value } if (message?.tool_calls) { toolCalls.push(...message.tool_calls) } } else { - // High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) - let rafScheduled = false - let rafHandle: number | undefined - let pendingDeltaCount = 0 - const reasoningProcessor = new ReasoningProcessor() - const scheduleFlush = () => { - if (rafScheduled || abortController.signal.aborted) return - rafScheduled = true - const doSchedule = (cb: () => void) => { - if (typeof requestAnimationFrame !== 'undefined') { - rafHandle = requestAnimationFrame(() => cb()) - } else { - // Fallback for non-browser test environments - const t = setTimeout(() => cb(), 0) as unknown as number - rafHandle = t - } - } - doSchedule(() => { - // Check abort status before executing the scheduled callback - if (abortController.signal.aborted) { - rafScheduled = false - return - } - - const currentContent = newAssistantThreadContent( - activeThread.id, - accumulatedText, - { - tool_calls: toolCalls.map((e) => ({ - ...e, - state: 'pending', - })), - } - ) - updateStreamingContent(currentContent) - if (pendingDeltaCount > 0) { - updateTokenSpeed(currentContent, pendingDeltaCount) - } - pendingDeltaCount = 0 - rafScheduled = false - }) - } - const flushIfPending = () => { - if (!rafScheduled) return - if ( - typeof cancelAnimationFrame !== 'undefined' && - rafHandle !== undefined - ) { - cancelAnimationFrame(rafHandle) - } else if (rafHandle !== undefined) { - clearTimeout(rafHandle) - } - // Do an immediate flush - const currentContent = newAssistantThreadContent( - activeThread.id, - accumulatedText, - { - tool_calls: toolCalls.map((e) => ({ - ...e, - state: 'pending', - })), - } - ) - updateStreamingContent(currentContent) - if (pendingDeltaCount > 0) { - updateTokenSpeed(currentContent, pendingDeltaCount) - } - pendingDeltaCount = 0 - rafScheduled = false - } - try { - for await (const part of completion) { - // Check if aborted before processing each part - if (abortController.signal.aborted) { - break - } - - // Handle prompt progress if available - if ('prompt_progress' in part && part.prompt_progress) { - // Force immediate state update to ensure we see intermediate values - flushSync(() => { - updatePromptProgress(part.prompt_progress) - }) - // Add a small delay to make progress visible - await new Promise((resolve) => setTimeout(resolve, 100)) - } - - // Error message - if (!part.choices) { - throw new Error( - 'message' in part - ? (part.message as string) - : (JSON.stringify(part) ?? '') - ) - } - - if (part.choices[0]?.delta?.tool_calls) { - extractToolCall(part, currentCall, toolCalls) - // Schedule a flush to reflect tool update - scheduleFlush() - } - const deltaReasoning = - reasoningProcessor.processReasoningChunk(part) - if (deltaReasoning) { - accumulatedText += deltaReasoning - pendingDeltaCount += 1 - // Schedule flush for reasoning updates - scheduleFlush() - } - const deltaContent = part.choices[0]?.delta?.content || '' - if (deltaContent) { - accumulatedText += deltaContent - pendingDeltaCount += 1 - // Batch UI update on next animation frame - scheduleFlush() - } - } - } finally { - // Always clean up scheduled RAF when stream ends (either normally or via abort) - if (rafHandle !== undefined) { - if (typeof cancelAnimationFrame !== 'undefined') { - cancelAnimationFrame(rafHandle) - } else { - clearTimeout(rafHandle) - } - rafHandle = undefined - rafScheduled = false - } - - // Only finalize and flush if not aborted - if (!abortController.signal.aborted) { - // Finalize reasoning (close any open think tags) - accumulatedText += reasoningProcessor.finalize() - // Ensure any pending buffered content is rendered at the end - flushIfPending() - } - } + await processStreamingCompletion( + completion, + abortController, + activeThread, + accumulatedTextRef, + toolCalls, + currentCall, + updateStreamingContent, + updateTokenSpeed, + updatePromptProgress + ) } } catch (error) { const errorMessage = @@ -519,7 +547,7 @@ export const useChat = () => { } // TODO: Remove this check when integrating new llama.cpp extension if ( - accumulatedText.length === 0 && + accumulatedTextRef.value.length === 0 && toolCalls.length === 0 && activeThread.model?.id && activeProvider?.provider === 'llamacpp' @@ -533,14 +561,15 @@ export const useChat = () => { // Create a final content object for adding to the thread const finalContent = newAssistantThreadContent( activeThread.id, - accumulatedText, + accumulatedTextRef.value, { tokenSpeed: useAppState.getState().tokenSpeed, assistant: currentAssistant, } ) - builder.addAssistantMessage(accumulatedText, undefined, toolCalls) + // Normal completion flow (abort is handled after loop exits) + builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls) const updatedMessage = await postMessageProcessing( toolCalls, builder, @@ -550,10 +579,13 @@ export const useChat = () => { allowAllMCPPermissions ? undefined : showApprovalModal, allowAllMCPPermissions ) - addMessage(updatedMessage ?? finalContent) - updateStreamingContent(emptyThreadContent) - updatePromptProgress(undefined) - updateThreadTimestamp(activeThread.id) + finalizeMessage( + updatedMessage ?? finalContent, + addMessage, + updateStreamingContent, + updatePromptProgress, + updateThreadTimestamp + ) isCompleted = !toolCalls.length // Do not create agent loop if there is no need for it @@ -563,8 +595,48 @@ export const useChat = () => { availableTools = [] } } + + // IMPORTANT: Check if aborted AFTER the while loop exits + // The while loop exits when abort is true, so we handle it here + if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) { + // Create final content for the partial message + const partialContent = newAssistantThreadContent( + activeThread.id, + accumulatedTextRef.value, + { + tokenSpeed: useAppState.getState().tokenSpeed, + assistant: currentAssistant, + } + ) + // Save the partial message + addMessage(partialContent) + updatePromptProgress(undefined) + updateThreadTimestamp(activeThread.id) + } } catch (error) { - if (!abortController.signal.aborted) { + // If aborted, save the partial message even though an error occurred + // Check both accumulatedTextRef and streamingContent from app state + const streamingContent = useAppState.getState().streamingContent + const hasPartialContent = accumulatedTextRef.value.length > 0 || + (streamingContent && streamingContent.content?.[0]?.text?.value) + + if (abortController.signal.aborted && hasPartialContent) { + // Use streaming content if available, otherwise use accumulatedTextRef + const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value + + const partialContent = newAssistantThreadContent( + activeThread.id, + contentText, + { + tokenSpeed: useAppState.getState().tokenSpeed, + assistant: currentAssistant, + } + ) + addMessage(partialContent) + updatePromptProgress(undefined) + updateThreadTimestamp(activeThread.id) + } else if (!abortController.signal.aborted) { + // Only show error if not aborted if (error && typeof error === 'object' && 'message' in error) { setModelLoadError(error as ErrorObject) } else { diff --git a/web-app/src/hooks/useMessages.ts b/web-app/src/hooks/useMessages.ts index 8c011a900..71fd0c4e0 100644 --- a/web-app/src/hooks/useMessages.ts +++ b/web-app/src/hooks/useMessages.ts @@ -40,16 +40,21 @@ export const useMessages = create()((set, get) => ({ assistant: selectedAssistant, }, } - getServiceHub().messages().createMessage(newMessage).then((createdMessage) => { - set((state) => ({ - messages: { - ...state.messages, - [message.thread_id]: [ - ...(state.messages[message.thread_id] || []), - createdMessage, - ], - }, - })) + + // Optimistically update state immediately for instant UI feedback + set((state) => ({ + messages: { + ...state.messages, + [message.thread_id]: [ + ...(state.messages[message.thread_id] || []), + newMessage, + ], + }, + })) + + // Persist to storage asynchronously + getServiceHub().messages().createMessage(newMessage).catch((error) => { + console.error('Failed to persist message:', error) }) }, deleteMessage: (threadId, messageId) => {