feat: Continue with AI response button if it got interrupted
This commit is contained in:
parent
2e86d4e421
commit
4ea9d296ea
@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { Play } from 'lucide-react'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useMemo } from 'react'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
const { t } = useTranslation()
|
||||
@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
}))
|
||||
)
|
||||
const sendMessage = useChat()
|
||||
|
||||
// Detect if last message is a partial assistant response (user stopped midway)
|
||||
// Only true if message has Stopped status (interrupted by user)
|
||||
const isPartialResponse = useMemo(() => {
|
||||
if (!messages || messages.length < 2) return false
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
const secondLastMessage = messages[messages.length - 2]
|
||||
|
||||
// Partial if: last is assistant with Stopped status, second-last is user, no tool calls
|
||||
return (
|
||||
lastMessage?.role === 'assistant' &&
|
||||
lastMessage?.status === MessageStatus.Stopped &&
|
||||
secondLastMessage?.role === 'user' &&
|
||||
!lastMessage?.metadata?.tool_calls
|
||||
)
|
||||
}, [messages])
|
||||
|
||||
const generateAIResponse = () => {
|
||||
// If continuing a partial response, delete the partial message first
|
||||
if (isPartialResponse) {
|
||||
const partialMessage = messages[messages.length - 1]
|
||||
deleteMessage(partialMessage.thread_id, partialMessage.id ?? '')
|
||||
// Get the user message that prompted this partial response
|
||||
const userMessage = messages[messages.length - 2]
|
||||
if (userMessage?.content?.[0]?.text?.value) {
|
||||
sendMessage(userMessage.content[0].text.value, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const latestUserMessage = messages[messages.length - 1]
|
||||
if (
|
||||
latestUserMessage?.content?.[0]?.text?.value &&
|
||||
@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
|
||||
onClick={generateAIResponse}
|
||||
>
|
||||
<p className="text-xs">{t('common:generateAiResponse')}</p>
|
||||
<p className="text-xs">
|
||||
{isPartialResponse
|
||||
? t('common:continueAiResponse')
|
||||
: t('common:generateAiResponse')}
|
||||
</p>
|
||||
<Play size={12} />
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
|
||||
import { ArrowDown } from 'lucide-react'
|
||||
import { useTranslation } from '@/i18n/react-i18next-compat'
|
||||
import { useAppState } from '@/hooks/useAppState'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
const ScrollToBottom = ({
|
||||
threadId,
|
||||
@ -28,11 +29,21 @@ const ScrollToBottom = ({
|
||||
|
||||
const streamingContent = useAppState((state) => state.streamingContent)
|
||||
|
||||
// Check if last message is a partial assistant response (user interrupted)
|
||||
// Only show button if message has Stopped status (interrupted by user)
|
||||
const isPartialResponse =
|
||||
messages.length >= 2 &&
|
||||
messages[messages.length - 1]?.role === 'assistant' &&
|
||||
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
|
||||
messages[messages.length - 2]?.role === 'user' &&
|
||||
!messages[messages.length - 1]?.metadata?.tool_calls
|
||||
|
||||
const showGenerateAIResponseBtn =
|
||||
(messages[messages.length - 1]?.role === 'user' ||
|
||||
((messages[messages.length - 1]?.role === 'user' ||
|
||||
(messages[messages.length - 1]?.metadata &&
|
||||
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) &&
|
||||
!streamingContent
|
||||
'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
|
||||
isPartialResponse) &&
|
||||
!streamingContent)
|
||||
|
||||
return (
|
||||
<div
|
||||
|
||||
@ -2,6 +2,7 @@ import { useAppState } from '@/hooks/useAppState'
|
||||
import { ThreadContent } from './ThreadContent'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useMessages } from '@/hooks/useMessages'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
type Props = {
|
||||
threadId: string
|
||||
@ -56,6 +57,12 @@ export const StreamingContent = memo(({ threadId }: Props) => {
|
||||
return null
|
||||
}
|
||||
|
||||
// Don't show streaming content if there's already a stopped message
|
||||
// (interrupted message that was just saved)
|
||||
if (lastAssistant?.status === MessageStatus.Stopped) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Pass a new object to ThreadContent to avoid reference issues
|
||||
// The streaming content is always the last message
|
||||
return (
|
||||
|
||||
@ -20,6 +20,7 @@ import {
|
||||
import { CompletionMessagesBuilder } from '@/lib/messages'
|
||||
import { renderInstructions } from '@/lib/instructionTemplate'
|
||||
import { ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
import { MessageStatus } from '@janhq/core'
|
||||
|
||||
import { useServiceHub } from '@/hooks/useServiceHub'
|
||||
import { useToolApproval } from '@/hooks/useToolApproval'
|
||||
@ -598,16 +599,25 @@ export const useChat = () => {
|
||||
|
||||
// IMPORTANT: Check if aborted AFTER the while loop exits
|
||||
// The while loop exits when abort is true, so we handle it here
|
||||
if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) {
|
||||
// Create final content for the partial message
|
||||
const partialContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
// Only save interrupted messages for llamacpp provider
|
||||
// Other providers (OpenAI, Claude, etc.) handle streaming differently
|
||||
if (
|
||||
abortController.signal.aborted &&
|
||||
accumulatedTextRef.value.length > 0 &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
) {
|
||||
// Create final content for the partial message with Stopped status
|
||||
const partialContent = {
|
||||
...newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
// Save the partial message
|
||||
addMessage(partialContent)
|
||||
updatePromptProgress(undefined)
|
||||
@ -615,23 +625,30 @@ export const useChat = () => {
|
||||
}
|
||||
} catch (error) {
|
||||
// If aborted, save the partial message even though an error occurred
|
||||
// Check both accumulatedTextRef and streamingContent from app state
|
||||
// Only save for llamacpp provider - other providers handle streaming differently
|
||||
const streamingContent = useAppState.getState().streamingContent
|
||||
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
|
||||
(streamingContent && streamingContent.content?.[0]?.text?.value)
|
||||
|
||||
if (abortController.signal.aborted && hasPartialContent) {
|
||||
if (
|
||||
abortController.signal.aborted &&
|
||||
hasPartialContent &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
) {
|
||||
// Use streaming content if available, otherwise use accumulatedTextRef
|
||||
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
|
||||
|
||||
const partialContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
contentText,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
const partialContent = {
|
||||
...newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
contentText,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
addMessage(partialContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
|
||||
@ -135,6 +135,7 @@
|
||||
"enterApiKey": "Enter API Key",
|
||||
"scrollToBottom": "Scroll to bottom",
|
||||
"generateAiResponse": "Generate AI Response",
|
||||
"continueAiResponse": "Continue with AI Response",
|
||||
"addModel": {
|
||||
"title": "Add Model",
|
||||
"modelId": "Model ID",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user