feat: Allow to save the last message upon interrupting llm response

This commit is contained in:
Vanalite 2025-10-01 15:43:05 +07:00
parent 0de5f17071
commit 2e86d4e421
4 changed files with 258 additions and 163 deletions

View File

@ -48,7 +48,9 @@ export const StreamingContent = memo(({ threadId }: Props) => {
return extractReasoningSegment(text) return extractReasoningSegment(text)
}, [lastAssistant]) }, [lastAssistant])
if (!streamingContent || streamingContent.thread_id !== threadId) return null if (!streamingContent || streamingContent.thread_id !== threadId) {
return null
}
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) { if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
return null return null

View File

@ -225,9 +225,25 @@ describe('useMessages', () => {
}) })
) )
// Wait for async operation // Message should be immediately available (optimistic update)
expect(result.current.messages['thread1']).toContainEqual(
expect.objectContaining({
id: messageToAdd.id,
thread_id: messageToAdd.thread_id,
role: messageToAdd.role,
content: messageToAdd.content,
metadata: expect.objectContaining({
assistant: expect.objectContaining({
id: expect.any(String),
name: expect.any(String),
}),
}),
})
)
// Verify persistence was attempted
await vi.waitFor(() => { await vi.waitFor(() => {
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage) expect(mockCreateMessage).toHaveBeenCalled()
}) })
}) })

View File

@ -35,6 +35,156 @@ import { useAssistant } from './useAssistant'
import { useShallow } from 'zustand/shallow' import { useShallow } from 'zustand/shallow'
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat' import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
// Helper to create thread content with consistent structure
const createThreadContent = (
threadId: string,
text: string,
toolCalls: ChatCompletionMessageToolCall[]
) => {
return newAssistantThreadContent(threadId, text, {
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
})
}
// Helper to cancel animation frame cross-platform
const cancelFrame = (handle: number | undefined) => {
if (handle === undefined) return
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(handle)
} else {
clearTimeout(handle)
}
}
// Helper to finalize and save a message
const finalizeMessage = (
finalContent: ThreadMessage,
addMessage: (message: ThreadMessage) => void,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updatePromptProgress: (progress: unknown) => void,
updateThreadTimestamp: (threadId: string) => void
) => {
addMessage(finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(finalContent.thread_id)
}
// Helper to process streaming completion
const processStreamingCompletion = async (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
completion: AsyncIterable<any>,
abortController: AbortController,
activeThread: Thread,
accumulatedText: { value: string },
toolCalls: ChatCompletionMessageToolCall[],
currentCall: ChatCompletionMessageToolCall | null,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
updatePromptProgress: (progress: unknown) => void
) => {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const flushStreamingContent = () => {
const currentContent = createThreadContent(
activeThread.id,
accumulatedText.value,
toolCalls
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
flushStreamingContent()
})
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText.value += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText.value += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
cancelFrame(rafHandle)
rafHandle = undefined
rafScheduled = false
// Finalize reasoning (close any open think tags)
accumulatedText.value += reasoningProcessor.finalize()
}
}
export const useChat = () => { export const useChat = () => {
const [ const [
updateTokenSpeed, updateTokenSpeed,
@ -264,13 +414,18 @@ export const useChat = () => {
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id)
usePrompt.getState().setPrompt('') usePrompt.getState().setPrompt('')
const selectedModel = useModelProvider.getState().selectedModel const selectedModel = useModelProvider.getState().selectedModel
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
const accumulatedTextRef = { value: '' }
let currentAssistant: Assistant | undefined
try { try {
if (selectedModel?.id) { if (selectedModel?.id) {
updateLoadingModel(true) updateLoadingModel(true)
await serviceHub.models().startModel(activeProvider, selectedModel.id) await serviceHub.models().startModel(activeProvider, selectedModel.id)
updateLoadingModel(false) updateLoadingModel(false)
} }
const currentAssistant = useAssistant.getState().currentAssistant currentAssistant = useAssistant.getState().currentAssistant
const builder = new CompletionMessagesBuilder( const builder = new CompletionMessagesBuilder(
messages, messages,
currentAssistant currentAssistant
@ -330,162 +485,35 @@ export const useChat = () => {
) )
if (!completion) throw new Error('No completion received') if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = [] const toolCalls: ChatCompletionMessageToolCall[] = []
try { try {
if (isCompletionResponse(completion)) { if (isCompletionResponse(completion)) {
const message = completion.choices[0]?.message const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || '' accumulatedTextRef.value = (message?.content as string) || ''
// Handle reasoning field if there is one // Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message) const reasoning = extractReasoningFromMessage(message)
if (reasoning) { if (reasoning) {
accumulatedText = accumulatedTextRef.value =
`<think>${reasoning}</think>` + accumulatedText `<think>${reasoning}</think>` + accumulatedTextRef.value
} }
if (message?.tool_calls) { if (message?.tool_calls) {
toolCalls.push(...message.tool_calls) toolCalls.push(...message.tool_calls)
} }
} else { } else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) await processStreamingCompletion(
let rafScheduled = false completion,
let rafHandle: number | undefined abortController,
let pendingDeltaCount = 0 activeThread,
const reasoningProcessor = new ReasoningProcessor() accumulatedTextRef,
const scheduleFlush = () => { toolCalls,
if (rafScheduled || abortController.signal.aborted) return currentCall,
rafScheduled = true updateStreamingContent,
const doSchedule = (cb: () => void) => { updateTokenSpeed,
if (typeof requestAnimationFrame !== 'undefined') { updatePromptProgress
rafHandle = requestAnimationFrame(() => cb()) )
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
})
}
const flushIfPending = () => {
if (!rafScheduled) return
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
}
// Do an immediate flush
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
} }
} catch (error) { } catch (error) {
const errorMessage = const errorMessage =
@ -519,7 +547,7 @@ export const useChat = () => {
} }
// TODO: Remove this check when integrating new llama.cpp extension // TODO: Remove this check when integrating new llama.cpp extension
if ( if (
accumulatedText.length === 0 && accumulatedTextRef.value.length === 0 &&
toolCalls.length === 0 && toolCalls.length === 0 &&
activeThread.model?.id && activeThread.model?.id &&
activeProvider?.provider === 'llamacpp' activeProvider?.provider === 'llamacpp'
@ -533,14 +561,15 @@ export const useChat = () => {
// Create a final content object for adding to the thread // Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent( const finalContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText, accumulatedTextRef.value,
{ {
tokenSpeed: useAppState.getState().tokenSpeed, tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant, assistant: currentAssistant,
} }
) )
builder.addAssistantMessage(accumulatedText, undefined, toolCalls) // Normal completion flow (abort is handled after loop exits)
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
const updatedMessage = await postMessageProcessing( const updatedMessage = await postMessageProcessing(
toolCalls, toolCalls,
builder, builder,
@ -550,10 +579,13 @@ export const useChat = () => {
allowAllMCPPermissions ? undefined : showApprovalModal, allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions allowAllMCPPermissions
) )
addMessage(updatedMessage ?? finalContent) finalizeMessage(
updateStreamingContent(emptyThreadContent) updatedMessage ?? finalContent,
updatePromptProgress(undefined) addMessage,
updateThreadTimestamp(activeThread.id) updateStreamingContent,
updatePromptProgress,
updateThreadTimestamp
)
isCompleted = !toolCalls.length isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it // Do not create agent loop if there is no need for it
@ -563,8 +595,48 @@ export const useChat = () => {
availableTools = [] availableTools = []
} }
} }
// IMPORTANT: Check if aborted AFTER the while loop exits
// The while loop exits when abort is true, so we handle it here
if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) {
// Create final content for the partial message
const partialContent = newAssistantThreadContent(
activeThread.id,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
// Save the partial message
addMessage(partialContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
}
} catch (error) { } catch (error) {
if (!abortController.signal.aborted) { // If aborted, save the partial message even though an error occurred
// Check both accumulatedTextRef and streamingContent from app state
const streamingContent = useAppState.getState().streamingContent
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
(streamingContent && streamingContent.content?.[0]?.text?.value)
if (abortController.signal.aborted && hasPartialContent) {
// Use streaming content if available, otherwise use accumulatedTextRef
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
const partialContent = newAssistantThreadContent(
activeThread.id,
contentText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
addMessage(partialContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
} else if (!abortController.signal.aborted) {
// Only show error if not aborted
if (error && typeof error === 'object' && 'message' in error) { if (error && typeof error === 'object' && 'message' in error) {
setModelLoadError(error as ErrorObject) setModelLoadError(error as ErrorObject)
} else { } else {

View File

@ -40,16 +40,21 @@ export const useMessages = create<MessageState>()((set, get) => ({
assistant: selectedAssistant, assistant: selectedAssistant,
}, },
} }
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
set((state) => ({ // Optimistically update state immediately for instant UI feedback
messages: { set((state) => ({
...state.messages, messages: {
[message.thread_id]: [ ...state.messages,
...(state.messages[message.thread_id] || []), [message.thread_id]: [
createdMessage, ...(state.messages[message.thread_id] || []),
], newMessage,
}, ],
})) },
}))
// Persist to storage asynchronously
getServiceHub().messages().createMessage(newMessage).catch((error) => {
console.error('Failed to persist message:', error)
}) })
}, },
deleteMessage: (threadId, messageId) => { deleteMessage: (threadId, messageId) => {