feat: Allow to save the last message upon interrupting llm response

This commit is contained in:
Vanalite 2025-10-01 15:43:05 +07:00
parent 0de5f17071
commit 2e86d4e421
4 changed files with 258 additions and 163 deletions

View File

@ -48,7 +48,9 @@ export const StreamingContent = memo(({ threadId }: Props) => {
return extractReasoningSegment(text)
}, [lastAssistant])
if (!streamingContent || streamingContent.thread_id !== threadId) return null
if (!streamingContent || streamingContent.thread_id !== threadId) {
return null
}
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
return null

View File

@ -225,9 +225,25 @@ describe('useMessages', () => {
})
)
// Wait for async operation
// Message should be immediately available (optimistic update)
expect(result.current.messages['thread1']).toContainEqual(
expect.objectContaining({
id: messageToAdd.id,
thread_id: messageToAdd.thread_id,
role: messageToAdd.role,
content: messageToAdd.content,
metadata: expect.objectContaining({
assistant: expect.objectContaining({
id: expect.any(String),
name: expect.any(String),
}),
}),
})
)
// Verify persistence was attempted
await vi.waitFor(() => {
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage)
expect(mockCreateMessage).toHaveBeenCalled()
})
})

View File

@ -35,6 +35,156 @@ import { useAssistant } from './useAssistant'
import { useShallow } from 'zustand/shallow'
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
// Helper to create thread content with consistent structure
const createThreadContent = (
threadId: string,
text: string,
toolCalls: ChatCompletionMessageToolCall[]
) => {
return newAssistantThreadContent(threadId, text, {
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
})
}
// Helper to cancel animation frame cross-platform
const cancelFrame = (handle: number | undefined) => {
if (handle === undefined) return
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(handle)
} else {
clearTimeout(handle)
}
}
// Helper to finalize and save a message
const finalizeMessage = (
finalContent: ThreadMessage,
addMessage: (message: ThreadMessage) => void,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updatePromptProgress: (progress: unknown) => void,
updateThreadTimestamp: (threadId: string) => void
) => {
addMessage(finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(finalContent.thread_id)
}
// Helper to process streaming completion
const processStreamingCompletion = async (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
completion: AsyncIterable<any>,
abortController: AbortController,
activeThread: Thread,
accumulatedText: { value: string },
toolCalls: ChatCompletionMessageToolCall[],
currentCall: ChatCompletionMessageToolCall | null,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
updatePromptProgress: (progress: unknown) => void
) => {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const flushStreamingContent = () => {
const currentContent = createThreadContent(
activeThread.id,
accumulatedText.value,
toolCalls
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
flushStreamingContent()
})
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText.value += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText.value += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
cancelFrame(rafHandle)
rafHandle = undefined
rafScheduled = false
// Finalize reasoning (close any open think tags)
accumulatedText.value += reasoningProcessor.finalize()
}
}
export const useChat = () => {
const [
updateTokenSpeed,
@ -264,13 +414,18 @@ export const useChat = () => {
updateThreadTimestamp(activeThread.id)
usePrompt.getState().setPrompt('')
const selectedModel = useModelProvider.getState().selectedModel
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
const accumulatedTextRef = { value: '' }
let currentAssistant: Assistant | undefined
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await serviceHub.models().startModel(activeProvider, selectedModel.id)
updateLoadingModel(false)
}
const currentAssistant = useAssistant.getState().currentAssistant
currentAssistant = useAssistant.getState().currentAssistant
const builder = new CompletionMessagesBuilder(
messages,
currentAssistant
@ -330,162 +485,35 @@ export const useChat = () => {
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
try {
if (isCompletionResponse(completion)) {
const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || ''
accumulatedTextRef.value = (message?.content as string) || ''
// Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
accumulatedTextRef.value =
`<think>${reasoning}</think>` + accumulatedTextRef.value
}
if (message?.tool_calls) {
toolCalls.push(...message.tool_calls)
}
} else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
})
}
const flushIfPending = () => {
if (!rafScheduled) return
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
}
// Do an immediate flush
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
await processStreamingCompletion(
completion,
abortController,
activeThread,
accumulatedTextRef,
toolCalls,
currentCall,
updateStreamingContent,
updateTokenSpeed,
updatePromptProgress
)
}
} catch (error) {
const errorMessage =
@ -519,7 +547,7 @@ export const useChat = () => {
}
// TODO: Remove this check when integrating new llama.cpp extension
if (
accumulatedText.length === 0 &&
accumulatedTextRef.value.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider?.provider === 'llamacpp'
@ -533,14 +561,15 @@ export const useChat = () => {
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
// Normal completion flow (abort is handled after loop exits)
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
@ -550,10 +579,13 @@ export const useChat = () => {
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
finalizeMessage(
updatedMessage ?? finalContent,
addMessage,
updateStreamingContent,
updatePromptProgress,
updateThreadTimestamp
)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
@ -563,8 +595,48 @@ export const useChat = () => {
availableTools = []
}
}
// IMPORTANT: Check if aborted AFTER the while loop exits
// The while loop exits when abort is true, so we handle it here
if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) {
// Create final content for the partial message
const partialContent = newAssistantThreadContent(
activeThread.id,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
// Save the partial message
addMessage(partialContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
}
} catch (error) {
if (!abortController.signal.aborted) {
// If aborted, save the partial message even though an error occurred
// Check both accumulatedTextRef and streamingContent from app state
const streamingContent = useAppState.getState().streamingContent
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
(streamingContent && streamingContent.content?.[0]?.text?.value)
if (abortController.signal.aborted && hasPartialContent) {
// Use streaming content if available, otherwise use accumulatedTextRef
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
const partialContent = newAssistantThreadContent(
activeThread.id,
contentText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
addMessage(partialContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
} else if (!abortController.signal.aborted) {
// Only show error if not aborted
if (error && typeof error === 'object' && 'message' in error) {
setModelLoadError(error as ErrorObject)
} else {

View File

@ -40,16 +40,21 @@ export const useMessages = create<MessageState>()((set, get) => ({
assistant: selectedAssistant,
},
}
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
set((state) => ({
messages: {
...state.messages,
[message.thread_id]: [
...(state.messages[message.thread_id] || []),
createdMessage,
],
},
}))
// Optimistically update state immediately for instant UI feedback
set((state) => ({
messages: {
...state.messages,
[message.thread_id]: [
...(state.messages[message.thread_id] || []),
newMessage,
],
},
}))
// Persist to storage asynchronously
getServiceHub().messages().createMessage(newMessage).catch((error) => {
console.error('Failed to persist message:', error)
})
},
deleteMessage: (threadId, messageId) => {