Compare commits

...

12 Commits

Author SHA1 Message Date
Nghia Doan
35264e9a22
Merge branch 'dev' into feat/retain-interruption-message 2025-10-03 14:10:43 +07:00
Vanalite
e7c9275488 fix: Fix tests on useChat 2025-10-03 10:17:27 +07:00
Vanalite
4ac45aba23 Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message 2025-10-03 09:47:08 +07:00
Vanalite
34036d895a Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message
# Conflicts:
#	web-app/src/containers/ChatInput.tsx
#	web-app/src/hooks/useChat.ts
2025-10-02 10:59:19 +07:00
Vanalite
7127ff1244 fix: Exposing PromptProgress to be passed as param 2025-10-01 21:52:30 +07:00
Vanalite
1c0e135077 Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message
# Conflicts:
#	web-app/src/hooks/useChat.ts
2025-10-01 19:35:23 +07:00
Vanalite
99473ed568 fix: Consolidate comments 2025-10-01 19:21:34 +07:00
Vanalite
52f73af08c feat: Add tests for the Continuing with AI response 2025-10-01 19:13:10 +07:00
Vanalite
ccca331d6c feat: Modify on-going response instead of creating new message to avoid message ID duplication 2025-10-01 17:14:59 +07:00
Vanalite
f4b187ba11 feat: Continue with AI response for llamacpp 2025-10-01 16:43:27 +07:00
Vanalite
4ea9d296ea feat: Continue with AI response button if it got interrupted 2025-10-01 16:05:58 +07:00
Vanalite
2e86d4e421 feat: Allow to save the last message upon interrupting llm response 2025-10-01 15:43:05 +07:00
14 changed files with 967 additions and 256 deletions

View File

@ -20,6 +20,13 @@ export interface MessageInterface {
*/ */
listMessages(threadId: string): Promise<ThreadMessage[]> listMessages(threadId: string): Promise<ThreadMessage[]>
/**
* Updates an existing message in a thread.
* @param {ThreadMessage} message - The message to be updated (must have existing ID).
* @returns {Promise<ThreadMessage>} A promise that resolves to the updated message.
*/
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
/** /**
* Deletes a specific message from a thread. * Deletes a specific message from a thread.
* @param {string} threadId - The ID of the thread from which the message will be deleted. * @param {string} threadId - The ID of the thread from which the message will be deleted.

View File

@ -176,7 +176,6 @@ const ChatInput = ({
const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP) const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP)
const MCPToolComponent = mcpExtension?.getToolComponent?.() const MCPToolComponent = mcpExtension?.getToolComponent?.()
const handleSendMesage = async (prompt: string) => { const handleSendMesage = async (prompt: string) => {
if (!selectedModel) { if (!selectedModel) {
setMessage('Please select a model to start chatting.') setMessage('Please select a model to start chatting.')

View File

@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { Play } from 'lucide-react' import { Play } from 'lucide-react'
import { useShallow } from 'zustand/react/shallow' import { useShallow } from 'zustand/react/shallow'
import { useMemo } from 'react'
import { MessageStatus } from '@janhq/core'
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => { export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
const { t } = useTranslation() const { t } = useTranslation()
@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
})) }))
) )
const sendMessage = useChat() const sendMessage = useChat()
// Detect if last message is a partial assistant response (user stopped midway)
const isPartialResponse = useMemo(() => {
if (!messages || messages.length < 2) return false
const lastMessage = messages[messages.length - 1]
const secondLastMessage = messages[messages.length - 2]
return (
lastMessage?.role === 'assistant' &&
lastMessage?.status === MessageStatus.Stopped &&
secondLastMessage?.role === 'user' &&
!lastMessage?.metadata?.tool_calls
)
}, [messages])
const generateAIResponse = () => { const generateAIResponse = () => {
if (isPartialResponse) {
const partialMessage = messages[messages.length - 1]
const userMessage = messages[messages.length - 2]
if (userMessage?.content?.[0]?.text?.value) {
sendMessage(
userMessage.content[0].text.value,
false,
undefined,
partialMessage.id
)
}
return
}
const latestUserMessage = messages[messages.length - 1] const latestUserMessage = messages[messages.length - 1]
if ( if (
latestUserMessage?.content?.[0]?.text?.value && latestUserMessage?.content?.[0]?.text?.value &&
@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto" className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
onClick={generateAIResponse} onClick={generateAIResponse}
> >
<p className="text-xs">{t('common:generateAiResponse')}</p> <p className="text-xs">
{isPartialResponse
? t('common:continueAiResponse')
: t('common:generateAiResponse')}
</p>
<Play size={12} /> <Play size={12} />
</div> </div>
) )

View File

@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
import { ArrowDown } from 'lucide-react' import { ArrowDown } from 'lucide-react'
import { useTranslation } from '@/i18n/react-i18next-compat' import { useTranslation } from '@/i18n/react-i18next-compat'
import { useAppState } from '@/hooks/useAppState' import { useAppState } from '@/hooks/useAppState'
import { MessageStatus } from '@janhq/core'
const ScrollToBottom = ({ const ScrollToBottom = ({
threadId, threadId,
@ -28,11 +29,20 @@ const ScrollToBottom = ({
const streamingContent = useAppState((state) => state.streamingContent) const streamingContent = useAppState((state) => state.streamingContent)
// Check if last message is a partial assistant response and show continue buton (user interrupted)
const isPartialResponse =
messages.length >= 2 &&
messages[messages.length - 1]?.role === 'assistant' &&
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
messages[messages.length - 2]?.role === 'user' &&
!messages[messages.length - 1]?.metadata?.tool_calls
const showGenerateAIResponseBtn = const showGenerateAIResponseBtn =
(messages[messages.length - 1]?.role === 'user' || ((messages[messages.length - 1]?.role === 'user' ||
(messages[messages.length - 1]?.metadata && (messages[messages.length - 1]?.metadata &&
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) && 'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
!streamingContent isPartialResponse) &&
!streamingContent)
return ( return (
<div <div

View File

@ -2,6 +2,7 @@ import { useAppState } from '@/hooks/useAppState'
import { ThreadContent } from './ThreadContent' import { ThreadContent } from './ThreadContent'
import { memo, useMemo } from 'react' import { memo, useMemo } from 'react'
import { useMessages } from '@/hooks/useMessages' import { useMessages } from '@/hooks/useMessages'
import { MessageStatus } from '@janhq/core'
type Props = { type Props = {
threadId: string threadId: string
@ -48,12 +49,19 @@ export const StreamingContent = memo(({ threadId }: Props) => {
return extractReasoningSegment(text) return extractReasoningSegment(text)
}, [lastAssistant]) }, [lastAssistant])
if (!streamingContent || streamingContent.thread_id !== threadId) return null if (!streamingContent || streamingContent.thread_id !== threadId) {
return null
}
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) { if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
return null return null
} }
// Don't show streaming content if there's already a stopped message
if (lastAssistant?.status === MessageStatus.Stopped) {
return null
}
// Pass a new object to ThreadContent to avoid reference issues // Pass a new object to ThreadContent to avoid reference issues
// The streaming content is always the last message // The streaming content is always the last message
return ( return (

View File

@ -1,6 +1,30 @@
import { renderHook, act } from '@testing-library/react' import { renderHook, act, waitFor } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach } from 'vitest' import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { useChat } from '../useChat' import { MessageStatus, ContentType } from '@janhq/core'
// Store mock functions for assertions - initialize immediately
const mockAddMessage = vi.fn()
const mockUpdateMessage = vi.fn()
const mockGetMessages = vi.fn(() => [])
const mockStartModel = vi.fn(() => Promise.resolve())
const mockSendCompletion = vi.fn(() => Promise.resolve({
choices: [{
message: {
content: 'AI response',
role: 'assistant',
},
}],
}))
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
Promise.resolve(content)
)
const mockCompletionMessagesBuilder = {
addUserMessage: vi.fn(),
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []),
}
const mockSetPrompt = vi.fn()
const mockResetTokenSpeed = vi.fn()
// Mock dependencies // Mock dependencies
vi.mock('../usePrompt', () => ({ vi.mock('../usePrompt', () => ({
@ -8,11 +32,16 @@ vi.mock('../usePrompt', () => ({
(selector: any) => { (selector: any) => {
const state = { const state = {
prompt: 'test prompt', prompt: 'test prompt',
setPrompt: vi.fn(), setPrompt: mockSetPrompt,
} }
return selector ? selector(state) : state return selector ? selector(state) : state
}, },
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) } {
getState: () => ({
prompt: 'test prompt',
setPrompt: mockSetPrompt
})
}
), ),
})) }))
@ -22,39 +51,58 @@ vi.mock('../useAppState', () => ({
const state = { const state = {
tools: [], tools: [],
updateTokenSpeed: vi.fn(), updateTokenSpeed: vi.fn(),
resetTokenSpeed: vi.fn(), resetTokenSpeed: mockResetTokenSpeed,
updateTools: vi.fn(), updateTools: vi.fn(),
updateStreamingContent: vi.fn(), updateStreamingContent: vi.fn(),
updatePromptProgress: vi.fn(), updatePromptProgress: vi.fn(),
updateLoadingModel: vi.fn(), updateLoadingModel: vi.fn(),
setAbortController: vi.fn(), setAbortController: vi.fn(),
streamingContent: undefined,
} }
return selector ? selector(state) : state return selector ? selector(state) : state
}, },
{ {
getState: vi.fn(() => ({ getState: vi.fn(() => ({
tools: [],
tokenSpeed: { tokensPerSecond: 10 }, tokenSpeed: { tokensPerSecond: 10 },
streamingContent: undefined,
})) }))
} }
), ),
})) }))
vi.mock('../useAssistant', () => ({ vi.mock('../useAssistant', () => ({
useAssistant: (selector: any) => { useAssistant: Object.assign(
const state = { (selector: any) => {
assistants: [{ const state = {
id: 'test-assistant', assistants: [{
instructions: 'test instructions', id: 'test-assistant',
parameters: { stream: true }, instructions: 'test instructions',
}], parameters: { stream: true },
currentAssistant: { }],
id: 'test-assistant', currentAssistant: {
instructions: 'test instructions', id: 'test-assistant',
parameters: { stream: true }, instructions: 'test instructions',
}, parameters: { stream: true },
},
}
return selector ? selector(state) : state
},
{
getState: () => ({
assistants: [{
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
}],
currentAssistant: {
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
},
})
} }
return selector ? selector(state) : state ),
},
})) }))
vi.mock('../useModelProvider', () => ({ vi.mock('../useModelProvider', () => ({
@ -62,14 +110,15 @@ vi.mock('../useModelProvider', () => ({
(selector: any) => { (selector: any) => {
const state = { const state = {
getProviderByName: vi.fn(() => ({ getProviderByName: vi.fn(() => ({
provider: 'openai', provider: 'llamacpp',
models: [], models: [],
settings: [],
})), })),
selectedModel: { selectedModel: {
id: 'test-model', id: 'test-model',
capabilities: ['tools'], capabilities: ['tools'],
}, },
selectedProvider: 'openai', selectedProvider: 'llamacpp',
updateProvider: vi.fn(), updateProvider: vi.fn(),
} }
return selector ? selector(state) : state return selector ? selector(state) : state
@ -77,14 +126,15 @@ vi.mock('../useModelProvider', () => ({
{ {
getState: () => ({ getState: () => ({
getProviderByName: vi.fn(() => ({ getProviderByName: vi.fn(() => ({
provider: 'openai', provider: 'llamacpp',
models: [], models: [],
settings: [],
})), })),
selectedModel: { selectedModel: {
id: 'test-model', id: 'test-model',
capabilities: ['tools'], capabilities: ['tools'],
}, },
selectedProvider: 'openai', selectedProvider: 'llamacpp',
updateProvider: vi.fn(), updateProvider: vi.fn(),
}) })
} }
@ -96,11 +146,11 @@ vi.mock('../useThreads', () => ({
const state = { const state = {
getCurrentThread: vi.fn(() => ({ getCurrentThread: vi.fn(() => ({
id: 'test-thread', id: 'test-thread',
model: { id: 'test-model', provider: 'openai' }, model: { id: 'test-model', provider: 'llamacpp' },
})), })),
createThread: vi.fn(() => Promise.resolve({ createThread: vi.fn(() => Promise.resolve({
id: 'test-thread', id: 'test-thread',
model: { id: 'test-model', provider: 'openai' }, model: { id: 'test-model', provider: 'llamacpp' },
})), })),
updateThreadTimestamp: vi.fn(), updateThreadTimestamp: vi.fn(),
} }
@ -111,22 +161,33 @@ vi.mock('../useThreads', () => ({
vi.mock('../useMessages', () => ({ vi.mock('../useMessages', () => ({
useMessages: (selector: any) => { useMessages: (selector: any) => {
const state = { const state = {
getMessages: vi.fn(() => []), getMessages: mockGetMessages,
addMessage: vi.fn(), addMessage: mockAddMessage,
updateMessage: mockUpdateMessage,
setMessages: vi.fn(),
} }
return selector ? selector(state) : state return selector ? selector(state) : state
}, },
})) }))
vi.mock('../useToolApproval', () => ({ vi.mock('../useToolApproval', () => ({
useToolApproval: (selector: any) => { useToolApproval: Object.assign(
const state = { (selector: any) => {
approvedTools: [], const state = {
showApprovalModal: vi.fn(), approvedTools: [],
allowAllMCPPermissions: false, showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
}
return selector ? selector(state) : state
},
{
getState: () => ({
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
})
} }
return selector ? selector(state) : state ),
},
})) }))
vi.mock('../useToolAvailable', () => ({ vi.mock('../useToolAvailable', () => ({
@ -162,38 +223,57 @@ vi.mock('@tanstack/react-router', () => ({
})), })),
})) }))
vi.mock('../useServiceHub', () => ({
useServiceHub: vi.fn(() => ({
models: () => ({
startModel: mockStartModel,
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}),
providers: () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}),
})),
}))
vi.mock('@/lib/completion', () => ({ vi.mock('@/lib/completion', () => ({
emptyThreadContent: { thread_id: 'test-thread', content: '' }, emptyThreadContent: { thread_id: 'test-thread', content: '' },
extractToolCall: vi.fn(), extractToolCall: vi.fn(),
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })), newUserThreadContent: vi.fn((threadId, content) => ({
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })), thread_id: threadId,
sendCompletion: vi.fn(), content: [{ type: 'text', text: { value: content, annotations: [] } }],
postMessageProcessing: vi.fn(), role: 'user'
isCompletionResponse: vi.fn(), })),
newAssistantThreadContent: vi.fn((threadId, content) => ({
thread_id: threadId,
content: [{ type: 'text', text: { value: content, annotations: [] } }],
role: 'assistant'
})),
sendCompletion: mockSendCompletion,
postMessageProcessing: mockPostMessageProcessing,
isCompletionResponse: vi.fn(() => true),
})) }))
vi.mock('@/lib/messages', () => ({ vi.mock('@/lib/messages', () => ({
CompletionMessagesBuilder: vi.fn(() => ({ CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
addUserMessage: vi.fn(), }))
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []), vi.mock('@/lib/instructionTemplate', () => ({
renderInstructions: vi.fn((instructions: string) => instructions),
}))
vi.mock('@/utils/reasoning', () => ({
ReasoningProcessor: vi.fn(() => ({
processReasoningChunk: vi.fn(() => null),
finalize: vi.fn(() => ''),
})), })),
extractReasoningFromMessage: vi.fn(() => null),
})) }))
vi.mock('@/services/mcp', () => ({ vi.mock('@/services/mcp', () => ({
getTools: vi.fn(() => Promise.resolve([])), getTools: vi.fn(() => Promise.resolve([])),
})) }))
vi.mock('@/services/models', () => ({
startModel: vi.fn(() => Promise.resolve()),
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/services/providers', () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}))
vi.mock('@tauri-apps/api/event', () => ({ vi.mock('@tauri-apps/api/event', () => ({
listen: vi.fn(() => Promise.resolve(vi.fn())), listen: vi.fn(() => Promise.resolve(vi.fn())),
})) }))
@ -204,9 +284,41 @@ vi.mock('sonner', () => ({
}, },
})) }))
// Import after mocks to avoid hoisting issues
const { useChat } = await import('../useChat')
const completionLib = await import('@/lib/completion')
const messagesLib = await import('@/lib/messages')
describe('useChat', () => { describe('useChat', () => {
beforeEach(() => { beforeEach(() => {
// Clear mock call history
vi.clearAllMocks() vi.clearAllMocks()
// Reset mock implementations
mockAddMessage.mockClear()
mockUpdateMessage.mockClear()
mockGetMessages.mockReturnValue([])
mockStartModel.mockResolvedValue(undefined)
mockSetPrompt.mockClear()
mockResetTokenSpeed.mockClear()
mockSendCompletion.mockResolvedValue({
choices: [{
message: {
content: 'AI response',
role: 'assistant',
},
}],
})
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
Promise.resolve(content)
)
mockCompletionMessagesBuilder.addUserMessage.mockClear()
mockCompletionMessagesBuilder.addAssistantMessage.mockClear()
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
})
afterEach(() => {
vi.clearAllTimers()
}) })
it('returns sendMessage function', () => { it('returns sendMessage function', () => {
@ -216,13 +328,270 @@ describe('useChat', () => {
expect(typeof result.current).toBe('function') expect(typeof result.current).toBe('function')
}) })
it('sends message successfully', async () => { describe('Continue with AI response functionality', () => {
const { result } = renderHook(() => useChat()) it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
const { result } = renderHook(() => useChat())
await act(async () => { await act(async () => {
await result.current('Hello world') await result.current('Hello world', true, undefined, undefined, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Hello world',
undefined
)
expect(mockAddMessage).toHaveBeenCalledWith(
expect.objectContaining({
thread_id: 'test-thread',
role: 'user',
})
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Hello world',
undefined
)
}) })
expect(result.current).toBeDefined() it('should NOT add new user message when continueFromMessageId is provided', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
const userMessageCalls = mockAddMessage.mock.calls.filter(
(call: any) => call[0]?.role === 'user'
)
expect(userMessageCalls).toHaveLength(0)
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
})
it('should add partial assistant message to builder when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// Should be called twice: once with partial message (line 517-521), once after completion (line 689)
const assistantCalls = mockCompletionMessagesBuilder.addAssistantMessage.mock.calls
expect(assistantCalls.length).toBeGreaterThanOrEqual(1)
// First call should be with the partial response content
expect(assistantCalls[0]).toEqual([
'Partial response',
undefined,
[]
])
})
it('should filter out stopped message from context when continuing', async () => {
const userMsg = {
id: 'msg-1',
thread_id: 'test-thread',
role: 'user',
content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }],
}
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
}
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// The CompletionMessagesBuilder is called with filtered messages (line 507-512)
// The stopped message should be filtered out from the context
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalled()
const builderCall = (messagesLib.CompletionMessagesBuilder as any).mock.calls[0]
expect(builderCall[0]).toEqual([userMsg]) // stopped message filtered out
expect(builderCall[1]).toEqual('test instructions')
})
it('should update existing message instead of adding new one when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// finalizeMessage is called at line 700-708, which should update the message
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
it('should start with previous content when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
mockSendCompletion.mockResolvedValue({
choices: [{
message: {
content: ' continued',
role: 'assistant',
},
}],
})
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// The accumulated text should contain the previous content plus new content
// accumulatedTextRef starts with 'Partial response' (line 490)
// Then gets ' continued' appended (line 585)
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
content: expect.arrayContaining([
expect.objectContaining({
text: expect.objectContaining({
value: 'Partial response continued',
})
})
])
})
)
})
it('should handle attachments correctly when not continuing', async () => {
const { result } = renderHook(() => useChat())
const attachments = [
{
name: 'test.png',
type: 'image/png',
size: 1024,
base64: 'base64data',
dataUrl: 'data:image/png;base64,base64data',
},
]
await act(async () => {
await result.current('Message with attachment', true, attachments, undefined, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Message with attachment',
attachments
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Message with attachment',
attachments
)
})
it('should preserve message status as Ready after continuation completes', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// finalContent is created at line 678-683 with status Ready when continuing
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
})
describe('Normal message sending', () => {
it('sends message successfully without continuation', async () => {
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Hello world')
})
expect(mockSendCompletion).toHaveBeenCalled()
expect(mockStartModel).toHaveBeenCalled()
})
})
describe('Error handling', () => {
it('should handle errors gracefully during continuation', async () => {
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
expect(result.current).toBeDefined()
})
}) })
}) })

View File

@ -225,9 +225,25 @@ describe('useMessages', () => {
}) })
) )
// Wait for async operation // Message should be immediately available (optimistic update)
expect(result.current.messages['thread1']).toContainEqual(
expect.objectContaining({
id: messageToAdd.id,
thread_id: messageToAdd.thread_id,
role: messageToAdd.role,
content: messageToAdd.content,
metadata: expect.objectContaining({
assistant: expect.objectContaining({
id: expect.any(String),
name: expect.any(String),
}),
}),
})
)
// Verify persistence was attempted
await vi.waitFor(() => { await vi.waitFor(() => {
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage) expect(mockCreateMessage).toHaveBeenCalled()
}) })
}) })

View File

@ -4,7 +4,7 @@ import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant' import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources' import { ChatCompletionMessageToolCall } from 'openai/resources'
type PromptProgress = { export type PromptProgress = {
cache: number cache: number
processed: number processed: number
time_ms: number time_ms: number

View File

@ -3,7 +3,7 @@ import { flushSync } from 'react-dom'
import { usePrompt } from './usePrompt' import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider' import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads' import { useThreads } from './useThreads'
import { useAppState } from './useAppState' import { useAppState, type PromptProgress } from './useAppState'
import { useMessages } from './useMessages' import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router' import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models' import { defaultModel } from '@/lib/models'
@ -23,6 +23,7 @@ import {
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
CompletionUsage, CompletionUsage,
} from 'openai/resources' } from 'openai/resources'
import { MessageStatus, ContentType, ThreadMessage } from '@janhq/core'
import { useServiceHub } from '@/hooks/useServiceHub' import { useServiceHub } from '@/hooks/useServiceHub'
import { useToolApproval } from '@/hooks/useToolApproval' import { useToolApproval } from '@/hooks/useToolApproval'
@ -38,6 +39,198 @@ import { useAssistant } from './useAssistant'
import { useShallow } from 'zustand/shallow' import { useShallow } from 'zustand/shallow'
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat' import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
// Helper to create thread content with consistent structure
const createThreadContent = (
threadId: string,
text: string,
toolCalls: ChatCompletionMessageToolCall[],
messageId?: string
) => {
const content = newAssistantThreadContent(threadId, text, {
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
})
// If continuing from a message, preserve the message ID
if (messageId) {
return { ...content, id: messageId }
}
return content
}
// Helper to cancel animation frame cross-platform
const cancelFrame = (handle: number | undefined) => {
if (handle === undefined) return
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(handle)
} else {
clearTimeout(handle)
}
}
// Helper to finalize and save a message
const finalizeMessage = (
finalContent: ThreadMessage,
addMessage: (message: ThreadMessage) => void,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updatePromptProgress: (progress: PromptProgress | undefined) => void,
updateThreadTimestamp: (threadId: string) => void,
updateMessage?: (message: ThreadMessage) => void,
continueFromMessageId?: string
) => {
// If continuing from a message, update it; otherwise add new message
if (continueFromMessageId && updateMessage) {
updateMessage({ ...finalContent, id: continueFromMessageId })
} else {
addMessage(finalContent)
}
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(finalContent.thread_id)
}
// Helper to process streaming completion
const processStreamingCompletion = async (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
completion: AsyncIterable<any>,
abortController: AbortController,
activeThread: Thread,
accumulatedText: { value: string },
toolCalls: ChatCompletionMessageToolCall[],
currentCall: ChatCompletionMessageToolCall | null,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void,
updatePromptProgress: (progress: PromptProgress | undefined) => void,
timeToFirstToken: number,
tokenUsageRef: { current: CompletionUsage | undefined },
continueFromMessageId?: string,
updateMessage?: (message: ThreadMessage) => void,
continueFromMessage?: ThreadMessage
) => {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const flushStreamingContent = () => {
const currentContent = createThreadContent(
activeThread.id,
accumulatedText.value,
toolCalls,
continueFromMessageId
)
// When continuing, update the message directly instead of using streamingContent
if (continueFromMessageId && updateMessage && continueFromMessage) {
updateMessage({
...continueFromMessage, // Preserve original message metadata
content: currentContent.content, // Update content
status: MessageStatus.Stopped, // Keep as Stopped while streaming
})
} else {
updateStreamingContent(currentContent)
}
if (tokenUsageRef.current) {
setTokenSpeed(
currentContent,
tokenUsageRef.current.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsageRef.current.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
flushStreamingContent()
})
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if ('usage' in part && part.usage) {
tokenUsageRef.current = part.usage
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText.value += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText.value += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
cancelFrame(rafHandle)
rafHandle = undefined
rafScheduled = false
// Finalize reasoning (close any open think tags)
accumulatedText.value += reasoningProcessor.finalize()
}
}
export const useChat = () => { export const useChat = () => {
const [ const [
updateTokenSpeed, updateTokenSpeed,
@ -86,6 +279,7 @@ export const useChat = () => {
const getMessages = useMessages((state) => state.getMessages) const getMessages = useMessages((state) => state.getMessages)
const addMessage = useMessages((state) => state.addMessage) const addMessage = useMessages((state) => state.addMessage)
const updateMessage = useMessages((state) => state.updateMessage)
const setMessages = useMessages((state) => state.setMessages) const setMessages = useMessages((state) => state.setMessages)
const setModelLoadError = useModelLoad((state) => state.setModelLoadError) const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
const router = useRouter() const router = useRouter()
@ -149,7 +343,7 @@ export const useChat = () => {
}) })
} }
return currentThread return currentThread
}, [createThread, retrieveThread, router, setMessages]) }, [createThread, retrieveThread, router, setMessages, serviceHub])
const restartModel = useCallback( const restartModel = useCallback(
async (provider: ProviderObject, modelId: string) => { async (provider: ProviderObject, modelId: string) => {
@ -264,7 +458,8 @@ export const useChat = () => {
base64: string base64: string
dataUrl: string dataUrl: string
}>, }>,
projectId?: string projectId?: string,
continueFromMessageId?: string
) => { ) => {
const activeThread = await getCurrentThread(projectId) const activeThread = await getCurrentThread(projectId)
const selectedProvider = useModelProvider.getState().selectedProvider const selectedProvider = useModelProvider.getState().selectedProvider
@ -277,26 +472,54 @@ export const useChat = () => {
setAbortController(activeThread.id, abortController) setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent) updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined) updatePromptProgress(undefined)
// Do not add new message on retry
if (troubleshooting) // Find the message to continue from if provided
const continueFromMessage = continueFromMessageId
? messages.find((m) => m.id === continueFromMessageId)
: undefined
// Do not add new message on retry or when continuing
if (troubleshooting && !continueFromMessageId)
addMessage(newUserThreadContent(activeThread.id, message, attachments)) addMessage(newUserThreadContent(activeThread.id, message, attachments))
updateThreadTimestamp(activeThread.id) updateThreadTimestamp(activeThread.id)
usePrompt.getState().setPrompt('') usePrompt.getState().setPrompt('')
const selectedModel = useModelProvider.getState().selectedModel const selectedModel = useModelProvider.getState().selectedModel
// If continuing, start with the previous content
const accumulatedTextRef = {
value: continueFromMessage?.content?.[0]?.text?.value || ''
}
let currentAssistant: Assistant | undefined | null
try { try {
if (selectedModel?.id) { if (selectedModel?.id) {
updateLoadingModel(true) updateLoadingModel(true)
await serviceHub.models().startModel(activeProvider, selectedModel.id) await serviceHub.models().startModel(activeProvider, selectedModel.id)
updateLoadingModel(false) updateLoadingModel(false)
} }
const currentAssistant = useAssistant.getState().currentAssistant currentAssistant = useAssistant.getState().currentAssistant
// Filter out the stopped message from context if continuing
const contextMessages = continueFromMessageId
? messages.filter((m) => m.id !== continueFromMessageId)
: messages
const builder = new CompletionMessagesBuilder( const builder = new CompletionMessagesBuilder(
messages, contextMessages,
currentAssistant currentAssistant
? renderInstructions(currentAssistant.instructions) ? renderInstructions(currentAssistant.instructions)
: undefined : undefined
) )
if (troubleshooting) builder.addUserMessage(message, attachments) if (troubleshooting && !continueFromMessageId) {
builder.addUserMessage(message, attachments)
} else if (continueFromMessage) {
// When continuing, add the partial assistant response to the context
builder.addAssistantMessage(
continueFromMessage.content?.[0]?.text?.value || '',
undefined,
[]
)
}
let isCompleted = false let isCompleted = false
@ -349,7 +572,6 @@ export const useChat = () => {
) )
if (!completion) throw new Error('No completion received') if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = [] const toolCalls: ChatCompletionMessageToolCall[] = []
const timeToFirstToken = Date.now() const timeToFirstToken = Date.now()
@ -357,13 +579,19 @@ export const useChat = () => {
try { try {
if (isCompletionResponse(completion)) { if (isCompletionResponse(completion)) {
const message = completion.choices[0]?.message const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || '' // When continuing, append to existing content; otherwise replace
const newContent = (message?.content as string) || ''
if (continueFromMessageId && accumulatedTextRef.value) {
accumulatedTextRef.value += newContent
} else {
accumulatedTextRef.value = newContent
}
// Handle reasoning field if there is one // Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message) const reasoning = extractReasoningFromMessage(message)
if (reasoning) { if (reasoning) {
accumulatedText = accumulatedTextRef.value =
`<think>${reasoning}</think>` + accumulatedText `<think>${reasoning}</think>` + accumulatedTextRef.value
} }
if (message?.tool_calls) { if (message?.tool_calls) {
@ -373,161 +601,25 @@ export const useChat = () => {
tokenUsage = completion.usage tokenUsage = completion.usage
} }
} else { } else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame) const tokenUsageRef = { current: tokenUsage }
let rafScheduled = false await processStreamingCompletion(
let rafHandle: number | undefined completion,
let pendingDeltaCount = 0 abortController,
const reasoningProcessor = new ReasoningProcessor() activeThread,
const scheduleFlush = () => { accumulatedTextRef,
if (rafScheduled || abortController.signal.aborted) return toolCalls,
rafScheduled = true currentCall,
const doSchedule = (cb: () => void) => { updateStreamingContent,
if (typeof requestAnimationFrame !== 'undefined') { updateTokenSpeed,
rafHandle = requestAnimationFrame(() => cb()) setTokenSpeed,
} else { updatePromptProgress,
// Fallback for non-browser test environments timeToFirstToken,
const t = setTimeout(() => cb(), 0) as unknown as number tokenUsageRef,
rafHandle = t continueFromMessageId,
} updateMessage,
} continueFromMessage
doSchedule(() => { )
// Check abort status before executing the scheduled callback tokenUsage = tokenUsageRef.current
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (tokenUsage) {
setTokenSpeed(
currentContent,
tokenUsage.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsage.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
})
}
const flushIfPending = () => {
if (!rafScheduled) return
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
}
// Do an immediate flush
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (tokenUsage) {
setTokenSpeed(
currentContent,
tokenUsage.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsage.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if ('usage' in part && part.usage) {
tokenUsage = part.usage
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
} }
} catch (error) { } catch (error) {
const errorMessage = const errorMessage =
@ -561,7 +653,7 @@ export const useChat = () => {
} }
// TODO: Remove this check when integrating new llama.cpp extension // TODO: Remove this check when integrating new llama.cpp extension
if ( if (
accumulatedText.length === 0 && accumulatedTextRef.value.length === 0 &&
toolCalls.length === 0 && toolCalls.length === 0 &&
activeThread.model?.id && activeThread.model?.id &&
activeProvider?.provider === 'llamacpp' activeProvider?.provider === 'llamacpp'
@ -573,16 +665,29 @@ export const useChat = () => {
} }
// Create a final content object for adding to the thread // Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent( let finalContent = newAssistantThreadContent(
activeThread.id, activeThread.id,
accumulatedText, accumulatedTextRef.value,
{ {
tokenSpeed: useAppState.getState().tokenSpeed, tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant, assistant: currentAssistant,
} }
) )
builder.addAssistantMessage(accumulatedText, undefined, toolCalls) // If continuing from a message, preserve the ID and set status to Ready
if (continueFromMessageId) {
finalContent = {
...finalContent,
id: continueFromMessageId,
status: MessageStatus.Ready,
}
}
// Normal completion flow (abort is handled after loop exits)
// Don't add assistant message to builder if continuing - it's already there
if (!continueFromMessageId) {
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
}
const updatedMessage = await postMessageProcessing( const updatedMessage = await postMessageProcessing(
toolCalls, toolCalls,
builder, builder,
@ -592,10 +697,15 @@ export const useChat = () => {
allowAllMCPPermissions ? undefined : showApprovalModal, allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions allowAllMCPPermissions
) )
addMessage(updatedMessage ?? finalContent) finalizeMessage(
updateStreamingContent(emptyThreadContent) updatedMessage ?? finalContent,
updatePromptProgress(undefined) addMessage,
updateThreadTimestamp(activeThread.id) updateStreamingContent,
updatePromptProgress,
updateThreadTimestamp,
updateMessage,
continueFromMessageId
)
isCompleted = !toolCalls.length isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it // Do not create agent loop if there is no need for it
@ -605,8 +715,109 @@ export const useChat = () => {
availableTools = [] availableTools = []
} }
} }
// IMPORTANT: Check if aborted AFTER the while loop exits
// The while loop exits when abort is true, so we handle it here
// Only save interrupted messages for llamacpp provider
// Other providers (OpenAI, Claude, etc.) handle streaming differently
if (
abortController.signal.aborted &&
accumulatedTextRef.value.length > 0 &&
activeProvider?.provider === 'llamacpp'
) {
// If continuing, update the existing message; otherwise add new
if (continueFromMessageId && continueFromMessage) {
// Preserve the original message metadata
updateMessage({
...continueFromMessage,
content: [
{
type: ContentType.Text,
text: {
value: accumulatedTextRef.value,
annotations: [],
},
},
],
status: MessageStatus.Stopped,
metadata: {
...continueFromMessage.metadata,
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
},
})
} else {
// Create final content for the partial message with Stopped status
const partialContent = {
...newAssistantThreadContent(
activeThread.id,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
),
status: MessageStatus.Stopped,
}
addMessage(partialContent)
}
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
}
} catch (error) { } catch (error) {
if (!abortController.signal.aborted) { // If aborted, save the partial message even though an error occurred
// Only save for llamacpp provider - other providers handle streaming differently
const streamingContent = useAppState.getState().streamingContent
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
(streamingContent && streamingContent.content?.[0]?.text?.value)
if (
abortController.signal.aborted &&
hasPartialContent &&
activeProvider?.provider === 'llamacpp'
) {
// Use streaming content if available, otherwise use accumulatedTextRef
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
// If continuing, update the existing message; otherwise add new
if (continueFromMessageId && continueFromMessage) {
// Preserve the original message metadata
updateMessage({
...continueFromMessage,
content: [
{
type: ContentType.Text,
text: {
value: contentText,
annotations: [],
},
},
],
status: MessageStatus.Stopped,
metadata: {
...continueFromMessage.metadata,
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
},
})
} else {
const partialContent = {
...newAssistantThreadContent(
activeThread.id,
contentText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
),
status: MessageStatus.Stopped,
}
addMessage(partialContent)
}
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
} else if (!abortController.signal.aborted) {
// Only show error if not aborted
if (error && typeof error === 'object' && 'message' in error) { if (error && typeof error === 'object' && 'message' in error) {
setModelLoadError(error as ErrorObject) setModelLoadError(error as ErrorObject)
} else { } else {
@ -628,12 +839,14 @@ export const useChat = () => {
updateStreamingContent, updateStreamingContent,
updatePromptProgress, updatePromptProgress,
addMessage, addMessage,
updateMessage,
updateThreadTimestamp, updateThreadTimestamp,
updateLoadingModel, updateLoadingModel,
getDisabledToolsForThread, getDisabledToolsForThread,
allowAllMCPPermissions, allowAllMCPPermissions,
showApprovalModal, showApprovalModal,
updateTokenSpeed, updateTokenSpeed,
setTokenSpeed,
showIncreaseContextSizeModal, showIncreaseContextSizeModal,
increaseModelContextSize, increaseModelContextSize,
toggleOnContextShifting, toggleOnContextShifting,

View File

@ -8,6 +8,7 @@ type MessageState = {
getMessages: (threadId: string) => ThreadMessage[] getMessages: (threadId: string) => ThreadMessage[]
setMessages: (threadId: string, messages: ThreadMessage[]) => void setMessages: (threadId: string, messages: ThreadMessage[]) => void
addMessage: (message: ThreadMessage) => void addMessage: (message: ThreadMessage) => void
updateMessage: (message: ThreadMessage) => void
deleteMessage: (threadId: string, messageId: string) => void deleteMessage: (threadId: string, messageId: string) => void
clearAllMessages: () => void clearAllMessages: () => void
} }
@ -40,16 +41,52 @@ export const useMessages = create<MessageState>()((set, get) => ({
assistant: selectedAssistant, assistant: selectedAssistant,
}, },
} }
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
set((state) => ({ // Optimistically update state immediately for instant UI feedback
messages: { set((state) => ({
...state.messages, messages: {
[message.thread_id]: [ ...state.messages,
...(state.messages[message.thread_id] || []), [message.thread_id]: [
createdMessage, ...(state.messages[message.thread_id] || []),
], newMessage,
}, ],
})) },
}))
// Persist to storage asynchronously
getServiceHub().messages().createMessage(newMessage).catch((error) => {
console.error('Failed to persist message:', error)
})
},
updateMessage: (message) => {
const assistants = useAssistant.getState().assistants
const currentAssistant = useAssistant.getState().currentAssistant
const selectedAssistant =
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
const updatedMessage = {
...message,
metadata: {
...message.metadata,
assistant: selectedAssistant,
},
}
// Optimistically update state immediately for instant UI feedback
set((state) => ({
messages: {
...state.messages,
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
m.id === message.id ? updatedMessage : m
),
},
}))
// Persist to storage asynchronously using modifyMessage instead of createMessage
// to prevent duplicates when updating existing messages
getServiceHub().messages().modifyMessage(updatedMessage).catch((error) => {
console.error('Failed to persist message update:', error)
}) })
}, },
deleteMessage: (threadId, messageId) => { deleteMessage: (threadId, messageId) => {

View File

@ -23,8 +23,9 @@ export const useTools = () => {
updateTools(data) updateTools(data)
// Initialize default disabled tools for new users (only once) // Initialize default disabled tools for new users (only once)
if (!isDefaultsInitialized() && data.length > 0 && mcpExtension?.getDefaultDisabledTools) { const mcpExt = mcpExtension as MCPExtension & { getDefaultDisabledTools?: () => Promise<string[]> }
const defaultDisabled = await mcpExtension.getDefaultDisabledTools() if (!isDefaultsInitialized() && data.length > 0 && mcpExt?.getDefaultDisabledTools) {
const defaultDisabled = await mcpExt.getDefaultDisabledTools()
if (defaultDisabled.length > 0) { if (defaultDisabled.length > 0) {
setDefaultDisabledTools(defaultDisabled) setDefaultDisabledTools(defaultDisabled)
markDefaultsAsInitialized() markDefaultsAsInitialized()

View File

@ -135,6 +135,7 @@
"enterApiKey": "Enter API Key", "enterApiKey": "Enter API Key",
"scrollToBottom": "Scroll to bottom", "scrollToBottom": "Scroll to bottom",
"generateAiResponse": "Generate AI Response", "generateAiResponse": "Generate AI Response",
"continueAiResponse": "Continue with AI Response",
"addModel": { "addModel": {
"title": "Add Model", "title": "Add Model",
"modelId": "Model ID", "modelId": "Model ID",

View File

@ -40,6 +40,20 @@ export class DefaultMessagesService implements MessagesService {
) )
} }
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
// Don't modify messages on server for temporary chat - it's local only
if (message.thread_id === TEMPORARY_CHAT_ID) {
return message
}
return (
ExtensionManager.getInstance()
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyMessage(message)
?.catch(() => message) ?? message
)
}
async deleteMessage(threadId: string, messageId: string): Promise<void> { async deleteMessage(threadId: string, messageId: string): Promise<void> {
// Don't delete messages on server for temporary chat - it's local only // Don't delete messages on server for temporary chat - it's local only
if (threadId === TEMPORARY_CHAT_ID) { if (threadId === TEMPORARY_CHAT_ID) {

View File

@ -7,5 +7,6 @@ import { ThreadMessage } from '@janhq/core'
export interface MessagesService { export interface MessagesService {
fetchMessages(threadId: string): Promise<ThreadMessage[]> fetchMessages(threadId: string): Promise<ThreadMessage[]>
createMessage(message: ThreadMessage): Promise<ThreadMessage> createMessage(message: ThreadMessage): Promise<ThreadMessage>
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
deleteMessage(threadId: string, messageId: string): Promise<void> deleteMessage(threadId: string, messageId: string): Promise<void>
} }