test: update test cases

This commit is contained in:
Louis 2024-12-09 18:40:57 +07:00
parent 174f1c7dcb
commit f6ba447f1b
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2
7 changed files with 152 additions and 173 deletions

View File

@ -1,4 +1,4 @@
import { Fragment, use, useCallback, useEffect, useRef } from 'react' import { Fragment, useCallback, useEffect, useRef } from 'react'
import { import {
ChatCompletionMessage, ChatCompletionMessage,

View File

@ -67,7 +67,7 @@ describe('useCreateNewThread', () => {
} as any) } as any)
}) })
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set expect(mockSetAtom).toHaveBeenCalledTimes(1)
expect(extensionManager.get).toHaveBeenCalled() expect(extensionManager.get).toHaveBeenCalled()
}) })
@ -104,7 +104,7 @@ describe('useCreateNewThread', () => {
await result.current.requestCreateNewThread({ await result.current.requestCreateNewThread({
id: 'assistant1', id: 'assistant1',
name: 'Assistant 1', name: 'Assistant 1',
instructions: "Hello Jan Assistant", instructions: 'Hello Jan Assistant',
model: { model: {
id: 'model1', id: 'model1',
parameters: [], parameters: [],
@ -113,16 +113,8 @@ describe('useCreateNewThread', () => {
} as any) } as any)
}) })
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
expect(extensionManager.get).toHaveBeenCalled() expect(extensionManager.get).toHaveBeenCalled()
expect(mockSetAtom).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
assistants: expect.arrayContaining([
expect.objectContaining({ instructions: 'Hello Jan Assistant' }),
]),
})
)
}) })
it('should create a new thread with previous instructions', async () => { it('should create a new thread with previous instructions', async () => {
@ -166,16 +158,8 @@ describe('useCreateNewThread', () => {
} as any) } as any)
}) })
expect(mockSetAtom).toHaveBeenCalledTimes(6) // Check if all the necessary atoms were set expect(mockSetAtom).toHaveBeenCalledTimes(1) // Check if all the necessary atoms were set
expect(extensionManager.get).toHaveBeenCalled() expect(extensionManager.get).toHaveBeenCalled()
expect(mockSetAtom).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
assistants: expect.arrayContaining([
expect.objectContaining({ instructions: 'Hello Jan' }),
]),
})
)
}) })
it('should show a warning toast if trying to create an empty thread', async () => { it('should show a warning toast if trying to create an empty thread', async () => {
@ -212,13 +196,12 @@ describe('useCreateNewThread', () => {
const { result } = renderHook(() => useCreateNewThread()) const { result } = renderHook(() => useCreateNewThread())
const mockThread = { id: 'thread1', title: 'Test Thread' } const mockThread = { id: 'thread1', title: 'Test Thread', assistants: [{}] }
await act(async () => { await act(async () => {
await result.current.updateThreadMetadata(mockThread as any) await result.current.updateThreadMetadata(mockThread as any)
}) })
expect(mockUpdateThread).toHaveBeenCalledWith(mockThread) expect(mockUpdateThread).toHaveBeenCalledWith(mockThread)
expect(extensionManager.get).toHaveBeenCalled()
}) })
}) })

View File

@ -2,8 +2,7 @@ import { renderHook, act } from '@testing-library/react'
import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { useAtom, useAtomValue, useSetAtom } from 'jotai'
import useDeleteThread from './useDeleteThread' import useDeleteThread from './useDeleteThread'
import { extensionManager } from '@/extension/ExtensionManager' import { extensionManager } from '@/extension/ExtensionManager'
import { toaster } from '@/containers/Toast' import { useCreateNewThread } from './useCreateNewThread'
// Mock the necessary dependencies // Mock the necessary dependencies
// Mock dependencies // Mock dependencies
jest.mock('jotai', () => ({ jest.mock('jotai', () => ({
@ -12,6 +11,7 @@ jest.mock('jotai', () => ({
useAtom: jest.fn(), useAtom: jest.fn(),
atom: jest.fn(), atom: jest.fn(),
})) }))
jest.mock('./useCreateNewThread')
jest.mock('@/extension/ExtensionManager') jest.mock('@/extension/ExtensionManager')
jest.mock('@/containers/Toast') jest.mock('@/containers/Toast')
@ -27,8 +27,13 @@ describe('useDeleteThread', () => {
] ]
const mockSetThreads = jest.fn() const mockSetThreads = jest.fn()
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
;(useSetAtom as jest.Mock).mockReturnValue(() => {})
;(useCreateNewThread as jest.Mock).mockReturnValue({})
const mockDeleteThread = jest.fn().mockImplementation(() => ({
catch: () => jest.fn,
}))
const mockDeleteThread = jest.fn()
extensionManager.get = jest.fn().mockReturnValue({ extensionManager.get = jest.fn().mockReturnValue({
deleteThread: mockDeleteThread, deleteThread: mockDeleteThread,
}) })
@ -50,12 +55,17 @@ describe('useDeleteThread', () => {
const mockCleanMessages = jest.fn() const mockCleanMessages = jest.fn()
;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages) ;(useSetAtom as jest.Mock).mockReturnValue(() => mockCleanMessages)
;(useAtomValue as jest.Mock).mockReturnValue(['thread 1']) ;(useAtomValue as jest.Mock).mockReturnValue(['thread 1'])
const mockCreateNewThread = jest.fn()
;(useCreateNewThread as jest.Mock).mockReturnValue({
requestCreateNewThread: mockCreateNewThread,
})
const mockWriteMessages = jest.fn()
const mockSaveThread = jest.fn() const mockSaveThread = jest.fn()
const mockDeleteThread = jest.fn().mockResolvedValue({})
extensionManager.get = jest.fn().mockReturnValue({ extensionManager.get = jest.fn().mockReturnValue({
writeMessages: mockWriteMessages,
saveThread: mockSaveThread, saveThread: mockSaveThread,
getThreadAssistant: jest.fn().mockResolvedValue({}),
deleteThread: mockDeleteThread,
}) })
const { result } = renderHook(() => useDeleteThread()) const { result } = renderHook(() => useDeleteThread())
@ -64,20 +74,18 @@ describe('useDeleteThread', () => {
await result.current.cleanThread('thread1') await result.current.cleanThread('thread1')
}) })
expect(mockWriteMessages).toHaveBeenCalled() expect(mockDeleteThread).toHaveBeenCalled()
expect(mockSaveThread).toHaveBeenCalledWith( expect(mockCreateNewThread).toHaveBeenCalled()
expect.objectContaining({
id: 'thread1',
title: 'New Thread',
metadata: expect.objectContaining({ lastMessage: undefined }),
})
)
}) })
it('should handle errors when deleting a thread', async () => { it('should handle errors when deleting a thread', async () => {
const mockThreads = [{ id: 'thread1', title: 'Thread 1' }] const mockThreads = [{ id: 'thread1', title: 'Thread 1' }]
const mockSetThreads = jest.fn() const mockSetThreads = jest.fn()
;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads]) ;(useAtom as jest.Mock).mockReturnValue([mockThreads, mockSetThreads])
const mockCreateNewThread = jest.fn()
;(useCreateNewThread as jest.Mock).mockReturnValue({
requestCreateNewThread: mockCreateNewThread,
})
const mockDeleteThread = jest const mockDeleteThread = jest
.fn() .fn()
@ -98,8 +106,6 @@ describe('useDeleteThread', () => {
expect(mockDeleteThread).toHaveBeenCalledWith('thread1') expect(mockDeleteThread).toHaveBeenCalledWith('thread1')
expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error)) expect(consoleErrorSpy).toHaveBeenCalledWith(expect.any(Error))
expect(mockSetThreads).not.toHaveBeenCalled()
expect(toaster).not.toHaveBeenCalled()
consoleErrorSpy.mockRestore() consoleErrorSpy.mockRestore()
}) })

View File

@ -78,7 +78,7 @@ describe('useThreads', () => {
// Mock extensionManager // Mock extensionManager
const mockGetThreads = jest.fn().mockResolvedValue(mockThreads) const mockGetThreads = jest.fn().mockResolvedValue(mockThreads)
;(extensionManager.get as jest.Mock).mockReturnValue({ ;(extensionManager.get as jest.Mock).mockReturnValue({
getThreads: mockGetThreads, listThreads: mockGetThreads,
}) })
const { result } = renderHook(() => useThreads()) const { result } = renderHook(() => useThreads())
@ -119,7 +119,7 @@ describe('useThreads', () => {
it('should handle empty threads', async () => { it('should handle empty threads', async () => {
// Mock empty threads // Mock empty threads
;(extensionManager.get as jest.Mock).mockReturnValue({ ;(extensionManager.get as jest.Mock).mockReturnValue({
getThreads: jest.fn().mockResolvedValue([]), listThreads: jest.fn().mockResolvedValue([]),
}) })
const mockSetThreadStates = jest.fn() const mockSetThreadStates = jest.fn()

View File

@ -1,7 +1,12 @@
import { renderHook, act } from '@testing-library/react' import { renderHook, act } from '@testing-library/react'
import { useAtom } from 'jotai'
// Mock dependencies // Mock dependencies
jest.mock('ulidx') jest.mock('ulidx')
jest.mock('@/extension') jest.mock('@/extension')
jest.mock('jotai', () => ({
...jest.requireActual('jotai'),
useAtom: jest.fn(),
}))
import useUpdateModelParameters from './useUpdateModelParameters' import useUpdateModelParameters from './useUpdateModelParameters'
import { extensionManager } from '@/extension' import { extensionManager } from '@/extension'
@ -13,7 +18,8 @@ let model: any = {
} }
let extension: any = { let extension: any = {
saveThread: jest.fn(), modifyThread: jest.fn(),
modifyThreadAssistant: jest.fn(),
} }
const mockThread: any = { const mockThread: any = {
@ -35,6 +41,7 @@ const mockThread: any = {
describe('useUpdateModelParameters', () => { describe('useUpdateModelParameters', () => {
beforeAll(() => { beforeAll(() => {
jest.clearAllMocks() jest.clearAllMocks()
jest.useFakeTimers()
jest.mock('./useRecommendedModel', () => ({ jest.mock('./useRecommendedModel', () => ({
useRecommendedModel: () => ({ useRecommendedModel: () => ({
recommendedModel: model, recommendedModel: model,
@ -45,6 +52,12 @@ describe('useUpdateModelParameters', () => {
}) })
it('should update model parameters and save thread when params are valid', async () => { it('should update model parameters and save thread when params are valid', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockValidParameters: any = { const mockValidParameters: any = {
params: { params: {
// Inference // Inference
@ -76,7 +89,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions // Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension) jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({}) jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters()) const { result } = renderHook(() => useUpdateModelParameters())
@ -84,44 +98,46 @@ describe('useUpdateModelParameters', () => {
await result.current.updateModelParameter(mockThread, mockValidParameters) await result.current.updateModelParameter(mockThread, mockValidParameters)
}) })
jest.runAllTimers()
// Check if the model parameters are valid before persisting // Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({ expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
assistants: [ id: 'assistant-1',
{ model: {
model: { parameters: {
parameters: { stop: ['<eos>', '<eos2>'],
stop: ['<eos>', '<eos2>'], temperature: 0.5,
temperature: 0.5, token_limit: 1000,
token_limit: 1000, top_k: 0.7,
top_k: 0.7, top_p: 0.1,
top_p: 0.1, stream: true,
stream: true, max_tokens: 1000,
max_tokens: 1000, frequency_penalty: 0.3,
frequency_penalty: 0.3, presence_penalty: 0.2,
presence_penalty: 0.2,
},
settings: {
ctx_len: 1024,
ngl: 12,
embedding: true,
n_parallel: 2,
cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
},
}, },
], settings: {
created: 0, ctx_len: 1024,
id: 'thread-1', ngl: 12,
object: 'thread', embedding: true,
title: 'New Thread', n_parallel: 2,
updated: 0, cpu_threads: 4,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
},
id: 'model-1',
engine: 'nitro',
},
}) })
}) })
it('should not update invalid model parameters', async () => { it('should not update invalid model parameters', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockInvalidParameters: any = { const mockInvalidParameters: any = {
params: { params: {
// Inference // Inference
@ -153,7 +169,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions // Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension) jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({}) jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters()) const { result } = renderHook(() => useUpdateModelParameters())
@ -164,36 +181,38 @@ describe('useUpdateModelParameters', () => {
) )
}) })
jest.runAllTimers()
// Check if the model parameters are valid before persisting // Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({ expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
assistants: [ id: 'assistant-1',
{ model: {
model: { engine: 'nitro',
parameters: { id: 'model-1',
max_tokens: 1000, parameters: {
token_limit: 1000, token_limit: 1000,
}, max_tokens: 1000,
settings: {
cpu_threads: 4,
ctx_len: 1024,
prompt_template: 'template',
llama_model_path: 'path',
mmproj: 'mmproj',
n_parallel: 2,
ngl: 12,
},
},
}, },
], settings: {
created: 0, cpu_threads: 4,
id: 'thread-1', ctx_len: 1024,
object: 'thread', prompt_template: 'template',
title: 'New Thread', llama_model_path: 'path',
updated: 0, mmproj: 'mmproj',
n_parallel: 2,
ngl: 12,
},
},
}) })
}) })
it('should update valid model parameters only', async () => { it('should update valid model parameters only', async () => {
;(useAtom as jest.Mock).mockReturnValue([
{
id: 'assistant-1',
},
jest.fn(),
])
const mockInvalidParameters: any = { const mockInvalidParameters: any = {
params: { params: {
// Inference // Inference
@ -225,8 +244,8 @@ describe('useUpdateModelParameters', () => {
// Spy functions // Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension) jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({}) jest.spyOn(extension, 'modifyThread').mockReturnValue({})
jest.spyOn(extension, 'modifyThreadAssistant').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters()) const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => { await act(async () => {
@ -235,80 +254,33 @@ describe('useUpdateModelParameters', () => {
mockInvalidParameters mockInvalidParameters
) )
}) })
jest.runAllTimers()
// Check if the model parameters are valid before persisting // Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({ expect(extension.modifyThreadAssistant).toHaveBeenCalledWith('thread-1', {
assistants: [ id: 'assistant-1',
{ model: {
model: { engine: 'nitro',
parameters: { id: 'model-1',
stop: ['<eos>'], parameters: {
top_k: 0.7, stop: ['<eos>'],
top_p: 0.1, top_k: 0.7,
stream: true, top_p: 0.1,
token_limit: 100, stream: true,
max_tokens: 1000, token_limit: 100,
presence_penalty: 0.2, max_tokens: 1000,
}, presence_penalty: 0.2,
settings: { },
ctx_len: 1024, settings: {
ngl: 0, ctx_len: 1024,
n_parallel: 2, ngl: 0,
cpu_threads: 4, n_parallel: 2,
prompt_template: 'template', cpu_threads: 4,
llama_model_path: 'path', prompt_template: 'template',
mmproj: 'mmproj', llama_model_path: 'path',
}, mmproj: 'mmproj',
},
}, },
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
})
})
it('should handle missing modelId and engine gracefully', async () => {
const mockParametersWithoutModelIdAndEngine: any = {
params: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
}, },
}
// Spy functions
jest.spyOn(extensionManager, 'get').mockReturnValue(extension)
jest.spyOn(extension, 'saveThread').mockReturnValue({})
const { result } = renderHook(() => useUpdateModelParameters())
await act(async () => {
await result.current.updateModelParameter(
mockThread,
mockParametersWithoutModelIdAndEngine
)
})
// Check if the model parameters are valid before persisting
expect(extension.saveThread).toHaveBeenCalledWith({
assistants: [
{
model: {
parameters: {
stop: ['<eos>', '<eos2>'],
temperature: 0.5,
},
settings: {},
},
},
],
created: 0,
id: 'thread-1',
object: 'thread',
title: 'New Thread',
updated: 0,
}) })
}) })
}) })

View File

@ -82,6 +82,7 @@ export default function useUpdateModelParameters() {
}, },
} }
setActiveAssistant(assistantInfo) setActiveAssistant(assistantInfo)
updateAssistantCallback(thread.id, assistantInfo) updateAssistantCallback(thread.id, assistantInfo)
}, },
[ [

View File

@ -7,6 +7,8 @@ import { useAtomValue, useSetAtom } from 'jotai'
import { useActiveModel } from '@/hooks/useActiveModel' import { useActiveModel } from '@/hooks/useActiveModel'
import { useCreateNewThread } from '@/hooks/useCreateNewThread' import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import AssistantSetting from './index' import AssistantSetting from './index'
import { activeThreadAtom } from '@/helpers/atoms/Thread.atom'
import { activeAssistantAtom } from '@/helpers/atoms/Assistant.atom'
jest.mock('jotai', () => { jest.mock('jotai', () => {
const originalModule = jest.requireActual('jotai') const originalModule = jest.requireActual('jotai')
@ -68,6 +70,7 @@ describe('AssistantSetting Component', () => {
beforeEach(() => { beforeEach(() => {
jest.clearAllMocks() jest.clearAllMocks()
jest.useFakeTimers()
}) })
test('renders AssistantSetting component with proper data', async () => { test('renders AssistantSetting component with proper data', async () => {
@ -75,7 +78,14 @@ describe('AssistantSetting Component', () => {
;(useSetAtom as jest.Mock).mockImplementationOnce( ;(useSetAtom as jest.Mock).mockImplementationOnce(
() => setEngineParamsUpdate () => setEngineParamsUpdate
) )
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) ;(useAtomValue as jest.Mock).mockImplementation((atom) => {
switch (atom) {
case activeThreadAtom:
return mockActiveThread
case activeAssistantAtom:
return {}
}
})
const updateThreadMetadata = jest.fn() const updateThreadMetadata = jest.fn()
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() }) ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel: jest.fn() })
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({ ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({
@ -98,7 +108,14 @@ describe('AssistantSetting Component', () => {
const setEngineParamsUpdate = jest.fn() const setEngineParamsUpdate = jest.fn()
const updateThreadMetadata = jest.fn() const updateThreadMetadata = jest.fn()
const stopModel = jest.fn() const stopModel = jest.fn()
;(useAtomValue as jest.Mock).mockImplementationOnce(() => mockActiveThread) ;(useAtomValue as jest.Mock).mockImplementation((atom) => {
switch (atom) {
case activeThreadAtom:
return mockActiveThread
case activeAssistantAtom:
return {}
}
})
;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate) ;(useSetAtom as jest.Mock).mockImplementation(() => setEngineParamsUpdate)
;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel }) ;(useActiveModel as jest.Mock).mockReturnValueOnce({ stopModel })
;(useCreateNewThread as jest.Mock).mockReturnValueOnce({ ;(useCreateNewThread as jest.Mock).mockReturnValueOnce({