feat: Add tests for the Continuing with AI response
This commit is contained in:
parent
ccca331d6c
commit
52f73af08c
@ -1,6 +1,20 @@
|
|||||||
import { renderHook, act } from '@testing-library/react'
|
import { renderHook, act, waitFor } from '@testing-library/react'
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||||
import { useChat } from '../useChat'
|
import { useChat } from '../useChat'
|
||||||
|
import * as completionLib from '@/lib/completion'
|
||||||
|
import * as messagesLib from '@/lib/messages'
|
||||||
|
import { MessageStatus, ContentType } from '@janhq/core'
|
||||||
|
|
||||||
|
// Store mock functions for assertions
|
||||||
|
let mockAddMessage: ReturnType<typeof vi.fn>
|
||||||
|
let mockUpdateMessage: ReturnType<typeof vi.fn>
|
||||||
|
let mockGetMessages: ReturnType<typeof vi.fn>
|
||||||
|
let mockStartModel: ReturnType<typeof vi.fn>
|
||||||
|
let mockSendCompletion: ReturnType<typeof vi.fn>
|
||||||
|
let mockPostMessageProcessing: ReturnType<typeof vi.fn>
|
||||||
|
let mockCompletionMessagesBuilder: any
|
||||||
|
let mockSetPrompt: ReturnType<typeof vi.fn>
|
||||||
|
let mockResetTokenSpeed: ReturnType<typeof vi.fn>
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('../usePrompt', () => ({
|
vi.mock('../usePrompt', () => ({
|
||||||
@ -8,11 +22,16 @@ vi.mock('../usePrompt', () => ({
|
|||||||
(selector: any) => {
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
prompt: 'test prompt',
|
prompt: 'test prompt',
|
||||||
setPrompt: vi.fn(),
|
setPrompt: mockSetPrompt,
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
{ getState: () => ({ prompt: 'test prompt', setPrompt: vi.fn() }) }
|
{
|
||||||
|
getState: () => ({
|
||||||
|
prompt: 'test prompt',
|
||||||
|
setPrompt: mockSetPrompt
|
||||||
|
})
|
||||||
|
}
|
||||||
),
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -22,25 +41,29 @@ vi.mock('../useAppState', () => ({
|
|||||||
const state = {
|
const state = {
|
||||||
tools: [],
|
tools: [],
|
||||||
updateTokenSpeed: vi.fn(),
|
updateTokenSpeed: vi.fn(),
|
||||||
resetTokenSpeed: vi.fn(),
|
resetTokenSpeed: mockResetTokenSpeed,
|
||||||
updateTools: vi.fn(),
|
updateTools: vi.fn(),
|
||||||
updateStreamingContent: vi.fn(),
|
updateStreamingContent: vi.fn(),
|
||||||
updatePromptProgress: vi.fn(),
|
updatePromptProgress: vi.fn(),
|
||||||
updateLoadingModel: vi.fn(),
|
updateLoadingModel: vi.fn(),
|
||||||
setAbortController: vi.fn(),
|
setAbortController: vi.fn(),
|
||||||
|
streamingContent: undefined,
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
getState: vi.fn(() => ({
|
getState: vi.fn(() => ({
|
||||||
|
tools: [],
|
||||||
tokenSpeed: { tokensPerSecond: 10 },
|
tokenSpeed: { tokensPerSecond: 10 },
|
||||||
|
streamingContent: undefined,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useAssistant', () => ({
|
vi.mock('../useAssistant', () => ({
|
||||||
useAssistant: (selector: any) => {
|
useAssistant: Object.assign(
|
||||||
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
assistants: [{
|
assistants: [{
|
||||||
id: 'test-assistant',
|
id: 'test-assistant',
|
||||||
@ -55,6 +78,21 @@ vi.mock('../useAssistant', () => ({
|
|||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
getState: () => ({
|
||||||
|
assistants: [{
|
||||||
|
id: 'test-assistant',
|
||||||
|
instructions: 'test instructions',
|
||||||
|
parameters: { stream: true },
|
||||||
|
}],
|
||||||
|
currentAssistant: {
|
||||||
|
id: 'test-assistant',
|
||||||
|
instructions: 'test instructions',
|
||||||
|
parameters: { stream: true },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useModelProvider', () => ({
|
vi.mock('../useModelProvider', () => ({
|
||||||
@ -62,14 +100,15 @@ vi.mock('../useModelProvider', () => ({
|
|||||||
(selector: any) => {
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
getProviderByName: vi.fn(() => ({
|
getProviderByName: vi.fn(() => ({
|
||||||
provider: 'openai',
|
provider: 'llamacpp',
|
||||||
models: [],
|
models: [],
|
||||||
|
settings: [],
|
||||||
})),
|
})),
|
||||||
selectedModel: {
|
selectedModel: {
|
||||||
id: 'test-model',
|
id: 'test-model',
|
||||||
capabilities: ['tools'],
|
capabilities: ['tools'],
|
||||||
},
|
},
|
||||||
selectedProvider: 'openai',
|
selectedProvider: 'llamacpp',
|
||||||
updateProvider: vi.fn(),
|
updateProvider: vi.fn(),
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
@ -77,14 +116,15 @@ vi.mock('../useModelProvider', () => ({
|
|||||||
{
|
{
|
||||||
getState: () => ({
|
getState: () => ({
|
||||||
getProviderByName: vi.fn(() => ({
|
getProviderByName: vi.fn(() => ({
|
||||||
provider: 'openai',
|
provider: 'llamacpp',
|
||||||
models: [],
|
models: [],
|
||||||
|
settings: [],
|
||||||
})),
|
})),
|
||||||
selectedModel: {
|
selectedModel: {
|
||||||
id: 'test-model',
|
id: 'test-model',
|
||||||
capabilities: ['tools'],
|
capabilities: ['tools'],
|
||||||
},
|
},
|
||||||
selectedProvider: 'openai',
|
selectedProvider: 'llamacpp',
|
||||||
updateProvider: vi.fn(),
|
updateProvider: vi.fn(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -96,11 +136,11 @@ vi.mock('../useThreads', () => ({
|
|||||||
const state = {
|
const state = {
|
||||||
getCurrentThread: vi.fn(() => ({
|
getCurrentThread: vi.fn(() => ({
|
||||||
id: 'test-thread',
|
id: 'test-thread',
|
||||||
model: { id: 'test-model', provider: 'openai' },
|
model: { id: 'test-model', provider: 'llamacpp' },
|
||||||
})),
|
})),
|
||||||
createThread: vi.fn(() => Promise.resolve({
|
createThread: vi.fn(() => Promise.resolve({
|
||||||
id: 'test-thread',
|
id: 'test-thread',
|
||||||
model: { id: 'test-model', provider: 'openai' },
|
model: { id: 'test-model', provider: 'llamacpp' },
|
||||||
})),
|
})),
|
||||||
updateThreadTimestamp: vi.fn(),
|
updateThreadTimestamp: vi.fn(),
|
||||||
}
|
}
|
||||||
@ -111,15 +151,18 @@ vi.mock('../useThreads', () => ({
|
|||||||
vi.mock('../useMessages', () => ({
|
vi.mock('../useMessages', () => ({
|
||||||
useMessages: (selector: any) => {
|
useMessages: (selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
getMessages: vi.fn(() => []),
|
getMessages: mockGetMessages,
|
||||||
addMessage: vi.fn(),
|
addMessage: mockAddMessage,
|
||||||
|
updateMessage: mockUpdateMessage,
|
||||||
|
setMessages: vi.fn(),
|
||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useToolApproval', () => ({
|
vi.mock('../useToolApproval', () => ({
|
||||||
useToolApproval: (selector: any) => {
|
useToolApproval: Object.assign(
|
||||||
|
(selector: any) => {
|
||||||
const state = {
|
const state = {
|
||||||
approvedTools: [],
|
approvedTools: [],
|
||||||
showApprovalModal: vi.fn(),
|
showApprovalModal: vi.fn(),
|
||||||
@ -127,6 +170,14 @@ vi.mock('../useToolApproval', () => ({
|
|||||||
}
|
}
|
||||||
return selector ? selector(state) : state
|
return selector ? selector(state) : state
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
getState: () => ({
|
||||||
|
approvedTools: [],
|
||||||
|
showApprovalModal: vi.fn(),
|
||||||
|
allowAllMCPPermissions: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('../useToolAvailable', () => ({
|
vi.mock('../useToolAvailable', () => ({
|
||||||
@ -162,38 +213,57 @@ vi.mock('@tanstack/react-router', () => ({
|
|||||||
})),
|
})),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('../useServiceHub', () => ({
|
||||||
|
useServiceHub: vi.fn(() => ({
|
||||||
|
models: () => ({
|
||||||
|
startModel: mockStartModel,
|
||||||
|
stopModel: vi.fn(() => Promise.resolve()),
|
||||||
|
stopAllModels: vi.fn(() => Promise.resolve()),
|
||||||
|
}),
|
||||||
|
providers: () => ({
|
||||||
|
updateSettings: vi.fn(() => Promise.resolve()),
|
||||||
|
}),
|
||||||
|
})),
|
||||||
|
}))
|
||||||
|
|
||||||
vi.mock('@/lib/completion', () => ({
|
vi.mock('@/lib/completion', () => ({
|
||||||
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
emptyThreadContent: { thread_id: 'test-thread', content: '' },
|
||||||
extractToolCall: vi.fn(),
|
extractToolCall: vi.fn(),
|
||||||
newUserThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'user message' })),
|
newUserThreadContent: vi.fn((threadId, content) => ({
|
||||||
newAssistantThreadContent: vi.fn(() => ({ thread_id: 'test-thread', content: 'assistant message' })),
|
thread_id: threadId,
|
||||||
sendCompletion: vi.fn(),
|
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
|
||||||
postMessageProcessing: vi.fn(),
|
role: 'user'
|
||||||
isCompletionResponse: vi.fn(),
|
})),
|
||||||
|
newAssistantThreadContent: vi.fn((threadId, content) => ({
|
||||||
|
thread_id: threadId,
|
||||||
|
content: [{ type: ContentType.Text, text: { value: content, annotations: [] } }],
|
||||||
|
role: 'assistant'
|
||||||
|
})),
|
||||||
|
sendCompletion: mockSendCompletion,
|
||||||
|
postMessageProcessing: mockPostMessageProcessing,
|
||||||
|
isCompletionResponse: vi.fn(() => true),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/lib/messages', () => ({
|
vi.mock('@/lib/messages', () => ({
|
||||||
CompletionMessagesBuilder: vi.fn(() => ({
|
CompletionMessagesBuilder: vi.fn(() => mockCompletionMessagesBuilder),
|
||||||
addUserMessage: vi.fn(),
|
}))
|
||||||
addAssistantMessage: vi.fn(),
|
|
||||||
getMessages: vi.fn(() => []),
|
vi.mock('@/lib/instructionTemplate', () => ({
|
||||||
|
renderInstructions: vi.fn((instructions: string) => instructions),
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/utils/reasoning', () => ({
|
||||||
|
ReasoningProcessor: vi.fn(() => ({
|
||||||
|
processReasoningChunk: vi.fn(() => null),
|
||||||
|
finalize: vi.fn(() => ''),
|
||||||
})),
|
})),
|
||||||
|
extractReasoningFromMessage: vi.fn(() => null),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/services/mcp', () => ({
|
vi.mock('@/services/mcp', () => ({
|
||||||
getTools: vi.fn(() => Promise.resolve([])),
|
getTools: vi.fn(() => Promise.resolve([])),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/services/models', () => ({
|
|
||||||
startModel: vi.fn(() => Promise.resolve()),
|
|
||||||
stopModel: vi.fn(() => Promise.resolve()),
|
|
||||||
stopAllModels: vi.fn(() => Promise.resolve()),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@/services/providers', () => ({
|
|
||||||
updateSettings: vi.fn(() => Promise.resolve()),
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.mock('@tauri-apps/api/event', () => ({
|
vi.mock('@tauri-apps/api/event', () => ({
|
||||||
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
listen: vi.fn(() => Promise.resolve(vi.fn())),
|
||||||
}))
|
}))
|
||||||
@ -206,9 +276,37 @@ vi.mock('sonner', () => ({
|
|||||||
|
|
||||||
describe('useChat', () => {
|
describe('useChat', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
|
// Reset all mocks
|
||||||
|
mockAddMessage = vi.fn()
|
||||||
|
mockUpdateMessage = vi.fn()
|
||||||
|
mockGetMessages = vi.fn(() => [])
|
||||||
|
mockStartModel = vi.fn(() => Promise.resolve())
|
||||||
|
mockSetPrompt = vi.fn()
|
||||||
|
mockResetTokenSpeed = vi.fn()
|
||||||
|
mockSendCompletion = vi.fn(() => Promise.resolve({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: 'AI response',
|
||||||
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}))
|
||||||
|
mockPostMessageProcessing = vi.fn((toolCalls, builder, content) =>
|
||||||
|
Promise.resolve(content)
|
||||||
|
)
|
||||||
|
mockCompletionMessagesBuilder = {
|
||||||
|
addUserMessage: vi.fn(),
|
||||||
|
addAssistantMessage: vi.fn(),
|
||||||
|
getMessages: vi.fn(() => []),
|
||||||
|
}
|
||||||
|
|
||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.clearAllTimers()
|
||||||
|
})
|
||||||
|
|
||||||
it('returns sendMessage function', () => {
|
it('returns sendMessage function', () => {
|
||||||
const { result } = renderHook(() => useChat())
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
@ -216,13 +314,268 @@ describe('useChat', () => {
|
|||||||
expect(typeof result.current).toBe('function')
|
expect(typeof result.current).toBe('function')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('sends message successfully', async () => {
|
describe('Continue with AI response functionality', () => {
|
||||||
|
it('should add new user message when troubleshooting is true and no continueFromMessageId', async () => {
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('Hello world', true, undefined, undefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||||
|
'test-thread',
|
||||||
|
'Hello world',
|
||||||
|
undefined
|
||||||
|
)
|
||||||
|
expect(mockAddMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'user',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||||
|
'Hello world',
|
||||||
|
undefined
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should NOT add new user message when continueFromMessageId is provided', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('Continue', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).not.toHaveBeenCalled()
|
||||||
|
const userMessageCalls = mockAddMessage.mock.calls.filter(
|
||||||
|
(call: any) => call[0]?.role === 'user'
|
||||||
|
)
|
||||||
|
expect(userMessageCalls).toHaveLength(0)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should add partial assistant message to builder when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(mockCompletionMessagesBuilder.addAssistantMessage).toHaveBeenCalledWith(
|
||||||
|
'Partial response',
|
||||||
|
undefined,
|
||||||
|
[]
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should filter out stopped message from context when continuing', async () => {
|
||||||
|
const userMsg = {
|
||||||
|
id: 'msg-1',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'user',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Hello', annotations: [] } }],
|
||||||
|
}
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([userMsg, stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(messagesLib.CompletionMessagesBuilder).toHaveBeenCalledWith(
|
||||||
|
[userMsg], // stopped message filtered out
|
||||||
|
'test instructions'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should update existing message instead of adding new one when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should start with previous content when continuing', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial response', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
mockSendCompletion.mockResolvedValue({
|
||||||
|
choices: [{
|
||||||
|
message: {
|
||||||
|
content: ' continued',
|
||||||
|
role: 'assistant',
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
})
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
// The accumulated text should contain the previous content
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
content: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
text: expect.objectContaining({
|
||||||
|
value: expect.stringContaining('Partial response'),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
])
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle attachments correctly when not continuing', async () => {
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
const attachments = [
|
||||||
|
{
|
||||||
|
name: 'test.png',
|
||||||
|
type: 'image/png',
|
||||||
|
size: 1024,
|
||||||
|
base64: 'base64data',
|
||||||
|
dataUrl: 'data:image/png;base64,base64data',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('Message with attachment', true, attachments, undefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(completionLib.newUserThreadContent).toHaveBeenCalledWith(
|
||||||
|
'test-thread',
|
||||||
|
'Message with attachment',
|
||||||
|
attachments
|
||||||
|
)
|
||||||
|
expect(mockCompletionMessagesBuilder.addUserMessage).toHaveBeenCalledWith(
|
||||||
|
'Message with attachment',
|
||||||
|
attachments
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve message status as Ready after continuation completes', async () => {
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockUpdateMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
id: 'msg-123',
|
||||||
|
status: MessageStatus.Ready,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Normal message sending', () => {
|
||||||
|
it('sends message successfully without continuation', async () => {
|
||||||
const { result } = renderHook(() => useChat())
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
await act(async () => {
|
await act(async () => {
|
||||||
await result.current('Hello world')
|
await result.current('Hello world')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
expect(mockSendCompletion).toHaveBeenCalled()
|
||||||
|
expect(mockStartModel).toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error handling', () => {
|
||||||
|
it('should handle errors gracefully during continuation', async () => {
|
||||||
|
mockSendCompletion.mockRejectedValueOnce(new Error('API Error'))
|
||||||
|
const stoppedMessage = {
|
||||||
|
id: 'msg-123',
|
||||||
|
thread_id: 'test-thread',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: ContentType.Text, text: { value: 'Partial', annotations: [] } }],
|
||||||
|
status: MessageStatus.Stopped,
|
||||||
|
metadata: {},
|
||||||
|
}
|
||||||
|
mockGetMessages.mockReturnValue([stoppedMessage])
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useChat())
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current('', true, undefined, 'msg-123')
|
||||||
|
})
|
||||||
|
|
||||||
expect(result.current).toBeDefined()
|
expect(result.current).toBeDefined()
|
||||||
})
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@ -791,6 +791,7 @@ export const useChat = () => {
|
|||||||
updateStreamingContent,
|
updateStreamingContent,
|
||||||
updatePromptProgress,
|
updatePromptProgress,
|
||||||
addMessage,
|
addMessage,
|
||||||
|
updateMessage,
|
||||||
updateThreadTimestamp,
|
updateThreadTimestamp,
|
||||||
updateLoadingModel,
|
updateLoadingModel,
|
||||||
getDisabledToolsForThread,
|
getDisabledToolsForThread,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user