test: improve completion.test for proactive screenshot handling and formatting

Add import for `captureProactiveScreenshots`, correct mock response formatting, and update test expectations to match the new API.
Enhance coverage by adding scenarios for screenshot capture errors, abort controller handling, and proactive mode toggling.  These changes provide clearer, more robust tests for the completion logic.
This commit is contained in:
Akarshan 2025-10-30 12:18:49 +05:30
parent 98d81819c5
commit 74b895c653
No known key found for this signature in database
GPG Key ID: D75C9634A870665F

View File

@ -9,7 +9,7 @@ import {
normalizeTools, normalizeTools,
extractToolCall, extractToolCall,
postMessageProcessing, postMessageProcessing,
captureProactiveScreenshots captureProactiveScreenshots,
} from '../completion' } from '../completion'
// Mock dependencies // Mock dependencies
@ -87,10 +87,12 @@ vi.mock('@/hooks/useServiceHub', () => ({
})), })),
rag: vi.fn(() => ({ rag: vi.fn(() => ({
getToolNames: vi.fn(() => Promise.resolve([])), getToolNames: vi.fn(() => Promise.resolve([])),
callTool: vi.fn(() => Promise.resolve({ callTool: vi.fn(() =>
content: [{ type: 'text', text: 'mock rag result' }], Promise.resolve({
error: '', content: [{ type: 'text', text: 'mock rag result' }],
})), error: '',
})
),
})), })),
})), })),
})) }))
@ -133,13 +135,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text') expect(result.type).toBe('text')
expect(result.role).toBe('user') expect(result.role).toBe('user')
expect(result.thread_id).toBe('thread-123') expect(result.thread_id).toBe('thread-123')
expect(result.content).toEqual([{ expect(result.content).toEqual([
type: 'text', {
text: { type: 'text',
value: 'Hello world', text: {
annotations: [], value: 'Hello world',
annotations: [],
},
}, },
}]) ])
}) })
it('should handle empty text', () => { it('should handle empty text', () => {
@ -147,13 +151,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text') expect(result.type).toBe('text')
expect(result.role).toBe('user') expect(result.role).toBe('user')
expect(result.content).toEqual([{ expect(result.content).toEqual([
type: 'text', {
text: { type: 'text',
value: '', text: {
annotations: [], value: '',
annotations: [],
},
}, },
}]) ])
}) })
}) })
@ -164,13 +170,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text') expect(result.type).toBe('text')
expect(result.role).toBe('assistant') expect(result.role).toBe('assistant')
expect(result.thread_id).toBe('thread-123') expect(result.thread_id).toBe('thread-123')
expect(result.content).toEqual([{ expect(result.content).toEqual([
type: 'text', {
text: { type: 'text',
value: 'AI response', text: {
annotations: [], value: 'AI response',
annotations: [],
},
}, },
}]) ])
}) })
}) })
@ -207,16 +215,20 @@ describe('completion.ts', () => {
describe('extractToolCall', () => { describe('extractToolCall', () => {
it('should extract tool calls from message', () => { it('should extract tool calls from message', () => {
const message = { const message = {
choices: [{ choices: [
delta: { {
tool_calls: [{ delta: {
id: 'call_1', tool_calls: [
type: 'function', {
index: 0, id: 'call_1',
function: { name: 'test', arguments: '{}' } type: 'function',
}] index: 0,
} function: { name: 'test', arguments: '{}' },
}] },
],
},
},
],
} }
const calls = [] const calls = []
const result = extractToolCall(message, null, calls) const result = extractToolCall(message, null, calls)
@ -226,9 +238,11 @@ describe('completion.ts', () => {
it('should handle message without tool calls', () => { it('should handle message without tool calls', () => {
const message = { const message = {
choices: [{ choices: [
delta: {} {
}] delta: {},
},
],
} }
const calls = [] const calls = []
const result = extractToolCall(message, null, calls) const result = extractToolCall(message, null, calls)
@ -245,23 +259,31 @@ describe('completion.ts', () => {
const mockMcp = { const mockMcp = {
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: vi.fn(() => ({ callToolWithCancellation: vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(), cancel: vi.fn(),
})) })),
} }
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => mockMcp, mcp: () => mockMcp,
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' } type: 'function' as const,
}] function: {
name: 'browserbase_navigate',
arguments: '{"url": "test.com"}',
},
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
@ -284,30 +306,44 @@ describe('completion.ts', () => {
it('should detect browserbase tools', async () => { it('should detect browserbase tools', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockCallTool = vi.fn(() => ({ const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(), cancel: vi.fn(),
})) }))
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: () => Promise.resolve([]), getTools: () => Promise.resolve([]),
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'browserbase_screenshot', arguments: '{}' } type: 'function' as const,
}] function: { name: 'browserbase_screenshot', arguments: '{}' },
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
expect(mockCallTool).toHaveBeenCalled() expect(mockCallTool).toHaveBeenCalled()
}) })
@ -315,30 +351,47 @@ describe('completion.ts', () => {
it('should detect multi_browserbase tools', async () => { it('should detect multi_browserbase tools', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockCallTool = vi.fn(() => ({ const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(), cancel: vi.fn(),
})) }))
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: () => Promise.resolve([]), getTools: () => Promise.resolve([]),
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'multi_browserbase_stagehand_navigate', arguments: '{}' } type: 'function' as const,
}] function: {
name: 'multi_browserbase_stagehand_navigate',
arguments: '{}',
},
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
expect(mockCallTool).toHaveBeenCalled() expect(mockCallTool).toHaveBeenCalled()
}) })
@ -350,26 +403,40 @@ describe('completion.ts', () => {
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: vi.fn(() => ({ callToolWithCancellation: vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(), cancel: vi.fn(),
})) })),
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' } type: 'function' as const,
}] function: { name: 'fetch_url', arguments: '{"url": "test.com"}' },
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
// Proactive screenshots should not be called for non-browser tools // Proactive screenshots should not be called for non-browser tools
expect(mockGetTools).not.toHaveBeenCalled() expect(mockGetTools).not.toHaveBeenCalled()
@ -380,7 +447,9 @@ describe('completion.ts', () => {
it('should capture screenshot and snapshot when available', async () => { it('should capture screenshot and snapshot when available', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockScreenshotResult = { const mockScreenshotResult = {
content: [{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' }], content: [
{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' },
],
error: '', error: '',
} }
const mockSnapshotResult = { const mockSnapshotResult = {
@ -388,11 +457,14 @@ describe('completion.ts', () => {
error: '', error: '',
} }
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'browserbase_screenshot', inputSchema: {} }, Promise.resolve([
{ name: 'browserbase_snapshot', inputSchema: {} } { name: 'browserbase_screenshot', inputSchema: {} },
])) { name: 'browserbase_snapshot', inputSchema: {} },
const mockCallTool = vi.fn() ])
)
const mockCallTool = vi
.fn()
.mockReturnValueOnce({ .mockReturnValueOnce({
promise: Promise.resolve(mockScreenshotResult), promise: Promise.resolve(mockScreenshotResult),
cancel: vi.fn(), cancel: vi.fn(),
@ -405,8 +477,8 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}) }),
} as any) } as any)
const abortController = new AbortController() const abortController = new AbortController()
@ -420,15 +492,15 @@ describe('completion.ts', () => {
it('should handle missing screenshot tool gracefully', async () => { it('should handle missing screenshot tool gracefully', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'some_other_tool', inputSchema: {} } Promise.resolve([{ name: 'some_other_tool', inputSchema: {} }])
])) )
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: vi.fn() callToolWithCancellation: vi.fn(),
}) }),
} as any) } as any)
const abortController = new AbortController() const abortController = new AbortController()
@ -439,9 +511,9 @@ describe('completion.ts', () => {
it('should handle screenshot capture errors gracefully', async () => { it('should handle screenshot capture errors gracefully', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'browserbase_screenshot', inputSchema: {} } Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
])) )
const mockCallTool = vi.fn(() => ({ const mockCallTool = vi.fn(() => ({
promise: Promise.reject(new Error('Screenshot failed')), promise: Promise.reject(new Error('Screenshot failed')),
cancel: vi.fn(), cancel: vi.fn(),
@ -450,8 +522,8 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}) }),
} as any) } as any)
const abortController = new AbortController() const abortController = new AbortController()
@ -463,22 +535,30 @@ describe('completion.ts', () => {
it('should respect abort controller', async () => { it('should respect abort controller', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'browserbase_screenshot', inputSchema: {} } Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
])) )
const mockCallTool = vi.fn(() => ({ const mockCallTool = vi.fn(() => ({
promise: new Promise((resolve) => setTimeout(() => resolve({ promise: new Promise((resolve) =>
content: [{ type: 'image', data: 'base64', mimeType: 'image/png' }], setTimeout(
error: '', () =>
}), 100)), resolve({
content: [
{ type: 'image', data: 'base64', mimeType: 'image/png' },
],
error: '',
}),
100
)
),
cancel: vi.fn(), cancel: vi.fn(),
})) }))
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}) }),
} as any) } as any)
const abortController = new AbortController() const abortController = new AbortController()
@ -500,12 +580,15 @@ describe('completion.ts', () => {
role: 'tool', role: 'tool',
content: [ content: [
{ type: 'text', text: 'Tool result' }, { type: 'text', text: 'Tool result' },
{ type: 'image_url', image_url: { url: '' } } {
type: 'image_url',
image_url: { url: '' },
},
], ],
tool_call_id: 'old_call' tool_call_id: 'old_call',
}, },
{ role: 'assistant', content: 'Response' }, { role: 'assistant', content: 'Response' },
] ],
} }
expect(builder.messages).toHaveLength(3) expect(builder.messages).toHaveLength(3)
@ -517,13 +600,19 @@ describe('completion.ts', () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockScreenshotResult = { const mockScreenshotResult = {
content: [{ type: 'image', data: 'proactive_screenshot', mimeType: 'image/png' }], content: [
{
type: 'image',
data: 'proactive_screenshot',
mimeType: 'image/png',
},
],
error: '', error: '',
} }
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'browserbase_screenshot', inputSchema: {} } Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
])) )
let callCount = 0 let callCount = 0
const mockCallTool = vi.fn(() => { const mockCallTool = vi.fn(() => {
@ -549,19 +638,24 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' } type: 'function' as const,
}] function: {
name: 'browserbase_navigate',
arguments: '{"url": "test.com"}',
},
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
@ -574,7 +668,13 @@ describe('completion.ts', () => {
{}, {},
undefined, undefined,
false, false,
true undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
true // isProactiveMode - Correctly set to true
) )
// Should have called: 1) browser tool, 2) getTools, 3) proactive screenshot // Should have called: 1) browser tool, 2) getTools, 3) proactive screenshot
@ -586,9 +686,9 @@ describe('completion.ts', () => {
it('should not trigger proactive screenshots when mode is disabled', async () => { it('should not trigger proactive screenshots when mode is disabled', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub') const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([ const mockGetTools = vi.fn(() =>
{ name: 'browserbase_screenshot', inputSchema: {} } Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
])) )
const mockCallTool = vi.fn(() => ({ const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({ promise: Promise.resolve({
@ -601,19 +701,21 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'browserbase_navigate', arguments: '{}' } type: 'function' as const,
}] function: { name: 'browserbase_navigate', arguments: '{}' },
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
@ -626,7 +728,13 @@ describe('completion.ts', () => {
{}, {},
undefined, undefined,
false, false,
false undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
false // isProactiveMode - Correctly set to false
) )
expect(mockCallTool).toHaveBeenCalledTimes(1) expect(mockCallTool).toHaveBeenCalledTimes(1)
@ -648,19 +756,21 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({ vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({ mcp: () => ({
getTools: mockGetTools, getTools: mockGetTools,
callToolWithCancellation: mockCallTool callToolWithCancellation: mockCallTool,
}), }),
rag: () => ({ getToolNames: () => Promise.resolve([]) }) rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any) } as any)
const calls = [{ const calls = [
id: 'call_1', {
type: 'function' as const, id: 'call_1',
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' } type: 'function' as const,
}] function: { name: 'fetch_url', arguments: '{"url": "test.com"}' },
},
]
const builder = { const builder = {
addToolMessage: vi.fn(), addToolMessage: vi.fn(),
getMessages: vi.fn(() => []) getMessages: vi.fn(() => []),
} as any } as any
const message = { thread_id: 'test-thread', metadata: {} } as any const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController() const abortController = new AbortController()
@ -673,7 +783,13 @@ describe('completion.ts', () => {
{}, {},
undefined, undefined,
false, false,
true undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
true // isProactiveMode - Still set to true, but the non-browser tool should skip the proactive step
) )
expect(mockCallTool).toHaveBeenCalledTimes(1) expect(mockCallTool).toHaveBeenCalledTimes(1)