test: improve completion.test for proactive screenshot handling and formatting

Add import for `captureProactiveScreenshots`, correct mock response formatting, and update test expectations to match the new API.
Enhance coverage by adding scenarios for screenshot capture errors, abort controller handling, and proactive mode toggling.  These changes provide clearer, more robust tests for the completion logic.
This commit is contained in:
Akarshan 2025-10-30 12:18:49 +05:30
parent 98d81819c5
commit 74b895c653
No known key found for this signature in database
GPG Key ID: D75C9634A870665F

View File

@ -9,7 +9,7 @@ import {
normalizeTools,
extractToolCall,
postMessageProcessing,
captureProactiveScreenshots
captureProactiveScreenshots,
} from '../completion'
// Mock dependencies
@ -87,10 +87,12 @@ vi.mock('@/hooks/useServiceHub', () => ({
})),
rag: vi.fn(() => ({
getToolNames: vi.fn(() => Promise.resolve([])),
callTool: vi.fn(() => Promise.resolve({
content: [{ type: 'text', text: 'mock rag result' }],
error: '',
})),
callTool: vi.fn(() =>
Promise.resolve({
content: [{ type: 'text', text: 'mock rag result' }],
error: '',
})
),
})),
})),
}))
@ -133,13 +135,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text')
expect(result.role).toBe('user')
expect(result.thread_id).toBe('thread-123')
expect(result.content).toEqual([{
type: 'text',
text: {
value: 'Hello world',
annotations: [],
expect(result.content).toEqual([
{
type: 'text',
text: {
value: 'Hello world',
annotations: [],
},
},
}])
])
})
it('should handle empty text', () => {
@ -147,13 +151,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text')
expect(result.role).toBe('user')
expect(result.content).toEqual([{
type: 'text',
text: {
value: '',
annotations: [],
expect(result.content).toEqual([
{
type: 'text',
text: {
value: '',
annotations: [],
},
},
}])
])
})
})
@ -164,13 +170,15 @@ describe('completion.ts', () => {
expect(result.type).toBe('text')
expect(result.role).toBe('assistant')
expect(result.thread_id).toBe('thread-123')
expect(result.content).toEqual([{
type: 'text',
text: {
value: 'AI response',
annotations: [],
expect(result.content).toEqual([
{
type: 'text',
text: {
value: 'AI response',
annotations: [],
},
},
}])
])
})
})
@ -207,16 +215,20 @@ describe('completion.ts', () => {
describe('extractToolCall', () => {
it('should extract tool calls from message', () => {
const message = {
choices: [{
delta: {
tool_calls: [{
id: 'call_1',
type: 'function',
index: 0,
function: { name: 'test', arguments: '{}' }
}]
}
}]
choices: [
{
delta: {
tool_calls: [
{
id: 'call_1',
type: 'function',
index: 0,
function: { name: 'test', arguments: '{}' },
},
],
},
},
],
}
const calls = []
const result = extractToolCall(message, null, calls)
@ -226,9 +238,11 @@ describe('completion.ts', () => {
it('should handle message without tool calls', () => {
const message = {
choices: [{
delta: {}
}]
choices: [
{
delta: {},
},
],
}
const calls = []
const result = extractToolCall(message, null, calls)
@ -245,23 +259,31 @@ describe('completion.ts', () => {
const mockMcp = {
getTools: mockGetTools,
callToolWithCancellation: vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(),
}))
})),
}
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => mockMcp,
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: {
name: 'browserbase_navigate',
arguments: '{"url": "test.com"}',
},
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
@ -284,30 +306,44 @@ describe('completion.ts', () => {
it('should detect browserbase tools', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(),
}))
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: () => Promise.resolve([]),
callToolWithCancellation: mockCallTool
callToolWithCancellation: mockCallTool,
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_screenshot', arguments: '{}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_screenshot', arguments: '{}' },
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
expect(mockCallTool).toHaveBeenCalled()
})
@ -315,30 +351,47 @@ describe('completion.ts', () => {
it('should detect multi_browserbase tools', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(),
}))
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: () => Promise.resolve([]),
callToolWithCancellation: mockCallTool
callToolWithCancellation: mockCallTool,
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'multi_browserbase_stagehand_navigate', arguments: '{}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: {
name: 'multi_browserbase_stagehand_navigate',
arguments: '{}',
},
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
expect(mockCallTool).toHaveBeenCalled()
})
@ -350,26 +403,40 @@ describe('completion.ts', () => {
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: vi.fn(() => ({
promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
promise: Promise.resolve({
content: [{ type: 'text', text: 'result' }],
error: '',
}),
cancel: vi.fn(),
}))
})),
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' },
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
await postMessageProcessing(
calls,
builder,
message,
abortController,
{},
undefined,
false,
true
)
// Proactive screenshots should not be called for non-browser tools
expect(mockGetTools).not.toHaveBeenCalled()
@ -380,7 +447,9 @@ describe('completion.ts', () => {
it('should capture screenshot and snapshot when available', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockScreenshotResult = {
content: [{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' }],
content: [
{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' },
],
error: '',
}
const mockSnapshotResult = {
@ -388,11 +457,14 @@ describe('completion.ts', () => {
error: '',
}
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} },
{ name: 'browserbase_snapshot', inputSchema: {} }
]))
const mockCallTool = vi.fn()
const mockGetTools = vi.fn(() =>
Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} },
{ name: 'browserbase_snapshot', inputSchema: {} },
])
)
const mockCallTool = vi
.fn()
.mockReturnValueOnce({
promise: Promise.resolve(mockScreenshotResult),
cancel: vi.fn(),
@ -405,8 +477,8 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
})
callToolWithCancellation: mockCallTool,
}),
} as any)
const abortController = new AbortController()
@ -420,15 +492,15 @@ describe('completion.ts', () => {
it('should handle missing screenshot tool gracefully', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'some_other_tool', inputSchema: {} }
]))
const mockGetTools = vi.fn(() =>
Promise.resolve([{ name: 'some_other_tool', inputSchema: {} }])
)
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: vi.fn()
})
callToolWithCancellation: vi.fn(),
}),
} as any)
const abortController = new AbortController()
@ -439,9 +511,9 @@ describe('completion.ts', () => {
it('should handle screenshot capture errors gracefully', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} }
]))
const mockGetTools = vi.fn(() =>
Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
)
const mockCallTool = vi.fn(() => ({
promise: Promise.reject(new Error('Screenshot failed')),
cancel: vi.fn(),
@ -450,8 +522,8 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
})
callToolWithCancellation: mockCallTool,
}),
} as any)
const abortController = new AbortController()
@ -463,22 +535,30 @@ describe('completion.ts', () => {
it('should respect abort controller', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} }
]))
const mockGetTools = vi.fn(() =>
Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
)
const mockCallTool = vi.fn(() => ({
promise: new Promise((resolve) => setTimeout(() => resolve({
content: [{ type: 'image', data: 'base64', mimeType: 'image/png' }],
error: '',
}), 100)),
promise: new Promise((resolve) =>
setTimeout(
() =>
resolve({
content: [
{ type: 'image', data: 'base64', mimeType: 'image/png' },
],
error: '',
}),
100
)
),
cancel: vi.fn(),
}))
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
})
callToolWithCancellation: mockCallTool,
}),
} as any)
const abortController = new AbortController()
@ -500,12 +580,15 @@ describe('completion.ts', () => {
role: 'tool',
content: [
{ type: 'text', text: 'Tool result' },
{ type: 'image_url', image_url: { url: 'data:image/png;base64,old' } }
{
type: 'image_url',
image_url: { url: 'data:image/png;base64,old' },
},
],
tool_call_id: 'old_call'
tool_call_id: 'old_call',
},
{ role: 'assistant', content: 'Response' },
]
],
}
expect(builder.messages).toHaveLength(3)
@ -517,13 +600,19 @@ describe('completion.ts', () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockScreenshotResult = {
content: [{ type: 'image', data: 'proactive_screenshot', mimeType: 'image/png' }],
content: [
{
type: 'image',
data: 'proactive_screenshot',
mimeType: 'image/png',
},
],
error: '',
}
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} }
]))
const mockGetTools = vi.fn(() =>
Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
)
let callCount = 0
const mockCallTool = vi.fn(() => {
@ -549,19 +638,24 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
callToolWithCancellation: mockCallTool,
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: {
name: 'browserbase_navigate',
arguments: '{"url": "test.com"}',
},
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
@ -574,7 +668,13 @@ describe('completion.ts', () => {
{},
undefined,
false,
true
undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
true // isProactiveMode - Correctly set to true
)
// Should have called: 1) browser tool, 2) getTools, 3) proactive screenshot
@ -586,9 +686,9 @@ describe('completion.ts', () => {
it('should not trigger proactive screenshots when mode is disabled', async () => {
const { getServiceHub } = await import('@/hooks/useServiceHub')
const mockGetTools = vi.fn(() => Promise.resolve([
{ name: 'browserbase_screenshot', inputSchema: {} }
]))
const mockGetTools = vi.fn(() =>
Promise.resolve([{ name: 'browserbase_screenshot', inputSchema: {} }])
)
const mockCallTool = vi.fn(() => ({
promise: Promise.resolve({
@ -601,19 +701,21 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
callToolWithCancellation: mockCallTool,
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_navigate', arguments: '{}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: { name: 'browserbase_navigate', arguments: '{}' },
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
@ -626,7 +728,13 @@ describe('completion.ts', () => {
{},
undefined,
false,
false
undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
false // isProactiveMode - Correctly set to false
)
expect(mockCallTool).toHaveBeenCalledTimes(1)
@ -648,19 +756,21 @@ describe('completion.ts', () => {
vi.mocked(getServiceHub).mockReturnValue({
mcp: () => ({
getTools: mockGetTools,
callToolWithCancellation: mockCallTool
callToolWithCancellation: mockCallTool,
}),
rag: () => ({ getToolNames: () => Promise.resolve([]) })
rag: () => ({ getToolNames: () => Promise.resolve([]) }),
} as any)
const calls = [{
id: 'call_1',
type: 'function' as const,
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' }
}]
const calls = [
{
id: 'call_1',
type: 'function' as const,
function: { name: 'fetch_url', arguments: '{"url": "test.com"}' },
},
]
const builder = {
addToolMessage: vi.fn(),
getMessages: vi.fn(() => [])
getMessages: vi.fn(() => []),
} as any
const message = { thread_id: 'test-thread', metadata: {} } as any
const abortController = new AbortController()
@ -673,7 +783,13 @@ describe('completion.ts', () => {
{},
undefined,
false,
true
undefined, // thread
undefined, // provider
[], // tools
undefined, // updateStreamingUI
undefined, // maxToolSteps
undefined, // currentStepCount
true // isProactiveMode - Still set to true, but the non-browser tool should skip the proactive step
)
expect(mockCallTool).toHaveBeenCalledTimes(1)