feat: Allow to save the last message upon interrupting llm response
This commit is contained in:
parent
0de5f17071
commit
2e86d4e421
@ -48,7 +48,9 @@ export const StreamingContent = memo(({ threadId }: Props) => {
|
||||
return extractReasoningSegment(text)
|
||||
}, [lastAssistant])
|
||||
|
||||
if (!streamingContent || streamingContent.thread_id !== threadId) return null
|
||||
if (!streamingContent || streamingContent.thread_id !== threadId) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
|
||||
return null
|
||||
|
||||
@ -225,9 +225,25 @@ describe('useMessages', () => {
|
||||
})
|
||||
)
|
||||
|
||||
// Wait for async operation
|
||||
// Message should be immediately available (optimistic update)
|
||||
expect(result.current.messages['thread1']).toContainEqual(
|
||||
expect.objectContaining({
|
||||
id: messageToAdd.id,
|
||||
thread_id: messageToAdd.thread_id,
|
||||
role: messageToAdd.role,
|
||||
content: messageToAdd.content,
|
||||
metadata: expect.objectContaining({
|
||||
assistant: expect.objectContaining({
|
||||
id: expect.any(String),
|
||||
name: expect.any(String),
|
||||
}),
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Verify persistence was attempted
|
||||
await vi.waitFor(() => {
|
||||
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage)
|
||||
expect(mockCreateMessage).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@ -35,6 +35,156 @@ import { useAssistant } from './useAssistant'
|
||||
import { useShallow } from 'zustand/shallow'
|
||||
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
||||
|
||||
// Helper to create thread content with consistent structure
|
||||
const createThreadContent = (
|
||||
threadId: string,
|
||||
text: string,
|
||||
toolCalls: ChatCompletionMessageToolCall[]
|
||||
) => {
|
||||
return newAssistantThreadContent(threadId, text, {
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
// Helper to cancel animation frame cross-platform
|
||||
const cancelFrame = (handle: number | undefined) => {
|
||||
if (handle === undefined) return
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(handle)
|
||||
} else {
|
||||
clearTimeout(handle)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to finalize and save a message
|
||||
const finalizeMessage = (
|
||||
finalContent: ThreadMessage,
|
||||
addMessage: (message: ThreadMessage) => void,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updatePromptProgress: (progress: unknown) => void,
|
||||
updateThreadTimestamp: (threadId: string) => void
|
||||
) => {
|
||||
addMessage(finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(finalContent.thread_id)
|
||||
}
|
||||
|
||||
// Helper to process streaming completion
|
||||
const processStreamingCompletion = async (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
completion: AsyncIterable<any>,
|
||||
abortController: AbortController,
|
||||
activeThread: Thread,
|
||||
accumulatedText: { value: string },
|
||||
toolCalls: ChatCompletionMessageToolCall[],
|
||||
currentCall: ChatCompletionMessageToolCall | null,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||
updatePromptProgress: (progress: unknown) => void
|
||||
) => {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
|
||||
const flushStreamingContent = () => {
|
||||
const currentContent = createThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText.value,
|
||||
toolCalls
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
flushStreamingContent()
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText.value += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText.value += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
cancelFrame(rafHandle)
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText.value += reasoningProcessor.finalize()
|
||||
}
|
||||
}
|
||||
|
||||
export const useChat = () => {
|
||||
const [
|
||||
updateTokenSpeed,
|
||||
@ -264,13 +414,18 @@ export const useChat = () => {
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
usePrompt.getState().setPrompt('')
|
||||
const selectedModel = useModelProvider.getState().selectedModel
|
||||
|
||||
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
|
||||
const accumulatedTextRef = { value: '' }
|
||||
let currentAssistant: Assistant | undefined
|
||||
|
||||
try {
|
||||
if (selectedModel?.id) {
|
||||
updateLoadingModel(true)
|
||||
await serviceHub.models().startModel(activeProvider, selectedModel.id)
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
const currentAssistant = useAssistant.getState().currentAssistant
|
||||
currentAssistant = useAssistant.getState().currentAssistant
|
||||
const builder = new CompletionMessagesBuilder(
|
||||
messages,
|
||||
currentAssistant
|
||||
@ -330,162 +485,35 @@ export const useChat = () => {
|
||||
)
|
||||
|
||||
if (!completion) throw new Error('No completion received')
|
||||
let accumulatedText = ''
|
||||
const currentCall: ChatCompletionMessageToolCall | null = null
|
||||
const toolCalls: ChatCompletionMessageToolCall[] = []
|
||||
try {
|
||||
if (isCompletionResponse(completion)) {
|
||||
const message = completion.choices[0]?.message
|
||||
accumulatedText = (message?.content as string) || ''
|
||||
accumulatedTextRef.value = (message?.content as string) || ''
|
||||
|
||||
// Handle reasoning field if there is one
|
||||
const reasoning = extractReasoningFromMessage(message)
|
||||
if (reasoning) {
|
||||
accumulatedText =
|
||||
`<think>${reasoning}</think>` + accumulatedText
|
||||
accumulatedTextRef.value =
|
||||
`<think>${reasoning}</think>` + accumulatedTextRef.value
|
||||
}
|
||||
|
||||
if (message?.tool_calls) {
|
||||
toolCalls.push(...message.tool_calls)
|
||||
}
|
||||
} else {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
let rafHandle: number | undefined
|
||||
let pendingDeltaCount = 0
|
||||
const reasoningProcessor = new ReasoningProcessor()
|
||||
const scheduleFlush = () => {
|
||||
if (rafScheduled || abortController.signal.aborted) return
|
||||
rafScheduled = true
|
||||
const doSchedule = (cb: () => void) => {
|
||||
if (typeof requestAnimationFrame !== 'undefined') {
|
||||
rafHandle = requestAnimationFrame(() => cb())
|
||||
} else {
|
||||
// Fallback for non-browser test environments
|
||||
const t = setTimeout(() => cb(), 0) as unknown as number
|
||||
rafHandle = t
|
||||
}
|
||||
}
|
||||
doSchedule(() => {
|
||||
// Check abort status before executing the scheduled callback
|
||||
if (abortController.signal.aborted) {
|
||||
rafScheduled = false
|
||||
return
|
||||
}
|
||||
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
})
|
||||
}
|
||||
const flushIfPending = () => {
|
||||
if (!rafScheduled) return
|
||||
if (
|
||||
typeof cancelAnimationFrame !== 'undefined' &&
|
||||
rafHandle !== undefined
|
||||
) {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else if (rafHandle !== undefined) {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
// Do an immediate flush
|
||||
const currentContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
{
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
}
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
pendingDeltaCount = 0
|
||||
rafScheduled = false
|
||||
}
|
||||
try {
|
||||
for await (const part of completion) {
|
||||
// Check if aborted before processing each part
|
||||
if (abortController.signal.aborted) {
|
||||
break
|
||||
}
|
||||
|
||||
// Handle prompt progress if available
|
||||
if ('prompt_progress' in part && part.prompt_progress) {
|
||||
// Force immediate state update to ensure we see intermediate values
|
||||
flushSync(() => {
|
||||
updatePromptProgress(part.prompt_progress)
|
||||
})
|
||||
// Add a small delay to make progress visible
|
||||
await new Promise((resolve) => setTimeout(resolve, 100))
|
||||
}
|
||||
|
||||
// Error message
|
||||
if (!part.choices) {
|
||||
throw new Error(
|
||||
'message' in part
|
||||
? (part.message as string)
|
||||
: (JSON.stringify(part) ?? '')
|
||||
)
|
||||
}
|
||||
|
||||
if (part.choices[0]?.delta?.tool_calls) {
|
||||
extractToolCall(part, currentCall, toolCalls)
|
||||
// Schedule a flush to reflect tool update
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaReasoning =
|
||||
reasoningProcessor.processReasoningChunk(part)
|
||||
if (deltaReasoning) {
|
||||
accumulatedText += deltaReasoning
|
||||
pendingDeltaCount += 1
|
||||
// Schedule flush for reasoning updates
|
||||
scheduleFlush()
|
||||
}
|
||||
const deltaContent = part.choices[0]?.delta?.content || ''
|
||||
if (deltaContent) {
|
||||
accumulatedText += deltaContent
|
||||
pendingDeltaCount += 1
|
||||
// Batch UI update on next animation frame
|
||||
scheduleFlush()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// Always clean up scheduled RAF when stream ends (either normally or via abort)
|
||||
if (rafHandle !== undefined) {
|
||||
if (typeof cancelAnimationFrame !== 'undefined') {
|
||||
cancelAnimationFrame(rafHandle)
|
||||
} else {
|
||||
clearTimeout(rafHandle)
|
||||
}
|
||||
rafHandle = undefined
|
||||
rafScheduled = false
|
||||
}
|
||||
|
||||
// Only finalize and flush if not aborted
|
||||
if (!abortController.signal.aborted) {
|
||||
// Finalize reasoning (close any open think tags)
|
||||
accumulatedText += reasoningProcessor.finalize()
|
||||
// Ensure any pending buffered content is rendered at the end
|
||||
flushIfPending()
|
||||
}
|
||||
}
|
||||
await processStreamingCompletion(
|
||||
completion,
|
||||
abortController,
|
||||
activeThread,
|
||||
accumulatedTextRef,
|
||||
toolCalls,
|
||||
currentCall,
|
||||
updateStreamingContent,
|
||||
updateTokenSpeed,
|
||||
updatePromptProgress
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
@ -519,7 +547,7 @@ export const useChat = () => {
|
||||
}
|
||||
// TODO: Remove this check when integrating new llama.cpp extension
|
||||
if (
|
||||
accumulatedText.length === 0 &&
|
||||
accumulatedTextRef.value.length === 0 &&
|
||||
toolCalls.length === 0 &&
|
||||
activeThread.model?.id &&
|
||||
activeProvider?.provider === 'llamacpp'
|
||||
@ -533,14 +561,15 @@ export const useChat = () => {
|
||||
// Create a final content object for adding to the thread
|
||||
const finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
|
||||
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
|
||||
// Normal completion flow (abort is handled after loop exits)
|
||||
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
toolCalls,
|
||||
builder,
|
||||
@ -550,10 +579,13 @@ export const useChat = () => {
|
||||
allowAllMCPPermissions ? undefined : showApprovalModal,
|
||||
allowAllMCPPermissions
|
||||
)
|
||||
addMessage(updatedMessage ?? finalContent)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
finalizeMessage(
|
||||
updatedMessage ?? finalContent,
|
||||
addMessage,
|
||||
updateStreamingContent,
|
||||
updatePromptProgress,
|
||||
updateThreadTimestamp
|
||||
)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
// Do not create agent loop if there is no need for it
|
||||
@ -563,8 +595,48 @@ export const useChat = () => {
|
||||
availableTools = []
|
||||
}
|
||||
}
|
||||
|
||||
// IMPORTANT: Check if aborted AFTER the while loop exits
|
||||
// The while loop exits when abort is true, so we handle it here
|
||||
if (abortController.signal.aborted && accumulatedTextRef.value.length > 0) {
|
||||
// Create final content for the partial message
|
||||
const partialContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
// Save the partial message
|
||||
addMessage(partialContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
}
|
||||
} catch (error) {
|
||||
if (!abortController.signal.aborted) {
|
||||
// If aborted, save the partial message even though an error occurred
|
||||
// Check both accumulatedTextRef and streamingContent from app state
|
||||
const streamingContent = useAppState.getState().streamingContent
|
||||
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
|
||||
(streamingContent && streamingContent.content?.[0]?.text?.value)
|
||||
|
||||
if (abortController.signal.aborted && hasPartialContent) {
|
||||
// Use streaming content if available, otherwise use accumulatedTextRef
|
||||
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
|
||||
|
||||
const partialContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
contentText,
|
||||
{
|
||||
tokenSpeed: useAppState.getState().tokenSpeed,
|
||||
assistant: currentAssistant,
|
||||
}
|
||||
)
|
||||
addMessage(partialContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
} else if (!abortController.signal.aborted) {
|
||||
// Only show error if not aborted
|
||||
if (error && typeof error === 'object' && 'message' in error) {
|
||||
setModelLoadError(error as ErrorObject)
|
||||
} else {
|
||||
|
||||
@ -40,16 +40,21 @@ export const useMessages = create<MessageState>()((set, get) => ({
|
||||
assistant: selectedAssistant,
|
||||
},
|
||||
}
|
||||
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: [
|
||||
...(state.messages[message.thread_id] || []),
|
||||
createdMessage,
|
||||
],
|
||||
},
|
||||
}))
|
||||
|
||||
// Optimistically update state immediately for instant UI feedback
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: [
|
||||
...(state.messages[message.thread_id] || []),
|
||||
newMessage,
|
||||
],
|
||||
},
|
||||
}))
|
||||
|
||||
// Persist to storage asynchronously
|
||||
getServiceHub().messages().createMessage(newMessage).catch((error) => {
|
||||
console.error('Failed to persist message:', error)
|
||||
})
|
||||
},
|
||||
deleteMessage: (threadId, messageId) => {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user