diff --git a/web-app/src/containers/GenerateResponseButton.tsx b/web-app/src/containers/GenerateResponseButton.tsx index 9f6df11f8..29b098776 100644 --- a/web-app/src/containers/GenerateResponseButton.tsx +++ b/web-app/src/containers/GenerateResponseButton.tsx @@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages' import { useTranslation } from '@/i18n/react-i18next-compat' import { Play } from 'lucide-react' import { useShallow } from 'zustand/react/shallow' +import { useMemo } from 'react' +import { MessageStatus } from '@janhq/core' export const GenerateResponseButton = ({ threadId }: { threadId: string }) => { const { t } = useTranslation() @@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => { })) ) const sendMessage = useChat() + + // Detect if last message is a partial assistant response (user stopped midway) + // Only true if message has Stopped status (interrupted by user) + const isPartialResponse = useMemo(() => { + if (!messages || messages.length < 2) return false + const lastMessage = messages[messages.length - 1] + const secondLastMessage = messages[messages.length - 2] + + // Partial if: last is assistant with Stopped status, second-last is user, no tool calls + return ( + lastMessage?.role === 'assistant' && + lastMessage?.status === MessageStatus.Stopped && + secondLastMessage?.role === 'user' && + !lastMessage?.metadata?.tool_calls + ) + }, [messages]) + const generateAIResponse = () => { + // If continuing a partial response, delete the partial message first + if (isPartialResponse) { + const partialMessage = messages[messages.length - 1] + deleteMessage(partialMessage.thread_id, partialMessage.id ?? '') + // Get the user message that prompted this partial response + const userMessage = messages[messages.length - 2] + if (userMessage?.content?.[0]?.text?.value) { + sendMessage(userMessage.content[0].text.value, false) + } + return + } + const latestUserMessage = messages[messages.length - 1] if ( latestUserMessage?.content?.[0]?.text?.value && @@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => { className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto" onClick={generateAIResponse} > -

{t('common:generateAiResponse')}

+

+ {isPartialResponse + ? t('common:continueAiResponse') + : t('common:generateAiResponse')} +

) diff --git a/web-app/src/containers/ScrollToBottom.tsx b/web-app/src/containers/ScrollToBottom.tsx index b1259480f..48f04a3d7 100644 --- a/web-app/src/containers/ScrollToBottom.tsx +++ b/web-app/src/containers/ScrollToBottom.tsx @@ -8,6 +8,7 @@ import { cn } from '@/lib/utils' import { ArrowDown } from 'lucide-react' import { useTranslation } from '@/i18n/react-i18next-compat' import { useAppState } from '@/hooks/useAppState' +import { MessageStatus } from '@janhq/core' const ScrollToBottom = ({ threadId, @@ -28,11 +29,21 @@ const ScrollToBottom = ({ const streamingContent = useAppState((state) => state.streamingContent) + // Check if last message is a partial assistant response (user interrupted) + // Only show button if message has Stopped status (interrupted by user) + const isPartialResponse = + messages.length >= 2 && + messages[messages.length - 1]?.role === 'assistant' && + messages[messages.length - 1]?.status === MessageStatus.Stopped && + messages[messages.length - 2]?.role === 'user' && + !messages[messages.length - 1]?.metadata?.tool_calls + const showGenerateAIResponseBtn = - (messages[messages.length - 1]?.role === 'user' || + ((messages[messages.length - 1]?.role === 'user' || (messages[messages.length - 1]?.metadata && - 'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) && - !streamingContent + 'tool_calls' in (messages[messages.length - 1].metadata ?? {})) || + isPartialResponse) && + !streamingContent) return (
{ return null } + // Don't show streaming content if there's already a stopped message + // (interrupted message that was just saved) + if (lastAssistant?.status === MessageStatus.Stopped) { + return null + } + // Pass a new object to ThreadContent to avoid reference issues // The streaming content is always the last message return ( diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index a92269e96..292adb235 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -20,6 +20,7 @@ import { import { CompletionMessagesBuilder } from '@/lib/messages' import { renderInstructions } from '@/lib/instructionTemplate' import { ChatCompletionMessageToolCall } from 'openai/resources' +import { MessageStatus } from '@janhq/core' import { useServiceHub } from '@/hooks/useServiceHub' import { useToolApproval } from '@/hooks/useToolApproval' @@ -598,16 +599,25 @@ export const useChat = () => { // IMPORTANT: Check if aborted AFTER the while loop exits // The while loop exits when abort is true, so we handle it here - if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) { - // Create final content for the partial message - const partialContent = newAssistantThreadContent( - activeThread.id, - accumulatedTextRef.value, - { - tokenSpeed: useAppState.getState().tokenSpeed, - assistant: currentAssistant, - } - ) + // Only save interrupted messages for llamacpp provider + // Other providers (OpenAI, Claude, etc.) handle streaming differently + if ( + abortController.signal.aborted && + accumulatedTextRef.value.length > 0 && + activeProvider?.provider === 'llamacpp' + ) { + // Create final content for the partial message with Stopped status + const partialContent = { + ...newAssistantThreadContent( + activeThread.id, + accumulatedTextRef.value, + { + tokenSpeed: useAppState.getState().tokenSpeed, + assistant: currentAssistant, + } + ), + status: MessageStatus.Stopped, + } // Save the partial message addMessage(partialContent) updatePromptProgress(undefined) @@ -615,23 +625,30 @@ export const useChat = () => { } } catch (error) { // If aborted, save the partial message even though an error occurred - // Check both accumulatedTextRef and streamingContent from app state + // Only save for llamacpp provider - other providers handle streaming differently const streamingContent = useAppState.getState().streamingContent const hasPartialContent = accumulatedTextRef.value.length > 0 || (streamingContent && streamingContent.content?.[0]?.text?.value) - if (abortController.signal.aborted && hasPartialContent) { + if ( + abortController.signal.aborted && + hasPartialContent && + activeProvider?.provider === 'llamacpp' + ) { // Use streaming content if available, otherwise use accumulatedTextRef const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value - const partialContent = newAssistantThreadContent( - activeThread.id, - contentText, - { - tokenSpeed: useAppState.getState().tokenSpeed, - assistant: currentAssistant, - } - ) + const partialContent = { + ...newAssistantThreadContent( + activeThread.id, + contentText, + { + tokenSpeed: useAppState.getState().tokenSpeed, + assistant: currentAssistant, + } + ), + status: MessageStatus.Stopped, + } addMessage(partialContent) updatePromptProgress(undefined) updateThreadTimestamp(activeThread.id) diff --git a/web-app/src/locales/en/common.json b/web-app/src/locales/en/common.json index 2c8b8c09d..d0044dee3 100644 --- a/web-app/src/locales/en/common.json +++ b/web-app/src/locales/en/common.json @@ -135,6 +135,7 @@ "enterApiKey": "Enter API Key", "scrollToBottom": "Scroll to bottom", "generateAiResponse": "Generate AI Response", + "continueAiResponse": "Continue with AI Response", "addModel": { "title": "Add Model", "modelId": "Model ID",