Compare commits

...

12 Commits

Author SHA1 Message Date
Nghia Doan
35264e9a22
Merge branch 'dev' into feat/retain-interruption-message 2025-10-03 14:10:43 +07:00
Vanalite
e7c9275488 fix: Fix tests on useChat 2025-10-03 10:17:27 +07:00
Vanalite
4ac45aba23 Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message 2025-10-03 09:47:08 +07:00
Vanalite
34036d895a Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message
# Conflicts:
#	web-app/src/containers/ChatInput.tsx
#	web-app/src/hooks/useChat.ts
2025-10-02 10:59:19 +07:00
Vanalite
7127ff1244 fix: Exposing PromptProgress to be passed as param 2025-10-01 21:52:30 +07:00
Vanalite
1c0e135077 Merge remote-tracking branch 'origin/dev' into feat/retain-interruption-message
# Conflicts:
#	web-app/src/hooks/useChat.ts
2025-10-01 19:35:23 +07:00
Vanalite
99473ed568 fix: Consolidate comments 2025-10-01 19:21:34 +07:00
Vanalite
52f73af08c feat: Add tests for the Continuing with AI response 2025-10-01 19:13:10 +07:00
Vanalite
ccca331d6c feat: Modify on-going response instead of creating new message to avoid message ID duplication 2025-10-01 17:14:59 +07:00
Vanalite
f4b187ba11 feat: Continue with AI response for llamacpp 2025-10-01 16:43:27 +07:00
Vanalite
4ea9d296ea feat: Continue with AI response button if it got interrupted 2025-10-01 16:05:58 +07:00
Vanalite
2e86d4e421 feat: Allow to save the last message upon interrupting llm response 2025-10-01 15:43:05 +07:00
14 changed files with 967 additions and 256 deletions

View File

@ -20,6 +20,13 @@ export interface MessageInterface {
*/
listMessages(threadId: string): Promise<ThreadMessage[]>
/**
* Updates an existing message in a thread.
* @param {ThreadMessage} message - The message to be updated (must have existing ID).
* @returns {Promise<ThreadMessage>} A promise that resolves to the updated message.
*/
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
/**
* Deletes a specific message from a thread.
* @param {string} threadId - The ID of the thread from which the message will be deleted.

View File

@ -176,7 +176,6 @@ const ChatInput = ({
const mcpExtension = extensionManager.get<MCPExtension>(ExtensionTypeEnum.MCP)
const MCPToolComponent = mcpExtension?.getToolComponent?.()
const handleSendMesage = async (prompt: string) => {
if (!selectedModel) {
setMessage('Please select a model to start chatting.')

View File

@ -3,6 +3,8 @@ import { useMessages } from '@/hooks/useMessages'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { Play } from 'lucide-react'
import { useShallow } from 'zustand/react/shallow'
import { useMemo } from 'react'
import { MessageStatus } from '@janhq/core'
export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
const { t } = useTranslation()
@ -13,7 +15,36 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
}))
)
const sendMessage = useChat()
// Detect if last message is a partial assistant response (user stopped midway)
const isPartialResponse = useMemo(() => {
if (!messages || messages.length < 2) return false
const lastMessage = messages[messages.length - 1]
const secondLastMessage = messages[messages.length - 2]
return (
lastMessage?.role === 'assistant' &&
lastMessage?.status === MessageStatus.Stopped &&
secondLastMessage?.role === 'user' &&
!lastMessage?.metadata?.tool_calls
)
}, [messages])
const generateAIResponse = () => {
if (isPartialResponse) {
const partialMessage = messages[messages.length - 1]
const userMessage = messages[messages.length - 2]
if (userMessage?.content?.[0]?.text?.value) {
sendMessage(
userMessage.content[0].text.value,
false,
undefined,
partialMessage.id
)
}
return
}
const latestUserMessage = messages[messages.length - 1]
if (
latestUserMessage?.content?.[0]?.text?.value &&
@ -39,7 +70,11 @@ export const GenerateResponseButton = ({ threadId }: { threadId: string }) => {
className="mx-2 bg-main-view-fg/10 px-2 border border-main-view-fg/5 flex items-center justify-center rounded-xl gap-x-2 cursor-pointer pointer-events-auto"
onClick={generateAIResponse}
>
<p className="text-xs">{t('common:generateAiResponse')}</p>
<p className="text-xs">
{isPartialResponse
? t('common:continueAiResponse')
: t('common:generateAiResponse')}
</p>
<Play size={12} />
</div>
)

View File

@ -8,6 +8,7 @@ import { cn } from '@/lib/utils'
import { ArrowDown } from 'lucide-react'
import { useTranslation } from '@/i18n/react-i18next-compat'
import { useAppState } from '@/hooks/useAppState'
import { MessageStatus } from '@janhq/core'
const ScrollToBottom = ({
threadId,
@ -28,11 +29,20 @@ const ScrollToBottom = ({
const streamingContent = useAppState((state) => state.streamingContent)
// Check if last message is a partial assistant response and show continue buton (user interrupted)
const isPartialResponse =
messages.length >= 2 &&
messages[messages.length - 1]?.role === 'assistant' &&
messages[messages.length - 1]?.status === MessageStatus.Stopped &&
messages[messages.length - 2]?.role === 'user' &&
!messages[messages.length - 1]?.metadata?.tool_calls
const showGenerateAIResponseBtn =
(messages[messages.length - 1]?.role === 'user' ||
((messages[messages.length - 1]?.role === 'user' ||
(messages[messages.length - 1]?.metadata &&
'tool_calls' in (messages[messages.length - 1].metadata ?? {}))) &&
!streamingContent
'tool_calls' in (messages[messages.length - 1].metadata ?? {})) ||
isPartialResponse) &&
!streamingContent)
return (
<div

View File

@ -2,6 +2,7 @@ import { useAppState } from '@/hooks/useAppState'
import { ThreadContent } from './ThreadContent'
import { memo, useMemo } from 'react'
import { useMessages } from '@/hooks/useMessages'
import { MessageStatus } from '@janhq/core'
type Props = {
threadId: string
@ -48,12 +49,19 @@ export const StreamingContent = memo(({ threadId }: Props) => {
return extractReasoningSegment(text)
}, [lastAssistant])
if (!streamingContent || streamingContent.thread_id !== threadId) return null
if (!streamingContent || streamingContent.thread_id !== threadId) {
return null
}
if (streamingReasoning && streamingReasoning === lastAssistantReasoning) {
return null
}
// Don't show streaming content if there's already a stopped message
if (lastAssistant?.status === MessageStatus.Stopped) {
return null
}
// Pass a new object to ThreadContent to avoid reference issues
// The streaming content is always the last message
return (

View File

@ -1,6 +1,30 @@
import { renderHook, act } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { useChat } from '../useChat'
import { renderHook, act, waitFor } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { MessageStatus, ContentType } from '@janhq/core'
// Store mock functions for assertions - initialize immediately
const mockAddMessage = vi.fn()
const mockUpdateMessage = vi.fn()
const mockGetMessages = vi.fn(() => [])
const mockStartModel = vi.fn(() => Promise.resolve())
const mockSendCompletion = vi.fn(() => Promise.resolve({
choices: [{
message: {
content: 'AI response',
role: 'assistant',
},
}],
}))
const mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
Promise.resolve(content)
)
const mockCompletionMessagesBuilder = {
addUserMessage: vi.fn(),
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []),
}
const mockSetPrompt = vi.fn()
const mockResetTokenSpeed = vi.fn()
// Mock dependencies
vi.mock('../usePrompt', () => ({
@ -8,11 +32,16 @@ vi.mock('../usePrompt', () => ({
(selector: any) => {
const state = {
prompt: 'test prompt',
setPrompt: vi.fn(),
setPrompt: mockSetPrompt,
}
return selector ? selector(state) : state
},
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) }
{
getState: () => ({
prompt: 'test prompt',
setPrompt: mockSetPrompt
})
}
),
}))
@ -22,25 +51,29 @@ vi.mock('../useAppState', () => ({
const state = {
tools: [],
updateTokenSpeed: vi.fn(),
resetTokenSpeed: vi.fn(),
resetTokenSpeed: mockResetTokenSpeed,
updateTools: vi.fn(),
updateStreamingContent: vi.fn(),
updatePromptProgress: vi.fn(),
updateLoadingModel: vi.fn(),
setAbortController: vi.fn(),
streamingContent: undefined,
}
return selector ? selector(state) : state
},
{
getState: vi.fn(() => ({
tools: [],
tokenSpeed: { tokensPerSecond: 10 },
streamingContent: undefined,
}))
}
),
}))
vi.mock('../useAssistant', () => ({
useAssistant: (selector: any) => {
useAssistant: Object.assign(
(selector: any) => {
const state = {
assistants: [{
id: 'test-assistant',
@ -55,6 +88,21 @@ vi.mock('../useAssistant', () => ({
}
return selector ? selector(state) : state
},
{
getState: () => ({
assistants: [{
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
}],
currentAssistant: {
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
},
})
}
),
}))
vi.mock('../useModelProvider', () => ({
@ -62,14 +110,15 @@ vi.mock('../useModelProvider', () => ({
(selector: any) => {
const state = {
getProviderByName: vi.fn(() => ({
provider: 'openai',
provider: 'llamacpp',
models: [],
settings: [],
})),
selectedModel: {
id: 'test-model',
capabilities: ['tools'],
},
selectedProvider: 'openai',
selectedProvider: 'llamacpp',
updateProvider: vi.fn(),
}
return selector ? selector(state) : state
@ -77,14 +126,15 @@ vi.mock('../useModelProvider', () => ({
{
getState: () => ({
getProviderByName: vi.fn(() => ({
provider: 'openai',
provider: 'llamacpp',
models: [],
settings: [],
})),
selectedModel: {
id: 'test-model',
capabilities: ['tools'],
},
selectedProvider: 'openai',
selectedProvider: 'llamacpp',
updateProvider: vi.fn(),
})
}
@ -96,11 +146,11 @@ vi.mock('../useThreads', () => ({
const state = {
getCurrentThread: vi.fn(() => ({
id: 'test-thread',
model: { id: 'test-model', provider: 'openai' },
model: { id: 'test-model', provider: 'llamacpp' },
})),
createThread: vi.fn(() => Promise.resolve({
id: 'test-thread',
model: { id: 'test-model', provider: 'openai' },
model: { id: 'test-model', provider: 'llamacpp' },
})),
updateThreadTimestamp: vi.fn(),
}
@ -111,15 +161,18 @@ vi.mock('../useThreads', () => ({
vi.mock('../useMessages', () => ({
useMessages: (selector: any) => {
const state = {
getMessages: vi.fn(() => []),
addMessage: vi.fn(),
getMessages: mockGetMessages,
addMessage: mockAddMessage,
updateMessage: mockUpdateMessage,
setMessages: vi.fn(),
}
return selector ? selector(state) : state
},
}))
vi.mock('../useToolApproval', () => ({
useToolApproval: (selector: any) => {
useToolApproval: Object.assign(
(selector: any) => {
const state = {
approvedTools: [],
showApprovalModal: vi.fn(),
@ -127,6 +180,14 @@ vi.mock('../useToolApproval', () => ({
}
return selector ? selector(state) : state
},
{
getState: () => ({
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
})
}
),
}))
vi.mock('../useToolAvailable', () => ({
@ -162,38 +223,57 @@ vi.mock('@tanstack/react-router', () => ({
})),
}))
vi.mock('../useServiceHub', () => ({
useServiceHub: vi.fn(() => ({
models: () => ({
startModel: mockStartModel,
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}),
providers: () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}),
})),
}))
vi.mock('@/lib/completion', () => ({
emptyThreadContent: { thread_id: 'test-thread', content: '' },
extractToolCall: vi.fn(),
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
sendCompletion: vi.fn(),
postMessageProcessing: vi.fn(),
isCompletionResponse: vi.fn(),
newUserThreadContent: vi.fn((threadId, content) => ({
thread_id: threadId,
content: [{ type: 'text', text: { value: content, annotations: [] } }],
role: 'user'
})),
newAssistantThreadContent: vi.fn((threadId, content) => ({
thread_id: threadId,
content: [{ type: 'text', text: { value: content, annotations: [] } }],
role: 'assistant'
})),
sendCompletion: mockSendCompletion,
postMessageProcessing: mockPostMessageProcessing,
isCompletionResponse: vi.fn(() => true),
}))
vi.mock('@/lib/messages', () => ({
CompletionMessagesBuilder: vi.fn(() => ({
addUserMessage: vi.fn(),
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []),
CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
}))
vi.mock('@/lib/instructionTemplate', () => ({
renderInstructions: vi.fn((instructions: string) => instructions),
}))
vi.mock('@/utils/reasoning', () => ({
ReasoningProcessor: vi.fn(() => ({
processReasoningChunk: vi.fn(() => null),
finalize: vi.fn(() => ''),
})),
extractReasoningFromMessage: vi.fn(() => null),
}))
vi.mock('@/services/mcp', () => ({
getTools: vi.fn(() => Promise.resolve([])),
}))
vi.mock('@/services/models', () => ({
startModel: vi.fn(() => Promise.resolve()),
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/services/providers', () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}))
vi.mock('@tauri-apps/api/event', () => ({
listen: vi.fn(() => Promise.resolve(vi.fn())),
}))
@ -204,9 +284,41 @@ vi.mock('sonner', () => ({
},
}))
// Import after mocks to avoid hoisting issues
const { useChat } = await import('../useChat')
const completionLib = await import('@/lib/completion')
const messagesLib = await import('@/lib/messages')
describe('useChat', () => {
beforeEach(() => {
// Clear mock call history
vi.clearAllMocks()
// Reset mock implementations
mockAddMessage.mockClear()
mockUpdateMessage.mockClear()
mockGetMessages.mockReturnValue([])
mockStartModel.mockResolvedValue(undefined)
mockSetPrompt.mockClear()
mockResetTokenSpeed.mockClear()
mockSendCompletion.mockResolvedValue({
choices: [{
message: {
content: 'AI response',
role: 'assistant',
},
}],
})
mockPostMessageProcessing.mockImplementation((toolCalls, builder, content) =>
Promise.resolve(content)
)
mockCompletionMessagesBuilder.addUserMessage.mockClear()
mockCompletionMessagesBuilder.addAssistantMessage.mockClear()
mockCompletionMessagesBuilder.getMessages.mockReturnValue([])
})
afterEach(() => {
vi.clearAllTimers()
})
it('returns sendMessage function', () => {
@ -216,13 +328,270 @@ describe('useChat', () => {
expect(typeof result.current).toBe('function')
})
it('sends message successfully', async () => {
describe('Continue with AI response functionality', () => {
it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Hello world', true, undefined, undefined, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Hello world',
undefined
)
expect(mockAddMessage).toHaveBeenCalledWith(
expect.objectContaining({
thread_id: 'test-thread',
role: 'user',
})
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Hello world',
undefined
)
})
it('should NOT add new user message when continueFromMessageId is provided', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
const userMessageCalls = mockAddMessage.mock.calls.filter(
(call: any) => call[0]?.role === 'user'
)
expect(userMessageCalls).toHaveLength(0)
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
})
it('should add partial assistant message to builder when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// Should be called twice: once with partial message (line 517-521), once after completion (line 689)
const assistantCalls = mockCompletionMessagesBuilder.addAssistantMessage.mock.calls
expect(assistantCalls.length).toBeGreaterThanOrEqual(1)
// First call should be with the partial response content
expect(assistantCalls[0]).toEqual([
'Partial response',
undefined,
[]
])
})
it('should filter out stopped message from context when continuing', async () => {
const userMsg = {
id: 'msg-1',
thread_id: 'test-thread',
role: 'user',
content: [{ type: 'text', text: { value: 'Hello', annotations: [] } }],
}
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
}
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// The CompletionMessagesBuilder is called with filtered messages (line 507-512)
// The stopped message should be filtered out from the context
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalled()
const builderCall = (messagesLib.CompletionMessagesBuilder as any).mock.calls[0]
expect(builderCall[0]).toEqual([userMsg]) // stopped message filtered out
expect(builderCall[1]).toEqual('test instructions')
})
it('should update existing message instead of adding new one when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// finalizeMessage is called at line 700-708, which should update the message
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
it('should start with previous content when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
mockSendCompletion.mockResolvedValue({
choices: [{
message: {
content: ' continued',
role: 'assistant',
},
}],
})
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// The accumulated text should contain the previous content plus new content
// accumulatedTextRef starts with 'Partial response' (line 490)
// Then gets ' continued' appended (line 585)
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
content: expect.arrayContaining([
expect.objectContaining({
text: expect.objectContaining({
value: 'Partial response continued',
})
})
])
})
)
})
it('should handle attachments correctly when not continuing', async () => {
const { result } = renderHook(() => useChat())
const attachments = [
{
name: 'test.png',
type: 'image/png',
size: 1024,
base64: 'base64data',
dataUrl: 'data:image/png;base64,base64data',
},
]
await act(async () => {
await result.current('Message with attachment', true, attachments, undefined, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Message with attachment',
attachments
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Message with attachment',
attachments
)
})
it('should preserve message status as Ready after continuation completes', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
// finalContent is created at line 678-683 with status Ready when continuing
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
})
describe('Normal message sending', () => {
it('sends message successfully without continuation', async () => {
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Hello world')
})
expect(mockSendCompletion).toHaveBeenCalled()
expect(mockStartModel).toHaveBeenCalled()
})
})
describe('Error handling', () => {
it('should handle errors gracefully during continuation', async () => {
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: 'text', text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, undefined, 'msg-123')
})
expect(result.current).toBeDefined()
})
})
})

View File

@ -225,9 +225,25 @@ describe('useMessages', () => {
})
)
// Wait for async operation
// Message should be immediately available (optimistic update)
expect(result.current.messages['thread1']).toContainEqual(
expect.objectContaining({
id: messageToAdd.id,
thread_id: messageToAdd.thread_id,
role: messageToAdd.role,
content: messageToAdd.content,
metadata: expect.objectContaining({
assistant: expect.objectContaining({
id: expect.any(String),
name: expect.any(String),
}),
}),
})
)
// Verify persistence was attempted
await vi.waitFor(() => {
expect(result.current.messages['thread1']).toContainEqual(mockCreatedMessage)
expect(mockCreateMessage).toHaveBeenCalled()
})
})

View File

@ -4,7 +4,7 @@ import { MCPTool } from '@/types/completion'
import { useAssistant } from './useAssistant'
import { ChatCompletionMessageToolCall } from 'openai/resources'
type PromptProgress = {
export type PromptProgress = {
cache: number
processed: number
time_ms: number

View File

@ -3,7 +3,7 @@ import { flushSync } from 'react-dom'
import { usePrompt } from './usePrompt'
import { useModelProvider } from './useModelProvider'
import { useThreads } from './useThreads'
import { useAppState } from './useAppState'
import { useAppState, type PromptProgress } from './useAppState'
import { useMessages } from './useMessages'
import { useRouter } from '@tanstack/react-router'
import { defaultModel } from '@/lib/models'
@ -23,6 +23,7 @@ import {
ChatCompletionMessageToolCall,
CompletionUsage,
} from 'openai/resources'
import { MessageStatus, ContentType, ThreadMessage } from '@janhq/core'
import { useServiceHub } from '@/hooks/useServiceHub'
import { useToolApproval } from '@/hooks/useToolApproval'
@ -38,6 +39,198 @@ import { useAssistant } from './useAssistant'
import { useShallow } from 'zustand/shallow'
import { TEMPORARY_CHAT_QUERY_ID, TEMPORARY_CHAT_ID } from '@/constants/chat'
// Helper to create thread content with consistent structure
const createThreadContent = (
threadId: string,
text: string,
toolCalls: ChatCompletionMessageToolCall[],
messageId?: string
) => {
const content = newAssistantThreadContent(threadId, text, {
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
})
// If continuing from a message, preserve the message ID
if (messageId) {
return { ...content, id: messageId }
}
return content
}
// Helper to cancel animation frame cross-platform
const cancelFrame = (handle: number | undefined) => {
if (handle === undefined) return
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(handle)
} else {
clearTimeout(handle)
}
}
// Helper to finalize and save a message
const finalizeMessage = (
finalContent: ThreadMessage,
addMessage: (message: ThreadMessage) => void,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updatePromptProgress: (progress: PromptProgress | undefined) => void,
updateThreadTimestamp: (threadId: string) => void,
updateMessage?: (message: ThreadMessage) => void,
continueFromMessageId?: string
) => {
// If continuing from a message, update it; otherwise add new message
if (continueFromMessageId && updateMessage) {
updateMessage({ ...finalContent, id: continueFromMessageId })
} else {
addMessage(finalContent)
}
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(finalContent.thread_id)
}
// Helper to process streaming completion
const processStreamingCompletion = async (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
completion: AsyncIterable<any>,
abortController: AbortController,
activeThread: Thread,
accumulatedText: { value: string },
toolCalls: ChatCompletionMessageToolCall[],
currentCall: ChatCompletionMessageToolCall | null,
updateStreamingContent: (content: ThreadMessage | undefined) => void,
updateTokenSpeed: (message: ThreadMessage, increment?: number) => void,
setTokenSpeed: (message: ThreadMessage, tokensPerSecond: number, totalTokens: number) => void,
updatePromptProgress: (progress: PromptProgress | undefined) => void,
timeToFirstToken: number,
tokenUsageRef: { current: CompletionUsage | undefined },
continueFromMessageId?: string,
updateMessage?: (message: ThreadMessage) => void,
continueFromMessage?: ThreadMessage
) => {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const flushStreamingContent = () => {
const currentContent = createThreadContent(
activeThread.id,
accumulatedText.value,
toolCalls,
continueFromMessageId
)
// When continuing, update the message directly instead of using streamingContent
if (continueFromMessageId && updateMessage && continueFromMessage) {
updateMessage({
...continueFromMessage, // Preserve original message metadata
content: currentContent.content, // Update content
status: MessageStatus.Stopped, // Keep as Stopped while streaming
})
} else {
updateStreamingContent(currentContent)
}
if (tokenUsageRef.current) {
setTokenSpeed(
currentContent,
tokenUsageRef.current.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsageRef.current.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
flushStreamingContent()
})
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if ('usage' in part && part.usage) {
tokenUsageRef.current = part.usage
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning = reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText.value += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText.value += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
cancelFrame(rafHandle)
rafHandle = undefined
rafScheduled = false
// Finalize reasoning (close any open think tags)
accumulatedText.value += reasoningProcessor.finalize()
}
}
export const useChat = () => {
const [
updateTokenSpeed,
@ -86,6 +279,7 @@ export const useChat = () => {
const getMessages = useMessages((state) => state.getMessages)
const addMessage = useMessages((state) => state.addMessage)
const updateMessage = useMessages((state) => state.updateMessage)
const setMessages = useMessages((state) => state.setMessages)
const setModelLoadError = useModelLoad((state) => state.setModelLoadError)
const router = useRouter()
@ -149,7 +343,7 @@ export const useChat = () => {
})
}
return currentThread
}, [createThread, retrieveThread, router, setMessages])
}, [createThread, retrieveThread, router, setMessages, serviceHub])
const restartModel = useCallback(
async (provider: ProviderObject, modelId: string) => {
@ -264,7 +458,8 @@ export const useChat = () => {
base64: string
dataUrl: string
}>,
projectId?: string
projectId?: string,
continueFromMessageId?: string
) => {
const activeThread = await getCurrentThread(projectId)
const selectedProvider = useModelProvider.getState().selectedProvider
@ -277,26 +472,54 @@ export const useChat = () => {
setAbortController(activeThread.id, abortController)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
// Do not add new message on retry
if (troubleshooting)
// Find the message to continue from if provided
const continueFromMessage = continueFromMessageId
? messages.find((m) => m.id === continueFromMessageId)
: undefined
// Do not add new message on retry or when continuing
if (troubleshooting && !continueFromMessageId)
addMessage(newUserThreadContent(activeThread.id, message, attachments))
updateThreadTimestamp(activeThread.id)
usePrompt.getState().setPrompt('')
const selectedModel = useModelProvider.getState().selectedModel
// If continuing, start with the previous content
const accumulatedTextRef = {
value: continueFromMessage?.content?.[0]?.text?.value || ''
}
let currentAssistant: Assistant | undefined | null
try {
if (selectedModel?.id) {
updateLoadingModel(true)
await serviceHub.models().startModel(activeProvider, selectedModel.id)
updateLoadingModel(false)
}
const currentAssistant = useAssistant.getState().currentAssistant
currentAssistant = useAssistant.getState().currentAssistant
// Filter out the stopped message from context if continuing
const contextMessages = continueFromMessageId
? messages.filter((m) => m.id !== continueFromMessageId)
: messages
const builder = new CompletionMessagesBuilder(
messages,
contextMessages,
currentAssistant
? renderInstructions(currentAssistant.instructions)
: undefined
)
if (troubleshooting) builder.addUserMessage(message, attachments)
if (troubleshooting && !continueFromMessageId) {
builder.addUserMessage(message, attachments)
} else if (continueFromMessage) {
// When continuing, add the partial assistant response to the context
builder.addAssistantMessage(
continueFromMessage.content?.[0]?.text?.value || '',
undefined,
[]
)
}
let isCompleted = false
@ -349,7 +572,6 @@ export const useChat = () => {
)
if (!completion) throw new Error('No completion received')
let accumulatedText = ''
const currentCall: ChatCompletionMessageToolCall | null = null
const toolCalls: ChatCompletionMessageToolCall[] = []
const timeToFirstToken = Date.now()
@ -357,13 +579,19 @@ export const useChat = () => {
try {
if (isCompletionResponse(completion)) {
const message = completion.choices[0]?.message
accumulatedText = (message?.content as string) || ''
// When continuing, append to existing content; otherwise replace
const newContent = (message?.content as string) || ''
if (continueFromMessageId && accumulatedTextRef.value) {
accumulatedTextRef.value += newContent
} else {
accumulatedTextRef.value = newContent
}
// Handle reasoning field if there is one
const reasoning = extractReasoningFromMessage(message)
if (reasoning) {
accumulatedText =
`<think>${reasoning}</think>` + accumulatedText
accumulatedTextRef.value =
`<think>${reasoning}</think>` + accumulatedTextRef.value
}
if (message?.tool_calls) {
@ -373,161 +601,25 @@ export const useChat = () => {
tokenUsage = completion.usage
}
} else {
// High-throughput scheduler: batch UI updates on rAF (requestAnimationFrame)
let rafScheduled = false
let rafHandle: number | undefined
let pendingDeltaCount = 0
const reasoningProcessor = new ReasoningProcessor()
const scheduleFlush = () => {
if (rafScheduled || abortController.signal.aborted) return
rafScheduled = true
const doSchedule = (cb: () => void) => {
if (typeof requestAnimationFrame !== 'undefined') {
rafHandle = requestAnimationFrame(() => cb())
} else {
// Fallback for non-browser test environments
const t = setTimeout(() => cb(), 0) as unknown as number
rafHandle = t
}
}
doSchedule(() => {
// Check abort status before executing the scheduled callback
if (abortController.signal.aborted) {
rafScheduled = false
return
}
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
const tokenUsageRef = { current: tokenUsage }
await processStreamingCompletion(
completion,
abortController,
activeThread,
accumulatedTextRef,
toolCalls,
currentCall,
updateStreamingContent,
updateTokenSpeed,
setTokenSpeed,
updatePromptProgress,
timeToFirstToken,
tokenUsageRef,
continueFromMessageId,
updateMessage,
continueFromMessage
)
updateStreamingContent(currentContent)
if (tokenUsage) {
setTokenSpeed(
currentContent,
tokenUsage.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsage.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
})
}
const flushIfPending = () => {
if (!rafScheduled) return
if (
typeof cancelAnimationFrame !== 'undefined' &&
rafHandle !== undefined
) {
cancelAnimationFrame(rafHandle)
} else if (rafHandle !== undefined) {
clearTimeout(rafHandle)
}
// Do an immediate flush
const currentContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
{
tool_calls: toolCalls.map((e) => ({
...e,
state: 'pending',
})),
}
)
updateStreamingContent(currentContent)
if (tokenUsage) {
setTokenSpeed(
currentContent,
tokenUsage.completion_tokens /
Math.max((Date.now() - timeToFirstToken) / 1000, 1),
tokenUsage.completion_tokens
)
} else if (pendingDeltaCount > 0) {
updateTokenSpeed(currentContent, pendingDeltaCount)
}
pendingDeltaCount = 0
rafScheduled = false
}
try {
for await (const part of completion) {
// Check if aborted before processing each part
if (abortController.signal.aborted) {
break
}
// Handle prompt progress if available
if ('prompt_progress' in part && part.prompt_progress) {
// Force immediate state update to ensure we see intermediate values
flushSync(() => {
updatePromptProgress(part.prompt_progress)
})
// Add a small delay to make progress visible
await new Promise((resolve) => setTimeout(resolve, 100))
}
// Error message
if (!part.choices) {
throw new Error(
'message' in part
? (part.message as string)
: (JSON.stringify(part) ?? '')
)
}
if ('usage' in part && part.usage) {
tokenUsage = part.usage
}
if (part.choices[0]?.delta?.tool_calls) {
extractToolCall(part, currentCall, toolCalls)
// Schedule a flush to reflect tool update
scheduleFlush()
}
const deltaReasoning =
reasoningProcessor.processReasoningChunk(part)
if (deltaReasoning) {
accumulatedText += deltaReasoning
pendingDeltaCount += 1
// Schedule flush for reasoning updates
scheduleFlush()
}
const deltaContent = part.choices[0]?.delta?.content || ''
if (deltaContent) {
accumulatedText += deltaContent
pendingDeltaCount += 1
// Batch UI update on next animation frame
scheduleFlush()
}
}
} finally {
// Always clean up scheduled RAF when stream ends (either normally or via abort)
if (rafHandle !== undefined) {
if (typeof cancelAnimationFrame !== 'undefined') {
cancelAnimationFrame(rafHandle)
} else {
clearTimeout(rafHandle)
}
rafHandle = undefined
rafScheduled = false
}
// Only finalize and flush if not aborted
if (!abortController.signal.aborted) {
// Finalize reasoning (close any open think tags)
accumulatedText += reasoningProcessor.finalize()
// Ensure any pending buffered content is rendered at the end
flushIfPending()
}
}
tokenUsage = tokenUsageRef.current
}
} catch (error) {
const errorMessage =
@ -561,7 +653,7 @@ export const useChat = () => {
}
// TODO: Remove this check when integrating new llama.cpp extension
if (
accumulatedText.length === 0 &&
accumulatedTextRef.value.length === 0 &&
toolCalls.length === 0 &&
activeThread.model?.id &&
activeProvider?.provider === 'llamacpp'
@ -573,16 +665,29 @@ export const useChat = () => {
}
// Create a final content object for adding to the thread
const finalContent = newAssistantThreadContent(
let finalContent = newAssistantThreadContent(
activeThread.id,
accumulatedText,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
)
builder.addAssistantMessage(accumulatedText, undefined, toolCalls)
// If continuing from a message, preserve the ID and set status to Ready
if (continueFromMessageId) {
finalContent = {
...finalContent,
id: continueFromMessageId,
status: MessageStatus.Ready,
}
}
// Normal completion flow (abort is handled after loop exits)
// Don't add assistant message to builder if continuing - it's already there
if (!continueFromMessageId) {
builder.addAssistantMessage(accumulatedTextRef.value, undefined, toolCalls)
}
const updatedMessage = await postMessageProcessing(
toolCalls,
builder,
@ -592,10 +697,15 @@ export const useChat = () => {
allowAllMCPPermissions ? undefined : showApprovalModal,
allowAllMCPPermissions
)
addMessage(updatedMessage ?? finalContent)
updateStreamingContent(emptyThreadContent)
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
finalizeMessage(
updatedMessage ?? finalContent,
addMessage,
updateStreamingContent,
updatePromptProgress,
updateThreadTimestamp,
updateMessage,
continueFromMessageId
)
isCompleted = !toolCalls.length
// Do not create agent loop if there is no need for it
@ -605,8 +715,109 @@ export const useChat = () => {
availableTools = []
}
}
// IMPORTANT: Check if aborted AFTER the while loop exits
// The while loop exits when abort is true, so we handle it here
// Only save interrupted messages for llamacpp provider
// Other providers (OpenAI, Claude, etc.) handle streaming differently
if (
abortController.signal.aborted &&
accumulatedTextRef.value.length > 0 &&
activeProvider?.provider === 'llamacpp'
) {
// If continuing, update the existing message; otherwise add new
if (continueFromMessageId && continueFromMessage) {
// Preserve the original message metadata
updateMessage({
...continueFromMessage,
content: [
{
type: ContentType.Text,
text: {
value: accumulatedTextRef.value,
annotations: [],
},
},
],
status: MessageStatus.Stopped,
metadata: {
...continueFromMessage.metadata,
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
},
})
} else {
// Create final content for the partial message with Stopped status
const partialContent = {
...newAssistantThreadContent(
activeThread.id,
accumulatedTextRef.value,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
),
status: MessageStatus.Stopped,
}
addMessage(partialContent)
}
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
}
} catch (error) {
if (!abortController.signal.aborted) {
// If aborted, save the partial message even though an error occurred
// Only save for llamacpp provider - other providers handle streaming differently
const streamingContent = useAppState.getState().streamingContent
const hasPartialContent = accumulatedTextRef.value.length > 0 ||
(streamingContent && streamingContent.content?.[0]?.text?.value)
if (
abortController.signal.aborted &&
hasPartialContent &&
activeProvider?.provider === 'llamacpp'
) {
// Use streaming content if available, otherwise use accumulatedTextRef
const contentText = streamingContent?.content?.[0]?.text?.value || accumulatedTextRef.value
// If continuing, update the existing message; otherwise add new
if (continueFromMessageId && continueFromMessage) {
// Preserve the original message metadata
updateMessage({
...continueFromMessage,
content: [
{
type: ContentType.Text,
text: {
value: contentText,
annotations: [],
},
},
],
status: MessageStatus.Stopped,
metadata: {
...continueFromMessage.metadata,
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
},
})
} else {
const partialContent = {
...newAssistantThreadContent(
activeThread.id,
contentText,
{
tokenSpeed: useAppState.getState().tokenSpeed,
assistant: currentAssistant,
}
),
status: MessageStatus.Stopped,
}
addMessage(partialContent)
}
updatePromptProgress(undefined)
updateThreadTimestamp(activeThread.id)
} else if (!abortController.signal.aborted) {
// Only show error if not aborted
if (error && typeof error === 'object' && 'message' in error) {
setModelLoadError(error as ErrorObject)
} else {
@ -628,12 +839,14 @@ export const useChat = () => {
updateStreamingContent,
updatePromptProgress,
addMessage,
updateMessage,
updateThreadTimestamp,
updateLoadingModel,
getDisabledToolsForThread,
allowAllMCPPermissions,
showApprovalModal,
updateTokenSpeed,
setTokenSpeed,
showIncreaseContextSizeModal,
increaseModelContextSize,
toggleOnContextShifting,

View File

@ -8,6 +8,7 @@ type MessageState = {
getMessages: (threadId: string) => ThreadMessage[]
setMessages: (threadId: string, messages: ThreadMessage[]) => void
addMessage: (message: ThreadMessage) => void
updateMessage: (message: ThreadMessage) => void
deleteMessage: (threadId: string, messageId: string) => void
clearAllMessages: () => void
}
@ -40,16 +41,52 @@ export const useMessages = create<MessageState>()((set, get) => ({
assistant: selectedAssistant,
},
}
getServiceHub().messages().createMessage(newMessage).then((createdMessage) => {
// Optimistically update state immediately for instant UI feedback
set((state) => ({
messages: {
...state.messages,
[message.thread_id]: [
...(state.messages[message.thread_id] || []),
createdMessage,
newMessage,
],
},
}))
// Persist to storage asynchronously
getServiceHub().messages().createMessage(newMessage).catch((error) => {
console.error('Failed to persist message:', error)
})
},
updateMessage: (message) => {
const assistants = useAssistant.getState().assistants
const currentAssistant = useAssistant.getState().currentAssistant
const selectedAssistant =
assistants.find((a) => a.id === currentAssistant?.id) || assistants[0]
const updatedMessage = {
...message,
metadata: {
...message.metadata,
assistant: selectedAssistant,
},
}
// Optimistically update state immediately for instant UI feedback
set((state) => ({
messages: {
...state.messages,
[message.thread_id]: (state.messages[message.thread_id] || []).map((m) =>
m.id === message.id ? updatedMessage : m
),
},
}))
// Persist to storage asynchronously using modifyMessage instead of createMessage
// to prevent duplicates when updating existing messages
getServiceHub().messages().modifyMessage(updatedMessage).catch((error) => {
console.error('Failed to persist message update:', error)
})
},
deleteMessage: (threadId, messageId) => {

View File

@ -23,8 +23,9 @@ export const useTools = () => {
updateTools(data)
// Initialize default disabled tools for new users (only once)
if (!isDefaultsInitialized() && data.length > 0 && mcpExtension?.getDefaultDisabledTools) {
const defaultDisabled = await mcpExtension.getDefaultDisabledTools()
const mcpExt = mcpExtension as MCPExtension & { getDefaultDisabledTools?: () => Promise<string[]> }
if (!isDefaultsInitialized() && data.length > 0 && mcpExt?.getDefaultDisabledTools) {
const defaultDisabled = await mcpExt.getDefaultDisabledTools()
if (defaultDisabled.length > 0) {
setDefaultDisabledTools(defaultDisabled)
markDefaultsAsInitialized()

View File

@ -135,6 +135,7 @@
"enterApiKey": "Enter API Key",
"scrollToBottom": "Scroll to bottom",
"generateAiResponse": "Generate AI Response",
"continueAiResponse": "Continue with AI Response",
"addModel": {
"title": "Add Model",
"modelId": "Model ID",

View File

@ -40,6 +40,20 @@ export class DefaultMessagesService implements MessagesService {
)
}
async modifyMessage(message: ThreadMessage): Promise<ThreadMessage> {
// Don't modify messages on server for temporary chat - it's local only
if (message.thread_id === TEMPORARY_CHAT_ID) {
return message
}
return (
ExtensionManager.getInstance()
.get<ConversationalExtension>(ExtensionTypeEnum.Conversational)
?.modifyMessage(message)
?.catch(() => message) ?? message
)
}
async deleteMessage(threadId: string, messageId: string): Promise<void> {
// Don't delete messages on server for temporary chat - it's local only
if (threadId === TEMPORARY_CHAT_ID) {

View File

@ -7,5 +7,6 @@ import { ThreadMessage } from '@janhq/core'
export interface MessagesService {
fetchMessages(threadId: string): Promise<ThreadMessage[]>
createMessage(message: ThreadMessage): Promise<ThreadMessage>
modifyMessage(message: ThreadMessage): Promise<ThreadMessage>
deleteMessage(threadId: string, messageId: string): Promise<void>
}