diff --git a/web-app/src/hooks/__tests__/useChat.test.ts b/web-app/src/hooks/__tests__/useChat.test.ts index e87191fb6..3f89e24cd 100644 --- a/web-app/src/hooks/__tests__/useChat.test.ts +++ b/web-app/src/hooks/__tests__/useChat.test.ts @@ -1,6 +1,20 @@ -import { renderHook, act } from '@testing-library/react' -import { describe, it, expect, vi, beforeEach } from 'vitest' +import { renderHook, act, waitFor } from '@testing-library/react' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { useChat } from '../useChat' +import * as completionLib from '@/lib/completion' +import * as messagesLib from '@/lib/messages' +import { MessageStatus, ContentType } from '@janhq/core' + +// Store mock functions for assertions +let mockAddMessage: ReturnType +let mockUpdateMessage: ReturnType +let mockGetMessages: ReturnType +let mockStartModel: ReturnType +let mockSendCompletion: ReturnType +let mockPostMessageProcessing: ReturnType +let mockCompletionMessagesBuilder: any +let mockSetPrompt: ReturnType +let mockResetTokenSpeed: ReturnType // Mock dependencies vi.mock('../usePrompt', () => ({ @@ -8,11 +22,16 @@ vi.mock('../usePrompt', () => ({ (selector: any) => { const state = { prompt: 'test prompt', - setPrompt: vi.fn(), + setPrompt: mockSetPrompt, } return selector ? selector(state) : state }, - { getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) } + { + getState: () => ({ + prompt: 'test prompt', + setPrompt: mockSetPrompt + }) + } ), })) @@ -22,39 +41,58 @@ vi.mock('../useAppState', () => ({ const state = { tools: [], updateTokenSpeed: vi.fn(), - resetTokenSpeed: vi.fn(), + resetTokenSpeed: mockResetTokenSpeed, updateTools: vi.fn(), updateStreamingContent: vi.fn(), updatePromptProgress: vi.fn(), updateLoadingModel: vi.fn(), setAbortController: vi.fn(), + streamingContent: undefined, } return selector ? selector(state) : state }, { getState: vi.fn(() => ({ + tools: [], tokenSpeed: { tokensPerSecond: 10 }, + streamingContent: undefined, })) } ), })) vi.mock('../useAssistant', () => ({ - useAssistant: (selector: any) => { - const state = { - assistants: [{ - id: 'test-assistant', - instructions: 'test instructions', - parameters: { stream: true }, - }], - currentAssistant: { - id: 'test-assistant', - instructions: 'test instructions', - parameters: { stream: true }, - }, + useAssistant: Object.assign( + (selector: any) => { + const state = { + assistants: [{ + id: 'test-assistant', + instructions: 'test instructions', + parameters: { stream: true }, + }], + currentAssistant: { + id: 'test-assistant', + instructions: 'test instructions', + parameters: { stream: true }, + }, + } + return selector ? selector(state) : state + }, + { + getState: () => ({ + assistants: [{ + id: 'test-assistant', + instructions: 'test instructions', + parameters: { stream: true }, + }], + currentAssistant: { + id: 'test-assistant', + instructions: 'test instructions', + parameters: { stream: true }, + }, + }) } - return selector ? selector(state) : state - }, + ), })) vi.mock('../useModelProvider', () => ({ @@ -62,14 +100,15 @@ vi.mock('../useModelProvider', () => ({ (selector: any) => { const state = { getProviderByName: vi.fn(() => ({ - provider: 'openai', + provider: 'llamacpp', models: [], + settings: [], })), selectedModel: { id: 'test-model', capabilities: ['tools'], }, - selectedProvider: 'openai', + selectedProvider: 'llamacpp', updateProvider: vi.fn(), } return selector ? selector(state) : state @@ -77,14 +116,15 @@ vi.mock('../useModelProvider', () => ({ { getState: () => ({ getProviderByName: vi.fn(() => ({ - provider: 'openai', + provider: 'llamacpp', models: [], + settings: [], })), selectedModel: { id: 'test-model', capabilities: ['tools'], }, - selectedProvider: 'openai', + selectedProvider: 'llamacpp', updateProvider: vi.fn(), }) } @@ -96,11 +136,11 @@ vi.mock('../useThreads', () => ({ const state = { getCurrentThread: vi.fn(() => ({ id: 'test-thread', - model: { id: 'test-model', provider: 'openai' }, + model: { id: 'test-model', provider: 'llamacpp' }, })), createThread: vi.fn(() => Promise.resolve({ id: 'test-thread', - model: { id: 'test-model', provider: 'openai' }, + model: { id: 'test-model', provider: 'llamacpp' }, })), updateThreadTimestamp: vi.fn(), } @@ -111,22 +151,33 @@ vi.mock('../useThreads', () => ({ vi.mock('../useMessages', () => ({ useMessages: (selector: any) => { const state = { - getMessages: vi.fn(() => []), - addMessage: vi.fn(), + getMessages: mockGetMessages, + addMessage: mockAddMessage, + updateMessage: mockUpdateMessage, + setMessages: vi.fn(), } return selector ? selector(state) : state }, })) vi.mock('../useToolApproval', () => ({ - useToolApproval: (selector: any) => { - const state = { - approvedTools: [], - showApprovalModal: vi.fn(), - allowAllMCPPermissions: false, + useToolApproval: Object.assign( + (selector: any) => { + const state = { + approvedTools: [], + showApprovalModal: vi.fn(), + allowAllMCPPermissions: false, + } + return selector ? selector(state) : state + }, + { + getState: () => ({ + approvedTools: [], + showApprovalModal: vi.fn(), + allowAllMCPPermissions: false, + }) } - return selector ? selector(state) : state - }, + ), })) vi.mock('../useToolAvailable', () => ({ @@ -162,38 +213,57 @@ vi.mock('@tanstack/react-router', () => ({ })), })) +vi.mock('../useServiceHub', () => ({ + useServiceHub: vi.fn(() => ({ + models: () => ({ + startModel: mockStartModel, + stopModel: vi.fn(() => Promise.resolve()), + stopAllModels: vi.fn(() => Promise.resolve()), + }), + providers: () => ({ + updateSettings: vi.fn(() => Promise.resolve()), + }), + })), +})) + vi.mock('@/lib/completion', () => ({ emptyThreadContent: { thread_id: 'test-thread', content: '' }, extractToolCall: vi.fn(), - newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })), - newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })), - sendCompletion: vi.fn(), - postMessageProcessing: vi.fn(), - isCompletionResponse: vi.fn(), + newUserThreadContent: vi.fn((threadId, content) => ({ + thread_id: threadId, + content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }], + role: 'user' + })), + newAssistantThreadContent: vi.fn((threadId, content) => ({ + thread_id: threadId, + content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }], + role: 'assistant' + })), + sendCompletion: mockSendCompletion, + postMessageProcessing: mockPostMessageProcessing, + isCompletionResponse: vi.fn(() => true), })) vi.mock('@/lib/messages', () => ({ - CompletionMessagesBuilder: vi.fn(() => ({ - addUserMessage: vi.fn(), - addAssistantMessage: vi.fn(), - getMessages: vi.fn(() => []), + CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder), +})) + +vi.mock('@/lib/instructionTemplate', () => ({ + renderInstructions: vi.fn((instructions: string) => instructions), +})) + +vi.mock('@/utils/reasoning', () => ({ + ReasoningProcessor: vi.fn(() => ({ + processReasoningChunk: vi.fn(() => null), + finalize: vi.fn(() => ''), })), + extractReasoningFromMessage: vi.fn(() => null), })) vi.mock('@/services/mcp', () => ({ getTools: vi.fn(() => Promise.resolve([])), })) -vi.mock('@/services/models', () => ({ - startModel: vi.fn(() => Promise.resolve()), - stopModel: vi.fn(() => Promise.resolve()), - stopAllModels: vi.fn(() => Promise.resolve()), -})) - -vi.mock('@/services/providers', () => ({ - updateSettings: vi.fn(() => Promise.resolve()), -})) - vi.mock('@tauri-apps/api/event', () => ({ listen: vi.fn(() => Promise.resolve(vi.fn())), })) @@ -206,9 +276,37 @@ vi.mock('sonner', () => ({ describe('useChat', () => { beforeEach(() => { + // Reset all mocks + mockAddMessage = vi.fn() + mockUpdateMessage = vi.fn() + mockGetMessages = vi.fn(() => []) + mockStartModel = vi.fn(() => Promise.resolve()) + mockSetPrompt = vi.fn() + mockResetTokenSpeed = vi.fn() + mockSendCompletion = vi.fn(() => Promise.resolve({ + choices: [{ + message: { + content: 'AI response', + role: 'assistant', + }, + }], + })) + mockPostMessageProcessing = vi.fn((toolCalls, builder, content) => + Promise.resolve(content) + ) + mockCompletionMessagesBuilder = { + addUserMessage: vi.fn(), + addAssistantMessage: vi.fn(), + getMessages: vi.fn(() => []), + } + vi.clearAllMocks() }) + afterEach(() => { + vi.clearAllTimers() + }) + it('returns sendMessage function', () => { const { result } = renderHook(() => useChat()) @@ -216,13 +314,268 @@ describe('useChat', () => { expect(typeof result.current).toBe('function') }) - it('sends message successfully', async () => { - const { result } = renderHook(() => useChat()) + describe('Continue with AI response functionality', () => { + it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => { + const { result } = renderHook(() => useChat()) - await act(async () => { - await result.current('Hello world') + await act(async () => { + await result.current('Hello world', true, undefined, undefined) + }) + + expect(completionLib.newUserThreadContent).toHaveBeenCalledWith( + 'test-thread', + 'Hello world', + undefined + ) + expect(mockAddMessage).toHaveBeenCalledWith( + expect.objectContaining({ + thread_id: 'test-thread', + role: 'user', + }) + ) + expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith( + 'Hello world', + undefined + ) }) - expect(result.current).toBeDefined() + it('should NOT add new user message when continueFromMessageId is provided', async () => { + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('Continue', true, undefined, 'msg-123') + }) + + expect(completionLib.newUserThreadContent).not.toHaveBeenCalled() + const userMessageCalls = mockAddMessage.mock.calls.filter( + (call: any) => call[0]?.role === 'user' + ) + expect(userMessageCalls).toHaveLength(0) + expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled() + }) + + it('should add partial assistant message to builder when continuing', async () => { + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + expect(mockCompletionMessagesBuilder.addAssistantMessage).toHaveBeenCalledWith( + 'Partial response', + undefined, + [] + ) + }) + + it('should filter out stopped message from context when continuing', async () => { + const userMsg = { + id: 'msg-1', + thread_id: 'test-thread', + role: 'user', + content: [{ type: ContentType.Text, text: { value: 'Hello', annotations: [] } }], + } + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + status: MessageStatus.Stopped, + } + mockGetMessages.mockReturnValue([userMsg, stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + await waitFor(() => { + expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalledWith( + [userMsg], // stopped message filtered out + 'test instructions' + ) + }) + }) + + it('should update existing message instead of adding new one when continuing', async () => { + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + await waitFor(() => { + expect(mockUpdateMessage).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'msg-123', + status: MessageStatus.Ready, + }) + ) + }) + }) + + it('should start with previous content when continuing', async () => { + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + mockSendCompletion.mockResolvedValue({ + choices: [{ + message: { + content: ' continued', + role: 'assistant', + }, + }], + }) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + // The accumulated text should contain the previous content + await waitFor(() => { + expect(mockUpdateMessage).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'msg-123', + content: expect.arrayContaining([ + expect.objectContaining({ + text: expect.objectContaining({ + value: expect.stringContaining('Partial response'), + }) + }) + ]) + }) + ) + }) + }) + + it('should handle attachments correctly when not continuing', async () => { + const { result } = renderHook(() => useChat()) + const attachments = [ + { + name: 'test.png', + type: 'image/png', + size: 1024, + base64: 'base64data', + dataUrl: '', + }, + ] + + await act(async () => { + await result.current('Message with attachment', true, attachments, undefined) + }) + + expect(completionLib.newUserThreadContent).toHaveBeenCalledWith( + 'test-thread', + 'Message with attachment', + attachments + ) + expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith( + 'Message with attachment', + attachments + ) + }) + + it('should preserve message status as Ready after continuation completes', async () => { + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + await waitFor(() => { + expect(mockUpdateMessage).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'msg-123', + status: MessageStatus.Ready, + }) + ) + }) + }) + }) + + describe('Normal message sending', () => { + it('sends message successfully without continuation', async () => { + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('Hello world') + }) + + expect(mockSendCompletion).toHaveBeenCalled() + expect(mockStartModel).toHaveBeenCalled() + }) + }) + + describe('Error handling', () => { + it('should handle errors gracefully during continuation', async () => { + mockSendCompletion.mockRejectedValueOnce(new Error('API Error')) + const stoppedMessage = { + id: 'msg-123', + thread_id: 'test-thread', + role: 'assistant', + content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }], + status: MessageStatus.Stopped, + metadata: {}, + } + mockGetMessages.mockReturnValue([stoppedMessage]) + + const { result } = renderHook(() => useChat()) + + await act(async () => { + await result.current('', true, undefined, 'msg-123') + }) + + expect(result.current).toBeDefined() + }) }) }) diff --git a/web-app/src/hooks/useChat.ts b/web-app/src/hooks/useChat.ts index a730e41e6..0a434e975 100644 --- a/web-app/src/hooks/useChat.ts +++ b/web-app/src/hooks/useChat.ts @@ -791,6 +791,7 @@ export const useChat = () => { updateStreamingContent, updatePromptProgress, addMessage, + updateMessage, updateThreadTimestamp, updateLoadingModel, getDisabledToolsForThread,