diff --git a/web-app/src/containers/__tests__/Capabilities.test.tsx b/web-app/src/containers/__tests__/Capabilities.test.tsx
new file mode 100644
index 000000000..a5e60c600
--- /dev/null
+++ b/web-app/src/containers/__tests__/Capabilities.test.tsx
@@ -0,0 +1,124 @@
+import { describe, it, expect, vi } from 'vitest'
+import { render, screen } from '@testing-library/react'
+import Capabilities from '../Capabilities'
+
+// Mock Tooltip components
+vi.mock('@/components/ui/tooltip', () => ({
+ Tooltip: ({ children }: { children: React.ReactNode }) =>
{children}
,
+ TooltipContent: ({ children }: { children: React.ReactNode }) => {children}
,
+ TooltipProvider: ({ children }: { children: React.ReactNode }) => {children}
,
+ TooltipTrigger: ({ children }: { children: React.ReactNode }) => {children}
,
+}))
+
+// Mock Tabler icons
+vi.mock('@tabler/icons-react', () => ({
+ IconEye: () => Eye Icon
,
+ IconTool: () => Tool Icon
,
+ IconSparkles: () => Sparkles Icon
,
+ IconAtom: () => Atom Icon
,
+ IconWorld: () => World Icon
,
+ IconCodeCircle2: () => Code Icon
,
+}))
+
+describe('Capabilities', () => {
+ it('should render vision capability with eye icon', () => {
+ render()
+
+ const eyeIcon = screen.getByTestId('icon-eye')
+ expect(eyeIcon).toBeInTheDocument()
+ })
+
+ it('should render tools capability with tool icon', () => {
+ render()
+
+ const toolIcon = screen.getByTestId('icon-tool')
+ expect(toolIcon).toBeInTheDocument()
+ })
+
+ it('should render proactive capability with sparkles icon', () => {
+ render()
+
+ const sparklesIcon = screen.getByTestId('icon-sparkles')
+ expect(sparklesIcon).toBeInTheDocument()
+ })
+
+ it('should render reasoning capability with atom icon', () => {
+ render()
+
+ const atomIcon = screen.getByTestId('icon-atom')
+ expect(atomIcon).toBeInTheDocument()
+ })
+
+ it('should render web_search capability with world icon', () => {
+ render()
+
+ const worldIcon = screen.getByTestId('icon-world')
+ expect(worldIcon).toBeInTheDocument()
+ })
+
+ it('should render embeddings capability with code icon', () => {
+ render()
+
+ const codeIcon = screen.getByTestId('icon-code')
+ expect(codeIcon).toBeInTheDocument()
+ })
+
+ it('should render multiple capabilities', () => {
+ render()
+
+ expect(screen.getByTestId('icon-tool')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-eye')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument()
+ })
+
+ it('should render all capabilities in correct order', () => {
+ render()
+
+ expect(screen.getByTestId('icon-tool')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-eye')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-atom')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-world')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-code')).toBeInTheDocument()
+ })
+
+ it('should handle empty capabilities array', () => {
+ const { container } = render()
+
+ expect(container.querySelector('[data-testid^="icon-"]')).not.toBeInTheDocument()
+ })
+
+ it('should handle unknown capabilities gracefully', () => {
+ const { container } = render()
+
+ expect(container).toBeInTheDocument()
+ })
+
+ it('should display proactive tooltip with correct text', () => {
+ render()
+
+ // The tooltip content should be 'Proactive'
+ expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument()
+ })
+
+ it('should render proactive icon between tools/vision and reasoning', () => {
+ const { container } = render()
+
+ // All icons should be rendered
+ expect(screen.getByTestId('icon-tool')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-eye')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument()
+ expect(screen.getByTestId('icon-atom')).toBeInTheDocument()
+
+ expect(container.querySelector('[data-testid="icon-sparkles"]')).toBeInTheDocument()
+ })
+
+ it('should apply correct CSS classes to proactive icon', () => {
+ render()
+
+ const sparklesIcon = screen.getByTestId('icon-sparkles')
+ expect(sparklesIcon).toBeInTheDocument()
+ // Icon should have size-3.5 class (same as tools, reasoning, etc.)
+ expect(sparklesIcon.parentElement).toBeInTheDocument()
+ })
+})
diff --git a/web-app/src/containers/__tests__/ChatInput.test.tsx b/web-app/src/containers/__tests__/ChatInput.test.tsx
index 642313ec7..a1c24d3e3 100644
--- a/web-app/src/containers/__tests__/ChatInput.test.tsx
+++ b/web-app/src/containers/__tests__/ChatInput.test.tsx
@@ -437,4 +437,31 @@ describe('ChatInput', () => {
expect(() => renderWithRouter()).not.toThrow()
})
})
+
+ describe('Proactive Mode', () => {
+ it('should render ChatInput with proactive capable model', async () => {
+ await act(async () => {
+ renderWithRouter()
+ })
+
+ expect(screen.getByTestId('chat-input')).toBeInTheDocument()
+ })
+
+ it('should handle proactive capability detection', async () => {
+ await act(async () => {
+ renderWithRouter()
+ })
+
+ expect(screen.getByTestId('chat-input')).toBeInTheDocument()
+ })
+
+ it('should work with models that have multiple capabilities', async () => {
+ await act(async () => {
+ renderWithRouter()
+ })
+
+ expect(screen.getByTestId('chat-input')).toBeInTheDocument()
+ })
+
+ })
})
diff --git a/web-app/src/hooks/__tests__/useChat.test.ts b/web-app/src/hooks/__tests__/useChat.test.ts
index e87191fb6..c7c576cf0 100644
--- a/web-app/src/hooks/__tests__/useChat.test.ts
+++ b/web-app/src/hooks/__tests__/useChat.test.ts
@@ -170,6 +170,7 @@ vi.mock('@/lib/completion', () => ({
sendCompletion: vi.fn(),
postMessageProcessing: vi.fn(),
isCompletionResponse: vi.fn(),
+ captureProactiveScreenshots: vi.fn(() => Promise.resolve([])),
}))
vi.mock('@/lib/messages', () => ({
@@ -225,4 +226,26 @@ describe('useChat', () => {
expect(result.current).toBeDefined()
})
+
+ describe('Proactive Mode', () => {
+ it('should detect proactive mode when model has proactive capability', () => {
+ const { result } = renderHook(() => useChat())
+
+ expect(result.current).toBeDefined()
+ expect(typeof result.current).toBe('function')
+ })
+
+ it('should handle model with tools, vision, and proactive capabilities', () => {
+ const { result } = renderHook(() => useChat())
+
+ expect(result.current).toBeDefined()
+ })
+
+ it('should work with models that have proactive capability', () => {
+ const { result } = renderHook(() => useChat())
+
+ expect(result.current).toBeDefined()
+ expect(typeof result.current).toBe('function')
+ })
+ })
})
diff --git a/web-app/src/lib/__tests__/completion.test.ts b/web-app/src/lib/__tests__/completion.test.ts
index 2b3ccaec7..f8fed4fec 100644
--- a/web-app/src/lib/__tests__/completion.test.ts
+++ b/web-app/src/lib/__tests__/completion.test.ts
@@ -1,5 +1,5 @@
import { describe, it, expect, vi, beforeEach } from 'vitest'
-import {
+import {
newUserThreadContent,
newAssistantThreadContent,
emptyThreadContent,
@@ -8,7 +8,8 @@ import {
stopModel,
normalizeTools,
extractToolCall,
- postMessageProcessing
+ postMessageProcessing,
+ captureProactiveScreenshots
} from '../completion'
// Mock dependencies
@@ -72,6 +73,54 @@ vi.mock('../extension', () => ({
ExtensionManager: {},
}))
+vi.mock('@/hooks/useServiceHub', () => ({
+ getServiceHub: vi.fn(() => ({
+ mcp: vi.fn(() => ({
+ getTools: vi.fn(() => Promise.resolve([])),
+ callToolWithCancellation: vi.fn(() => ({
+ promise: Promise.resolve({
+ content: [{ type: 'text', text: 'mock result' }],
+ error: '',
+ }),
+ cancel: vi.fn(),
+ })),
+ })),
+ rag: vi.fn(() => ({
+ getToolNames: vi.fn(() => Promise.resolve([])),
+ callTool: vi.fn(() => Promise.resolve({
+ content: [{ type: 'text', text: 'mock rag result' }],
+ error: '',
+ })),
+ })),
+ })),
+}))
+
+vi.mock('@/hooks/useAttachments', () => ({
+ useAttachments: {
+ getState: vi.fn(() => ({ enabled: true })),
+ },
+}))
+
+vi.mock('@/hooks/useAppState', () => ({
+ useAppState: {
+ getState: vi.fn(() => ({
+ setCancelToolCall: vi.fn(),
+ })),
+ },
+}))
+
+vi.mock('@/lib/platform/const', () => ({
+ PlatformFeatures: {
+ ATTACHMENTS: true,
+ },
+}))
+
+vi.mock('@/lib/platform/types', () => ({
+ PlatformFeature: {
+ ATTACHMENTS: 'ATTACHMENTS',
+ },
+}))
+
describe('completion.ts', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -187,4 +236,448 @@ describe('completion.ts', () => {
expect(result.length).toBe(0)
})
})
+
+ describe('Proactive Mode - Browser MCP Tool Detection', () => {
+ // We need to access the private function, so we'll test it through postMessageProcessing
+ it('should detect browser tool names with "browser" prefix', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockGetTools = vi.fn(() => Promise.resolve([]))
+ const mockMcp = {
+ getTools: mockGetTools,
+ callToolWithCancellation: vi.fn(() => ({
+ promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
+ cancel: vi.fn(),
+ }))
+ }
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => mockMcp,
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(
+ calls,
+ builder,
+ message,
+ abortController,
+ {},
+ undefined,
+ false,
+ true // isProactiveMode = true
+ )
+
+ // Verify tool was executed
+ expect(mockMcp.callToolWithCancellation).toHaveBeenCalled()
+ })
+
+ it('should detect browserbase tools', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockCallTool = vi.fn(() => ({
+ promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
+ cancel: vi.fn(),
+ }))
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: () => Promise.resolve([]),
+ callToolWithCancellation: mockCallTool
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'browserbase_screenshot', arguments: '{}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
+
+ expect(mockCallTool).toHaveBeenCalled()
+ })
+
+ it('should detect multi_browserbase tools', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockCallTool = vi.fn(() => ({
+ promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
+ cancel: vi.fn(),
+ }))
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: () => Promise.resolve([]),
+ callToolWithCancellation: mockCallTool
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'multi_browserbase_stagehand_navigate', arguments: '{}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
+
+ expect(mockCallTool).toHaveBeenCalled()
+ })
+
+ it('should not treat non-browser tools as browser tools', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockGetTools = vi.fn(() => Promise.resolve([]))
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: vi.fn(() => ({
+ promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }),
+ cancel: vi.fn(),
+ }))
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'fetch_url', arguments: '{"url": "test.com"}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true)
+
+ // Proactive screenshots should not be called for non-browser tools
+ expect(mockGetTools).not.toHaveBeenCalled()
+ })
+ })
+
+ describe('Proactive Mode - Screenshot Capture', () => {
+ it('should capture screenshot and snapshot when available', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockScreenshotResult = {
+ content: [{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' }],
+ error: '',
+ }
+ const mockSnapshotResult = {
+ content: [{ type: 'text', text: 'snapshot html' }],
+ error: '',
+ }
+
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'browserbase_screenshot', inputSchema: {} },
+ { name: 'browserbase_snapshot', inputSchema: {} }
+ ]))
+ const mockCallTool = vi.fn()
+ .mockReturnValueOnce({
+ promise: Promise.resolve(mockScreenshotResult),
+ cancel: vi.fn(),
+ })
+ .mockReturnValueOnce({
+ promise: Promise.resolve(mockSnapshotResult),
+ cancel: vi.fn(),
+ })
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ })
+ } as any)
+
+ const abortController = new AbortController()
+ const results = await captureProactiveScreenshots(abortController)
+
+ expect(results).toHaveLength(2)
+ expect(results[0]).toEqual(mockScreenshotResult)
+ expect(results[1]).toEqual(mockSnapshotResult)
+ expect(mockCallTool).toHaveBeenCalledTimes(2)
+ })
+
+ it('should handle missing screenshot tool gracefully', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'some_other_tool', inputSchema: {} }
+ ]))
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: vi.fn()
+ })
+ } as any)
+
+ const abortController = new AbortController()
+ const results = await captureProactiveScreenshots(abortController)
+
+ expect(results).toHaveLength(0)
+ })
+
+ it('should handle screenshot capture errors gracefully', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'browserbase_screenshot', inputSchema: {} }
+ ]))
+ const mockCallTool = vi.fn(() => ({
+ promise: Promise.reject(new Error('Screenshot failed')),
+ cancel: vi.fn(),
+ }))
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ })
+ } as any)
+
+ const abortController = new AbortController()
+ const results = await captureProactiveScreenshots(abortController)
+
+ // Should return empty array on error, not throw
+ expect(results).toHaveLength(0)
+ })
+
+ it('should respect abort controller', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'browserbase_screenshot', inputSchema: {} }
+ ]))
+ const mockCallTool = vi.fn(() => ({
+ promise: new Promise((resolve) => setTimeout(() => resolve({
+ content: [{ type: 'image', data: 'base64', mimeType: 'image/png' }],
+ error: '',
+ }), 100)),
+ cancel: vi.fn(),
+ }))
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ })
+ } as any)
+
+ const abortController = new AbortController()
+ abortController.abort()
+
+ const results = await captureProactiveScreenshots(abortController)
+
+ // Should not attempt to capture if already aborted
+ expect(results).toHaveLength(0)
+ })
+ })
+
+ describe('Proactive Mode - Screenshot Filtering', () => {
+ it('should filter out old image_url content from tool messages', () => {
+ const builder = {
+ messages: [
+ { role: 'user', content: 'Hello' },
+ {
+ role: 'tool',
+ content: [
+ { type: 'text', text: 'Tool result' },
+ { type: 'image_url', image_url: { url: '' } }
+ ],
+ tool_call_id: 'old_call'
+ },
+ { role: 'assistant', content: 'Response' },
+ ]
+ }
+
+ expect(builder.messages).toHaveLength(3)
+ })
+ })
+
+ describe('Proactive Mode - Integration', () => {
+ it('should trigger proactive screenshots after browser tool execution', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+
+ const mockScreenshotResult = {
+ content: [{ type: 'image', data: 'proactive_screenshot', mimeType: 'image/png' }],
+ error: '',
+ }
+
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'browserbase_screenshot', inputSchema: {} }
+ ]))
+
+ let callCount = 0
+ const mockCallTool = vi.fn(() => {
+ callCount++
+ if (callCount === 1) {
+ // First call: the browser tool itself
+ return {
+ promise: Promise.resolve({
+ content: [{ type: 'text', text: 'navigated to page' }],
+ error: '',
+ }),
+ cancel: vi.fn(),
+ }
+ } else {
+ // Second call: proactive screenshot
+ return {
+ promise: Promise.resolve(mockScreenshotResult),
+ cancel: vi.fn(),
+ }
+ }
+ })
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(
+ calls,
+ builder,
+ message,
+ abortController,
+ {},
+ undefined,
+ false,
+ true
+ )
+
+ // Should have called: 1) browser tool, 2) getTools, 3) proactive screenshot
+ expect(mockCallTool).toHaveBeenCalledTimes(2)
+ expect(mockGetTools).toHaveBeenCalled()
+ expect(builder.addToolMessage).toHaveBeenCalledTimes(2)
+ })
+
+ it('should not trigger proactive screenshots when mode is disabled', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+
+ const mockGetTools = vi.fn(() => Promise.resolve([
+ { name: 'browserbase_screenshot', inputSchema: {} }
+ ]))
+
+ const mockCallTool = vi.fn(() => ({
+ promise: Promise.resolve({
+ content: [{ type: 'text', text: 'navigated' }],
+ error: '',
+ }),
+ cancel: vi.fn(),
+ }))
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'browserbase_navigate', arguments: '{}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(
+ calls,
+ builder,
+ message,
+ abortController,
+ {},
+ undefined,
+ false,
+ false
+ )
+
+ expect(mockCallTool).toHaveBeenCalledTimes(1)
+ expect(mockGetTools).not.toHaveBeenCalled()
+ })
+
+ it('should not trigger proactive screenshots for non-browser tools', async () => {
+ const { getServiceHub } = await import('@/hooks/useServiceHub')
+
+ const mockGetTools = vi.fn(() => Promise.resolve([]))
+ const mockCallTool = vi.fn(() => ({
+ promise: Promise.resolve({
+ content: [{ type: 'text', text: 'fetched data' }],
+ error: '',
+ }),
+ cancel: vi.fn(),
+ }))
+
+ vi.mocked(getServiceHub).mockReturnValue({
+ mcp: () => ({
+ getTools: mockGetTools,
+ callToolWithCancellation: mockCallTool
+ }),
+ rag: () => ({ getToolNames: () => Promise.resolve([]) })
+ } as any)
+
+ const calls = [{
+ id: 'call_1',
+ type: 'function' as const,
+ function: { name: 'fetch_url', arguments: '{"url": "test.com"}' }
+ }]
+ const builder = {
+ addToolMessage: vi.fn(),
+ getMessages: vi.fn(() => [])
+ } as any
+ const message = { thread_id: 'test-thread', metadata: {} } as any
+ const abortController = new AbortController()
+
+ await postMessageProcessing(
+ calls,
+ builder,
+ message,
+ abortController,
+ {},
+ undefined,
+ false,
+ true
+ )
+
+ expect(mockCallTool).toHaveBeenCalledTimes(1)
+ expect(mockGetTools).not.toHaveBeenCalled()
+ })
+ })
})