feat: Continue with AI response for llamacpp
This commit is contained in:
parent
4ea9d296ea
commit
f4b187ba11
@ -33,14 +33,18 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
|||||||
}, [messages])
|
}, [messages])
|
||||||
|
|
||||||
const generateAIResponse = () => {
|
const generateAIResponse = () => {
|
||||||
// If continuing a partial response, delete the partial message first
|
// If continuing a partial response, keep the message and continue from it
|
||||||
if (isPartialResponse) {
|
if (isPartialResponse) {
|
||||||
const partialMessage = messages[messages.length - 1]
|
const partialMessage = messages[messages.length - 1]
|
||||||
deleteMessage(partialMessage.thread_id, partialMessage.id ?? '')
|
|
||||||
// Get the user message that prompted this partial response
|
|
||||||
const userMessage = messages[messages.length - 2]
|
const userMessage = messages[messages.length - 2]
|
||||||
if (userMessage?.content?.[0]?.text?.value) {
|
if (userMessage?.content?.[0]?.text?.value) {
|
||||||
sendMessage(userMessage.content[0].text.value, false)
|
// Pass the partial message ID to continue from it
|
||||||
|
sendMessage(
|
||||||
|
userMessage.content[0].text.value,
|
||||||
|
false,
|
||||||
|
undefined,
|
||||||
|
partialMessage.id
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,14 +40,20 @@ import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
|||||||
const createThreadContent = (
|
const createThreadContent = (
|
||||||
threadId: string,
|
threadId: string,
|
||||||
text: string,
|
text: string,
|
||||||
toolCalls: ChatCompletionMessageToolCall[]
|
toolCalls: ChatCompletionMessageToolCall[],
|
||||||
|
messageId?: string
|
||||||
) => {
|
) => {
|
||||||
return newAssistantThreadContent(threadId, text, {
|
const content = newAssistantThreadContent(threadId, text, {
|
||||||
tool_calls: toolCalls.map((e) => ({
|
tool_calls: toolCalls.map((e) => ({
|
||||||
...e,
|
...e,
|
||||||
state: 'pending',
|
state: 'pending',
|
||||||
})),
|
})),
|
||||||
})
|
})
|
||||||
|
// If continuing from a message, preserve the message ID
|
||||||
|
if (messageId) {
|
||||||
|
return { ...content, id: messageId }
|
||||||
|
}
|
||||||
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to cancel animation frame cross-platform
|
// Helper to cancel animation frame cross-platform
|
||||||
@ -66,9 +72,16 @@ const finalizeMessage = (
|
|||||||
addMessage: (message: ThreadMessage) => void,
|
addMessage: (message: ThreadMessage) => void,
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||||
updatePromptProgress: (progress: unknown) => void,
|
updatePromptProgress: (progress: unknown) => void,
|
||||||
updateThreadTimestamp: (threadId: string) => void
|
updateThreadTimestamp: (threadId: string) => void,
|
||||||
|
updateMessage?: (message: ThreadMessage) => void,
|
||||||
|
continueFromMessageId?: string
|
||||||
) => {
|
) => {
|
||||||
addMessage(finalContent)
|
// If continuing from a message, update it; otherwise add new message
|
||||||
|
if (continueFromMessageId && updateMessage) {
|
||||||
|
updateMessage({ ...finalContent, id: continueFromMessageId })
|
||||||
|
} else {
|
||||||
|
addMessage(finalContent)
|
||||||
|
}
|
||||||
updateStreamingContent(emptyThreadContent)
|
updateStreamingContent(emptyThreadContent)
|
||||||
updatePromptProgress(undefined)
|
updatePromptProgress(undefined)
|
||||||
updateThreadTimestamp(finalContent.thread_id)
|
updateThreadTimestamp(finalContent.thread_id)
|
||||||
@ -85,7 +98,9 @@ const processStreamingCompletion = async (
|
|||||||
currentCall: ChatCompletionMessageToolCall | null,
|
currentCall: ChatCompletionMessageToolCall | null,
|
||||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||||
updatePromptProgress: (progress: unknown) => void
|
updatePromptProgress: (progress: unknown) => void,
|
||||||
|
continueFromMessageId?: string,
|
||||||
|
updateMessage?: (message: ThreadMessage) => void
|
||||||
) => {
|
) => {
|
||||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||||
let rafScheduled = false
|
let rafScheduled = false
|
||||||
@ -97,9 +112,20 @@ const processStreamingCompletion = async (
|
|||||||
const currentContent = createThreadContent(
|
const currentContent = createThreadContent(
|
||||||
activeThread.id,
|
activeThread.id,
|
||||||
accumulatedText.value,
|
accumulatedText.value,
|
||||||
toolCalls
|
toolCalls,
|
||||||
|
continueFromMessageId
|
||||||
)
|
)
|
||||||
updateStreamingContent(currentContent)
|
|
||||||
|
// When continuing, update the message directly instead of using streamingContent
|
||||||
|
if (continueFromMessageId && updateMessage) {
|
||||||
|
updateMessage({
|
||||||
|
...currentContent,
|
||||||
|
status: MessageStatus.Stopped, // Keep as Stopped while streaming
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
updateStreamingContent(currentContent)
|
||||||
|
}
|
||||||
|
|
||||||
if (pendingDeltaCount > 0) {
|
if (pendingDeltaCount > 0) {
|
||||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||||
}
|
}
|
||||||
@ -232,6 +258,7 @@ export const useChat = () => {
|
|||||||
|
|
||||||
const getMessages = useMessages((state) => state.getMessages)
|
const getMessages = useMessages((state) => state.getMessages)
|
||||||
const addMessage = useMessages((state) => state.addMessage)
|
const addMessage = useMessages((state) => state.addMessage)
|
||||||
|
const updateMessage = useMessages((state) => state.updateMessage)
|
||||||
const setMessages = useMessages((state) => state.setMessages)
|
const setMessages = useMessages((state) => state.setMessages)
|
||||||
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
@ -396,7 +423,8 @@ export const useChat = () => {
|
|||||||
size: number
|
size: number
|
||||||
base64: string
|
base64: string
|
||||||
dataUrl: string
|
dataUrl: string
|
||||||
}>
|
}>,
|
||||||
|
continueFromMessageId?: string
|
||||||
) => {
|
) => {
|
||||||
const activeThread = await getCurrentThread()
|
const activeThread = await getCurrentThread()
|
||||||
const selectedProvider = useModelProvider.getState().selectedProvider
|
const selectedProvider = useModelProvider.getState().selectedProvider
|
||||||
@ -409,15 +437,24 @@ export const useChat = () => {
|
|||||||
setAbortController(activeThread.id, abortController)
|
setAbortController(activeThread.id, abortController)
|
||||||
updateStreamingContent(emptyThreadContent)
|
updateStreamingContent(emptyThreadContent)
|
||||||
updatePromptProgress(undefined)
|
updatePromptProgress(undefined)
|
||||||
// Do not add new message on retry
|
|
||||||
if (troubleshooting)
|
// Find the message to continue from if provided
|
||||||
|
const continueFromMessage = continueFromMessageId
|
||||||
|
? messages.find((m) => m.id === continueFromMessageId)
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// Do not add new message on retry or when continuing
|
||||||
|
if (troubleshooting && !continueFromMessageId)
|
||||||
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateThreadTimestamp(activeThread.id)
|
||||||
usePrompt.getState().setPrompt('')
|
usePrompt.getState().setPrompt('')
|
||||||
const selectedModel = useModelProvider.getState().selectedModel
|
const selectedModel = useModelProvider.getState().selectedModel
|
||||||
|
|
||||||
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
|
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
|
||||||
const accumulatedTextRef = { value: '' }
|
// If continuing, start with the previous content
|
||||||
|
const accumulatedTextRef = {
|
||||||
|
value: continueFromMessage?.content?.[0]?.text?.value || ''
|
||||||
|
}
|
||||||
let currentAssistant: Assistant | undefined
|
let currentAssistant: Assistant | undefined
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@ -427,13 +464,28 @@ export const useChat = () => {
|
|||||||
updateLoadingModel(false)
|
updateLoadingModel(false)
|
||||||
}
|
}
|
||||||
currentAssistant = useAssistant.getState().currentAssistant
|
currentAssistant = useAssistant.getState().currentAssistant
|
||||||
|
|
||||||
|
// Filter out the stopped message from context if continuing
|
||||||
|
const contextMessages = continueFromMessageId
|
||||||
|
? messages.filter((m) => m.id !== continueFromMessageId)
|
||||||
|
: messages
|
||||||
|
|
||||||
const builder = new CompletionMessagesBuilder(
|
const builder = new CompletionMessagesBuilder(
|
||||||
messages,
|
contextMessages,
|
||||||
currentAssistant
|
currentAssistant
|
||||||
? renderInstructions(currentAssistant.instructions)
|
? renderInstructions(currentAssistant.instructions)
|
||||||
: undefined
|
: undefined
|
||||||
)
|
)
|
||||||
if (troubleshooting) builder.addUserMessage(message, attachments)
|
if (troubleshooting && !continueFromMessageId) {
|
||||||
|
builder.addUserMessage(message, attachments)
|
||||||
|
} else if (continueFromMessage) {
|
||||||
|
// When continuing, add the partial assistant response to the context
|
||||||
|
builder.addAssistantMessage(
|
||||||
|
continueFromMessage.content?.[0]?.text?.value || '',
|
||||||
|
undefined,
|
||||||
|
[]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
let isCompleted = false
|
let isCompleted = false
|
||||||
|
|
||||||
@ -513,7 +565,9 @@ export const useChat = () => {
|
|||||||
currentCall,
|
currentCall,
|
||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updateTokenSpeed,
|
updateTokenSpeed,
|
||||||
updatePromptProgress
|
updatePromptProgress,
|
||||||
|
continueFromMessageId,
|
||||||
|
updateMessage
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -560,7 +614,7 @@ export const useChat = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a final content object for adding to the thread
|
// Create a final content object for adding to the thread
|
||||||
const finalContent = newAssistantThreadContent(
|
let finalContent = newAssistantThreadContent(
|
||||||
activeThread.id,
|
activeThread.id,
|
||||||
accumulatedTextRef.value,
|
accumulatedTextRef.value,
|
||||||
{
|
{
|
||||||
@ -569,6 +623,15 @@ export const useChat = () => {
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// If continuing from a message, preserve the ID and set status to Ready
|
||||||
|
if (continueFromMessageId) {
|
||||||
|
finalContent = {
|
||||||
|
...finalContent,
|
||||||
|
id: continueFromMessageId,
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Normal completion flow (abort is handled after loop exits)
|
// Normal completion flow (abort is handled after loop exits)
|
||||||
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
||||||
const updatedMessage = await postMessageProcessing(
|
const updatedMessage = await postMessageProcessing(
|
||||||
@ -585,7 +648,9 @@ export const useChat = () => {
|
|||||||
addMessage,
|
addMessage,
|
||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updatePromptProgress,
|
updatePromptProgress,
|
||||||
updateThreadTimestamp
|
updateThreadTimestamp,
|
||||||
|
updateMessage,
|
||||||
|
continueFromMessageId
|
||||||
)
|
)
|
||||||
|
|
||||||
isCompleted = !toolCalls.length
|
isCompleted = !toolCalls.length
|
||||||
@ -618,8 +683,13 @@ export const useChat = () => {
|
|||||||
),
|
),
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
}
|
}
|
||||||
// Save the partial message
|
|
||||||
addMessage(partialContent)
|
// If continuing, update the existing message; otherwise add new
|
||||||
|
if (continueFromMessageId) {
|
||||||
|
updateMessage({ ...partialContent, id: continueFromMessageId })
|
||||||
|
} else {
|
||||||
|
addMessage(partialContent)
|
||||||
|
}
|
||||||
updatePromptProgress(undefined)
|
updatePromptProgress(undefined)
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateThreadTimestamp(activeThread.id)
|
||||||
}
|
}
|
||||||
@ -649,7 +719,13 @@ export const useChat = () => {
|
|||||||
),
|
),
|
||||||
status: MessageStatus.Stopped,
|
status: MessageStatus.Stopped,
|
||||||
}
|
}
|
||||||
addMessage(partialContent)
|
|
||||||
|
// If continuing, update the existing message; otherwise add new
|
||||||
|
if (continueFromMessageId) {
|
||||||
|
updateMessage({ ...partialContent, id: continueFromMessageId })
|
||||||
|
} else {
|
||||||
|
addMessage(partialContent)
|
||||||
|
}
|
||||||
updatePromptProgress(undefined)
|
updatePromptProgress(undefined)
|
||||||
updateThreadTimestamp(activeThread.id)
|
updateThreadTimestamp(activeThread.id)
|
||||||
} else if (!abortController.signal.aborted) {
|
} else if (!abortController.signal.aborted) {
|
||||||
|
|||||||
@ -8,6 +8,7 @@ type MessageState = {
|
|||||||
getMessages: (threadId: string) => ThreadMessage[]
|
getMessages: (threadId: string) => ThreadMessage[]
|
||||||
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
||||||
addMessage: (message: ThreadMessage) => void
|
addMessage: (message: ThreadMessage) => void
|
||||||
|
updateMessage: (message: ThreadMessage) => void
|
||||||
deleteMessage: (threadId: string, messageId: string) => void
|
deleteMessage: (threadId: string, messageId: string) => void
|
||||||
clearAllMessages: () => void
|
clearAllMessages: () => void
|
||||||
}
|
}
|
||||||
@ -57,6 +58,36 @@ export const useMessages = create<MessageState>()((set, get) => ({
|
|||||||
console.error('Failed to persist message:', error)
|
console.error('Failed to persist message:', error)
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
updateMessage: (message) => {
|
||||||
|
const assistants = useAssistant.getState().assistants
|
||||||
|
const currentAssistant = useAssistant.getState().currentAssistant
|
||||||
|
|
||||||
|
const selectedAssistant =
|
||||||
|
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
|
||||||
|
|
||||||
|
const updatedMessage = {
|
||||||
|
...message,
|
||||||
|
metadata: {
|
||||||
|
...message.metadata,
|
||||||
|
assistant: selectedAssistant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimistically update state immediately for instant UI feedback
|
||||||
|
set((state) => ({
|
||||||
|
messages: {
|
||||||
|
...state.messages,
|
||||||
|
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
|
||||||
|
m.id === message.id ? updatedMessage : m
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Persist to storage asynchronously
|
||||||
|
getServiceHub().messages().createMessage(updatedMessage).catch((error) => {
|
||||||
|
console.error('Failed to persist message update:', error)
|
||||||
|
})
|
||||||
|
},
|
||||||
deleteMessage: (threadId, messageId) => {
|
deleteMessage: (threadId, messageId) => {
|
||||||
getServiceHub().messages().deleteMessage(threadId, messageId)
|
getServiceHub().messages().deleteMessage(threadId, messageId)
|
||||||
set((state) => ({
|
set((state) => ({
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user