feat: Add tests for the Continuing with AI response

This commit is contained in:
Vanalite 2025-10-01 19:13:10 +07:00
parent ccca331d6c
commit 52f73af08c
2 changed files with 412 additions and 58 deletions

View File

@ -1,6 +1,20 @@
import { renderHook, act } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach } from 'vitest'
import { renderHook, act, waitFor } from '@testing-library/react'
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { useChat } from '../useChat'
import * as completionLib from '@/lib/completion'
import * as messagesLib from '@/lib/messages'
import { MessageStatus, ContentType } from '@janhq/core'
// Store mock functions for assertions
let mockAddMessage: ReturnType<typeof vi.fn>
let mockUpdateMessage: ReturnType<typeof vi.fn>
let mockGetMessages: ReturnType<typeof vi.fn>
let mockStartModel: ReturnType<typeof vi.fn>
let mockSendCompletion: ReturnType<typeof vi.fn>
let mockPostMessageProcessing: ReturnType<typeof vi.fn>
let mockCompletionMessagesBuilder: any
let mockSetPrompt: ReturnType<typeof vi.fn>
let mockResetTokenSpeed: ReturnType<typeof vi.fn>
// Mock dependencies
vi.mock('../usePrompt', () => ({
@ -8,11 +22,16 @@ vi.mock('../usePrompt', () => ({
(selector: any) => {
const state = {
prompt: 'test prompt',
setPrompt: vi.fn(),
setPrompt: mockSetPrompt,
}
return selector ? selector(state) : state
},
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) }
{
getState: () => ({
prompt: 'test prompt',
setPrompt: mockSetPrompt
})
}
),
}))
@ -22,39 +41,58 @@ vi.mock('../useAppState', () => ({
const state = {
tools: [],
updateTokenSpeed: vi.fn(),
resetTokenSpeed: vi.fn(),
resetTokenSpeed: mockResetTokenSpeed,
updateTools: vi.fn(),
updateStreamingContent: vi.fn(),
updatePromptProgress: vi.fn(),
updateLoadingModel: vi.fn(),
setAbortController: vi.fn(),
streamingContent: undefined,
}
return selector ? selector(state) : state
},
{
getState: vi.fn(() => ({
tools: [],
tokenSpeed: { tokensPerSecond: 10 },
streamingContent: undefined,
}))
}
),
}))
vi.mock('../useAssistant', () => ({
useAssistant: (selector: any) => {
const state = {
assistants: [{
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
}],
currentAssistant: {
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
},
useAssistant: Object.assign(
(selector: any) => {
const state = {
assistants: [{
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
}],
currentAssistant: {
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
},
}
return selector ? selector(state) : state
},
{
getState: () => ({
assistants: [{
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
}],
currentAssistant: {
id: 'test-assistant',
instructions: 'test instructions',
parameters: { stream: true },
},
})
}
return selector ? selector(state) : state
},
),
}))
vi.mock('../useModelProvider', () => ({
@ -62,14 +100,15 @@ vi.mock('../useModelProvider', () => ({
(selector: any) => {
const state = {
getProviderByName: vi.fn(() => ({
provider: 'openai',
provider: 'llamacpp',
models: [],
settings: [],
})),
selectedModel: {
id: 'test-model',
capabilities: ['tools'],
},
selectedProvider: 'openai',
selectedProvider: 'llamacpp',
updateProvider: vi.fn(),
}
return selector ? selector(state) : state
@ -77,14 +116,15 @@ vi.mock('../useModelProvider', () => ({
{
getState: () => ({
getProviderByName: vi.fn(() => ({
provider: 'openai',
provider: 'llamacpp',
models: [],
settings: [],
})),
selectedModel: {
id: 'test-model',
capabilities: ['tools'],
},
selectedProvider: 'openai',
selectedProvider: 'llamacpp',
updateProvider: vi.fn(),
})
}
@ -96,11 +136,11 @@ vi.mock('../useThreads', () => ({
const state = {
getCurrentThread: vi.fn(() => ({
id: 'test-thread',
model: { id: 'test-model', provider: 'openai' },
model: { id: 'test-model', provider: 'llamacpp' },
})),
createThread: vi.fn(() => Promise.resolve({
id: 'test-thread',
model: { id: 'test-model', provider: 'openai' },
model: { id: 'test-model', provider: 'llamacpp' },
})),
updateThreadTimestamp: vi.fn(),
}
@ -111,22 +151,33 @@ vi.mock('../useThreads', () => ({
vi.mock('../useMessages', () => ({
useMessages: (selector: any) => {
const state = {
getMessages: vi.fn(() => []),
addMessage: vi.fn(),
getMessages: mockGetMessages,
addMessage: mockAddMessage,
updateMessage: mockUpdateMessage,
setMessages: vi.fn(),
}
return selector ? selector(state) : state
},
}))
vi.mock('../useToolApproval', () => ({
useToolApproval: (selector: any) => {
const state = {
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
useToolApproval: Object.assign(
(selector: any) => {
const state = {
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
}
return selector ? selector(state) : state
},
{
getState: () => ({
approvedTools: [],
showApprovalModal: vi.fn(),
allowAllMCPPermissions: false,
})
}
return selector ? selector(state) : state
},
),
}))
vi.mock('../useToolAvailable', () => ({
@ -162,38 +213,57 @@ vi.mock('@tanstack/react-router', () => ({
})),
}))
vi.mock('../useServiceHub', () => ({
useServiceHub: vi.fn(() => ({
models: () => ({
startModel: mockStartModel,
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}),
providers: () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}),
})),
}))
vi.mock('@/lib/completion', () => ({
emptyThreadContent: { thread_id: 'test-thread', content: '' },
extractToolCall: vi.fn(),
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
sendCompletion: vi.fn(),
postMessageProcessing: vi.fn(),
isCompletionResponse: vi.fn(),
newUserThreadContent: vi.fn((threadId, content) => ({
thread_id: threadId,
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
role: 'user'
})),
newAssistantThreadContent: vi.fn((threadId, content) => ({
thread_id: threadId,
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
role: 'assistant'
})),
sendCompletion: mockSendCompletion,
postMessageProcessing: mockPostMessageProcessing,
isCompletionResponse: vi.fn(() => true),
}))
vi.mock('@/lib/messages', () => ({
CompletionMessagesBuilder: vi.fn(() => ({
addUserMessage: vi.fn(),
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []),
CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
}))
vi.mock('@/lib/instructionTemplate', () => ({
renderInstructions: vi.fn((instructions: string) => instructions),
}))
vi.mock('@/utils/reasoning', () => ({
ReasoningProcessor: vi.fn(() => ({
processReasoningChunk: vi.fn(() => null),
finalize: vi.fn(() => ''),
})),
extractReasoningFromMessage: vi.fn(() => null),
}))
vi.mock('@/services/mcp', () => ({
getTools: vi.fn(() => Promise.resolve([])),
}))
vi.mock('@/services/models', () => ({
startModel: vi.fn(() => Promise.resolve()),
stopModel: vi.fn(() => Promise.resolve()),
stopAllModels: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/services/providers', () => ({
updateSettings: vi.fn(() => Promise.resolve()),
}))
vi.mock('@tauri-apps/api/event', () => ({
listen: vi.fn(() => Promise.resolve(vi.fn())),
}))
@ -206,9 +276,37 @@ vi.mock('sonner', () => ({
describe('useChat', () => {
beforeEach(() => {
// Reset all mocks
mockAddMessage = vi.fn()
mockUpdateMessage = vi.fn()
mockGetMessages = vi.fn(() => [])
mockStartModel = vi.fn(() => Promise.resolve())
mockSetPrompt = vi.fn()
mockResetTokenSpeed = vi.fn()
mockSendCompletion = vi.fn(() => Promise.resolve({
choices: [{
message: {
content: 'AI response',
role: 'assistant',
},
}],
}))
mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
Promise.resolve(content)
)
mockCompletionMessagesBuilder = {
addUserMessage: vi.fn(),
addAssistantMessage: vi.fn(),
getMessages: vi.fn(() => []),
}
vi.clearAllMocks()
})
afterEach(() => {
vi.clearAllTimers()
})
it('returns sendMessage function', () => {
const { result } = renderHook(() => useChat())
@ -216,13 +314,268 @@ describe('useChat', () => {
expect(typeof result.current).toBe('function')
})
it('sends message successfully', async () => {
const { result } = renderHook(() => useChat())
describe('Continue with AI response functionality', () => {
it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Hello world')
await act(async () => {
await result.current('Hello world', true, undefined, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Hello world',
undefined
)
expect(mockAddMessage).toHaveBeenCalledWith(
expect.objectContaining({
thread_id: 'test-thread',
role: 'user',
})
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Hello world',
undefined
)
})
expect(result.current).toBeDefined()
it('should NOT add new user message when continueFromMessageId is provided', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Continue', true, undefined, 'msg-123')
})
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
const userMessageCalls = mockAddMessage.mock.calls.filter(
(call: any) => call[0]?.role === 'user'
)
expect(userMessageCalls).toHaveLength(0)
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
})
it('should add partial assistant message to builder when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
expect(mockCompletionMessagesBuilder.addAssistantMessage).toHaveBeenCalledWith(
'Partial response',
undefined,
[]
)
})
it('should filter out stopped message from context when continuing', async () => {
const userMsg = {
id: 'msg-1',
thread_id: 'test-thread',
role: 'user',
content: [{ type: ContentType.Text, text: { value: 'Hello', annotations: [] } }],
}
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
}
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
await waitFor(() => {
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalledWith(
[userMsg], // stopped message filtered out
'test instructions'
)
})
})
it('should update existing message instead of adding new one when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
await waitFor(() => {
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
})
it('should start with previous content when continuing', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
mockSendCompletion.mockResolvedValue({
choices: [{
message: {
content: ' continued',
role: 'assistant',
},
}],
})
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
// The accumulated text should contain the previous content
await waitFor(() => {
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
content: expect.arrayContaining([
expect.objectContaining({
text: expect.objectContaining({
value: expect.stringContaining('Partial response'),
})
})
])
})
)
})
})
it('should handle attachments correctly when not continuing', async () => {
const { result } = renderHook(() => useChat())
const attachments = [
{
name: 'test.png',
type: 'image/png',
size: 1024,
base64: 'base64data',
dataUrl: '',
},
]
await act(async () => {
await result.current('Message with attachment', true, attachments, undefined)
})
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
'test-thread',
'Message with attachment',
attachments
)
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
'Message with attachment',
attachments
)
})
it('should preserve message status as Ready after continuation completes', async () => {
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
await waitFor(() => {
expect(mockUpdateMessage).toHaveBeenCalledWith(
expect.objectContaining({
id: 'msg-123',
status: MessageStatus.Ready,
})
)
})
})
})
describe('Normal message sending', () => {
it('sends message successfully without continuation', async () => {
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('Hello world')
})
expect(mockSendCompletion).toHaveBeenCalled()
expect(mockStartModel).toHaveBeenCalled()
})
})
describe('Error handling', () => {
it('should handle errors gracefully during continuation', async () => {
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
const stoppedMessage = {
id: 'msg-123',
thread_id: 'test-thread',
role: 'assistant',
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
status: MessageStatus.Stopped,
metadata: {},
}
mockGetMessages.mockReturnValue([stoppedMessage])
const { result } = renderHook(() => useChat())
await act(async () => {
await result.current('', true, undefined, 'msg-123')
})
expect(result.current).toBeDefined()
})
})
})

View File

@ -791,6 +791,7 @@ export const useChat = () => {
updateStreamingContent,
updatePromptProgress,
addMessage,
updateMessage,
updateThreadTimestamp,
updateLoadingModel,
getDisabledToolsForThread,