feat: Continue with AI response for llamacpp
This commit is contained in:
parent
4ea9d296ea
commit
f4b187ba11
@ -33,14 +33,18 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
|
||||
}, [messages])
|
||||
|
||||
const generateAIResponse = () => {
|
||||
// If continuing a partial response, delete the partial message first
|
||||
// If continuing a partial response, keep the message and continue from it
|
||||
if (isPartialResponse) {
|
||||
const partialMessage = messages[messages.length - 1]
|
||||
deleteMessage(partialMessage.thread_id, partialMessage.id ?? '')
|
||||
// Get the user message that prompted this partial response
|
||||
const userMessage = messages[messages.length - 2]
|
||||
if (userMessage?.content?.[0]?.text?.value) {
|
||||
sendMessage(userMessage.content[0].text.value, false)
|
||||
// Pass the partial message ID to continue from it
|
||||
sendMessage(
|
||||
userMessage.content[0].text.value,
|
||||
false,
|
||||
undefined,
|
||||
partialMessage.id
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -40,14 +40,20 @@ import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
|
||||
const createThreadContent = (
|
||||
threadId: string,
|
||||
text: string,
|
||||
toolCalls: ChatCompletionMessageToolCall[]
|
||||
toolCalls: ChatCompletionMessageToolCall[],
|
||||
messageId?: string
|
||||
) => {
|
||||
return newAssistantThreadContent(threadId, text, {
|
||||
const content = newAssistantThreadContent(threadId, text, {
|
||||
tool_calls: toolCalls.map((e) => ({
|
||||
...e,
|
||||
state: 'pending',
|
||||
})),
|
||||
})
|
||||
// If continuing from a message, preserve the message ID
|
||||
if (messageId) {
|
||||
return { ...content, id: messageId }
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
// Helper to cancel animation frame cross-platform
|
||||
@ -66,9 +72,16 @@ const finalizeMessage = (
|
||||
addMessage: (message: ThreadMessage) => void,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updatePromptProgress: (progress: unknown) => void,
|
||||
updateThreadTimestamp: (threadId: string) => void
|
||||
updateThreadTimestamp: (threadId: string) => void,
|
||||
updateMessage?: (message: ThreadMessage) => void,
|
||||
continueFromMessageId?: string
|
||||
) => {
|
||||
addMessage(finalContent)
|
||||
// If continuing from a message, update it; otherwise add new message
|
||||
if (continueFromMessageId && updateMessage) {
|
||||
updateMessage({ ...finalContent, id: continueFromMessageId })
|
||||
} else {
|
||||
addMessage(finalContent)
|
||||
}
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(finalContent.thread_id)
|
||||
@ -85,7 +98,9 @@ const processStreamingCompletion = async (
|
||||
currentCall: ChatCompletionMessageToolCall | null,
|
||||
updateStreamingContent: (content: ThreadMessage | undefined) => void,
|
||||
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
|
||||
updatePromptProgress: (progress: unknown) => void
|
||||
updatePromptProgress: (progress: unknown) => void,
|
||||
continueFromMessageId?: string,
|
||||
updateMessage?: (message: ThreadMessage) => void
|
||||
) => {
|
||||
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
|
||||
let rafScheduled = false
|
||||
@ -97,9 +112,20 @@ const processStreamingCompletion = async (
|
||||
const currentContent = createThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedText.value,
|
||||
toolCalls
|
||||
toolCalls,
|
||||
continueFromMessageId
|
||||
)
|
||||
updateStreamingContent(currentContent)
|
||||
|
||||
// When continuing, update the message directly instead of using streamingContent
|
||||
if (continueFromMessageId && updateMessage) {
|
||||
updateMessage({
|
||||
...currentContent,
|
||||
status: MessageStatus.Stopped, // Keep as Stopped while streaming
|
||||
})
|
||||
} else {
|
||||
updateStreamingContent(currentContent)
|
||||
}
|
||||
|
||||
if (pendingDeltaCount > 0) {
|
||||
updateTokenSpeed(currentContent, pendingDeltaCount)
|
||||
}
|
||||
@ -232,6 +258,7 @@ export const useChat = () => {
|
||||
|
||||
const getMessages = useMessages((state) => state.getMessages)
|
||||
const addMessage = useMessages((state) => state.addMessage)
|
||||
const updateMessage = useMessages((state) => state.updateMessage)
|
||||
const setMessages = useMessages((state) => state.setMessages)
|
||||
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
|
||||
const router = useRouter()
|
||||
@ -396,7 +423,8 @@ export const useChat = () => {
|
||||
size: number
|
||||
base64: string
|
||||
dataUrl: string
|
||||
}>
|
||||
}>,
|
||||
continueFromMessageId?: string
|
||||
) => {
|
||||
const activeThread = await getCurrentThread()
|
||||
const selectedProvider = useModelProvider.getState().selectedProvider
|
||||
@ -409,15 +437,24 @@ export const useChat = () => {
|
||||
setAbortController(activeThread.id, abortController)
|
||||
updateStreamingContent(emptyThreadContent)
|
||||
updatePromptProgress(undefined)
|
||||
// Do not add new message on retry
|
||||
if (troubleshooting)
|
||||
|
||||
// Find the message to continue from if provided
|
||||
const continueFromMessage = continueFromMessageId
|
||||
? messages.find((m) => m.id === continueFromMessageId)
|
||||
: undefined
|
||||
|
||||
// Do not add new message on retry or when continuing
|
||||
if (troubleshooting && !continueFromMessageId)
|
||||
addMessage(newUserThreadContent(activeThread.id, message, attachments))
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
usePrompt.getState().setPrompt('')
|
||||
const selectedModel = useModelProvider.getState().selectedModel
|
||||
|
||||
// Declare accumulatedTextRef BEFORE try block so it's accessible in catch block
|
||||
const accumulatedTextRef = { value: '' }
|
||||
// If continuing, start with the previous content
|
||||
const accumulatedTextRef = {
|
||||
value: continueFromMessage?.content?.[0]?.text?.value || ''
|
||||
}
|
||||
let currentAssistant: Assistant | undefined
|
||||
|
||||
try {
|
||||
@ -427,13 +464,28 @@ export const useChat = () => {
|
||||
updateLoadingModel(false)
|
||||
}
|
||||
currentAssistant = useAssistant.getState().currentAssistant
|
||||
|
||||
// Filter out the stopped message from context if continuing
|
||||
const contextMessages = continueFromMessageId
|
||||
? messages.filter((m) => m.id !== continueFromMessageId)
|
||||
: messages
|
||||
|
||||
const builder = new CompletionMessagesBuilder(
|
||||
messages,
|
||||
contextMessages,
|
||||
currentAssistant
|
||||
? renderInstructions(currentAssistant.instructions)
|
||||
: undefined
|
||||
)
|
||||
if (troubleshooting) builder.addUserMessage(message, attachments)
|
||||
if (troubleshooting && !continueFromMessageId) {
|
||||
builder.addUserMessage(message, attachments)
|
||||
} else if (continueFromMessage) {
|
||||
// When continuing, add the partial assistant response to the context
|
||||
builder.addAssistantMessage(
|
||||
continueFromMessage.content?.[0]?.text?.value || '',
|
||||
undefined,
|
||||
[]
|
||||
)
|
||||
}
|
||||
|
||||
let isCompleted = false
|
||||
|
||||
@ -513,7 +565,9 @@ export const useChat = () => {
|
||||
currentCall,
|
||||
updateStreamingContent,
|
||||
updateTokenSpeed,
|
||||
updatePromptProgress
|
||||
updatePromptProgress,
|
||||
continueFromMessageId,
|
||||
updateMessage
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
@ -560,7 +614,7 @@ export const useChat = () => {
|
||||
}
|
||||
|
||||
// Create a final content object for adding to the thread
|
||||
const finalContent = newAssistantThreadContent(
|
||||
let finalContent = newAssistantThreadContent(
|
||||
activeThread.id,
|
||||
accumulatedTextRef.value,
|
||||
{
|
||||
@ -569,6 +623,15 @@ export const useChat = () => {
|
||||
}
|
||||
)
|
||||
|
||||
// If continuing from a message, preserve the ID and set status to Ready
|
||||
if (continueFromMessageId) {
|
||||
finalContent = {
|
||||
...finalContent,
|
||||
id: continueFromMessageId,
|
||||
status: MessageStatus.Ready,
|
||||
}
|
||||
}
|
||||
|
||||
// Normal completion flow (abort is handled after loop exits)
|
||||
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
|
||||
const updatedMessage = await postMessageProcessing(
|
||||
@ -585,7 +648,9 @@ export const useChat = () => {
|
||||
addMessage,
|
||||
updateStreamingContent,
|
||||
updatePromptProgress,
|
||||
updateThreadTimestamp
|
||||
updateThreadTimestamp,
|
||||
updateMessage,
|
||||
continueFromMessageId
|
||||
)
|
||||
|
||||
isCompleted = !toolCalls.length
|
||||
@ -618,8 +683,13 @@ export const useChat = () => {
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
// Save the partial message
|
||||
addMessage(partialContent)
|
||||
|
||||
// If continuing, update the existing message; otherwise add new
|
||||
if (continueFromMessageId) {
|
||||
updateMessage({ ...partialContent, id: continueFromMessageId })
|
||||
} else {
|
||||
addMessage(partialContent)
|
||||
}
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
}
|
||||
@ -649,7 +719,13 @@ export const useChat = () => {
|
||||
),
|
||||
status: MessageStatus.Stopped,
|
||||
}
|
||||
addMessage(partialContent)
|
||||
|
||||
// If continuing, update the existing message; otherwise add new
|
||||
if (continueFromMessageId) {
|
||||
updateMessage({ ...partialContent, id: continueFromMessageId })
|
||||
} else {
|
||||
addMessage(partialContent)
|
||||
}
|
||||
updatePromptProgress(undefined)
|
||||
updateThreadTimestamp(activeThread.id)
|
||||
} else if (!abortController.signal.aborted) {
|
||||
|
||||
@ -8,6 +8,7 @@ type MessageState = {
|
||||
getMessages: (threadId: string) => ThreadMessage[]
|
||||
setMessages: (threadId: string, messages: ThreadMessage[]) => void
|
||||
addMessage: (message: ThreadMessage) => void
|
||||
updateMessage: (message: ThreadMessage) => void
|
||||
deleteMessage: (threadId: string, messageId: string) => void
|
||||
clearAllMessages: () => void
|
||||
}
|
||||
@ -57,6 +58,36 @@ export const useMessages = create<MessageState>()((set, get) => ({
|
||||
console.error('Failed to persist message:', error)
|
||||
})
|
||||
},
|
||||
updateMessage: (message) => {
|
||||
const assistants = useAssistant.getState().assistants
|
||||
const currentAssistant = useAssistant.getState().currentAssistant
|
||||
|
||||
const selectedAssistant =
|
||||
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
|
||||
|
||||
const updatedMessage = {
|
||||
...message,
|
||||
metadata: {
|
||||
...message.metadata,
|
||||
assistant: selectedAssistant,
|
||||
},
|
||||
}
|
||||
|
||||
// Optimistically update state immediately for instant UI feedback
|
||||
set((state) => ({
|
||||
messages: {
|
||||
...state.messages,
|
||||
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
|
||||
m.id === message.id ? updatedMessage : m
|
||||
),
|
||||
},
|
||||
}))
|
||||
|
||||
// Persist to storage asynchronously
|
||||
getServiceHub().messages().createMessage(updatedMessage).catch((error) => {
|
||||
console.error('Failed to persist message update:', error)
|
||||
})
|
||||
},
|
||||
deleteMessage: (threadId, messageId) => {
|
||||
getServiceHub().messages().deleteMessage(threadId, messageId)
|
||||
set((state) => ({
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user