diff --git a/web-app/src/containers/__tests__/Capabilities.test.tsx b/web-app/src/containers/__tests__/Capabilities.test.tsx new file mode 100644 index 000000000..a5e60c600 --- /dev/null +++ b/web-app/src/containers/__tests__/Capabilities.test.tsx @@ -0,0 +1,124 @@ +import { describe, it, expect, vi } from 'vitest' +import { render, screen } from '@testing-library/react' +import Capabilities from '../Capabilities' + +// Mock Tooltip components +vi.mock('@/components/ui/tooltip', () => ({ + Tooltip: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipContent: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipProvider: ({ children }: { children: React.ReactNode }) =>
{children}
, + TooltipTrigger: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +// Mock Tabler icons +vi.mock('@tabler/icons-react', () => ({ + IconEye: () =>
Eye Icon
, + IconTool: () =>
Tool Icon
, + IconSparkles: () =>
Sparkles Icon
, + IconAtom: () =>
Atom Icon
, + IconWorld: () =>
World Icon
, + IconCodeCircle2: () =>
Code Icon
, +})) + +describe('Capabilities', () => { + it('should render vision capability with eye icon', () => { + render() + + const eyeIcon = screen.getByTestId('icon-eye') + expect(eyeIcon).toBeInTheDocument() + }) + + it('should render tools capability with tool icon', () => { + render() + + const toolIcon = screen.getByTestId('icon-tool') + expect(toolIcon).toBeInTheDocument() + }) + + it('should render proactive capability with sparkles icon', () => { + render() + + const sparklesIcon = screen.getByTestId('icon-sparkles') + expect(sparklesIcon).toBeInTheDocument() + }) + + it('should render reasoning capability with atom icon', () => { + render() + + const atomIcon = screen.getByTestId('icon-atom') + expect(atomIcon).toBeInTheDocument() + }) + + it('should render web_search capability with world icon', () => { + render() + + const worldIcon = screen.getByTestId('icon-world') + expect(worldIcon).toBeInTheDocument() + }) + + it('should render embeddings capability with code icon', () => { + render() + + const codeIcon = screen.getByTestId('icon-code') + expect(codeIcon).toBeInTheDocument() + }) + + it('should render multiple capabilities', () => { + render() + + expect(screen.getByTestId('icon-tool')).toBeInTheDocument() + expect(screen.getByTestId('icon-eye')).toBeInTheDocument() + expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument() + }) + + it('should render all capabilities in correct order', () => { + render() + + expect(screen.getByTestId('icon-tool')).toBeInTheDocument() + expect(screen.getByTestId('icon-eye')).toBeInTheDocument() + expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument() + expect(screen.getByTestId('icon-atom')).toBeInTheDocument() + expect(screen.getByTestId('icon-world')).toBeInTheDocument() + expect(screen.getByTestId('icon-code')).toBeInTheDocument() + }) + + it('should handle empty capabilities array', () => { + const { container } = render() + + expect(container.querySelector('[data-testid^="icon-"]')).not.toBeInTheDocument() + }) + + it('should handle unknown capabilities gracefully', () => { + const { container } = render() + + expect(container).toBeInTheDocument() + }) + + it('should display proactive tooltip with correct text', () => { + render() + + // The tooltip content should be 'Proactive' + expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument() + }) + + it('should render proactive icon between tools/vision and reasoning', () => { + const { container } = render() + + // All icons should be rendered + expect(screen.getByTestId('icon-tool')).toBeInTheDocument() + expect(screen.getByTestId('icon-eye')).toBeInTheDocument() + expect(screen.getByTestId('icon-sparkles')).toBeInTheDocument() + expect(screen.getByTestId('icon-atom')).toBeInTheDocument() + + expect(container.querySelector('[data-testid="icon-sparkles"]')).toBeInTheDocument() + }) + + it('should apply correct CSS classes to proactive icon', () => { + render() + + const sparklesIcon = screen.getByTestId('icon-sparkles') + expect(sparklesIcon).toBeInTheDocument() + // Icon should have size-3.5 class (same as tools, reasoning, etc.) + expect(sparklesIcon.parentElement).toBeInTheDocument() + }) +}) diff --git a/web-app/src/containers/__tests__/ChatInput.test.tsx b/web-app/src/containers/__tests__/ChatInput.test.tsx index 642313ec7..a1c24d3e3 100644 --- a/web-app/src/containers/__tests__/ChatInput.test.tsx +++ b/web-app/src/containers/__tests__/ChatInput.test.tsx @@ -437,4 +437,31 @@ describe('ChatInput', () => { expect(() => renderWithRouter()).not.toThrow() }) }) + + describe('Proactive Mode', () => { + it('should render ChatInput with proactive capable model', async () => { + await act(async () => { + renderWithRouter() + }) + + expect(screen.getByTestId('chat-input')).toBeInTheDocument() + }) + + it('should handle proactive capability detection', async () => { + await act(async () => { + renderWithRouter() + }) + + expect(screen.getByTestId('chat-input')).toBeInTheDocument() + }) + + it('should work with models that have multiple capabilities', async () => { + await act(async () => { + renderWithRouter() + }) + + expect(screen.getByTestId('chat-input')).toBeInTheDocument() + }) + + }) }) diff --git a/web-app/src/hooks/__tests__/useChat.test.ts b/web-app/src/hooks/__tests__/useChat.test.ts index e87191fb6..c7c576cf0 100644 --- a/web-app/src/hooks/__tests__/useChat.test.ts +++ b/web-app/src/hooks/__tests__/useChat.test.ts @@ -170,6 +170,7 @@ vi.mock('@/lib/completion', () => ({ sendCompletion: vi.fn(), postMessageProcessing: vi.fn(), isCompletionResponse: vi.fn(), + captureProactiveScreenshots: vi.fn(() => Promise.resolve([])), })) vi.mock('@/lib/messages', () => ({ @@ -225,4 +226,26 @@ describe('useChat', () => { expect(result.current).toBeDefined() }) + + describe('Proactive Mode', () => { + it('should detect proactive mode when model has proactive capability', () => { + const { result } = renderHook(() => useChat()) + + expect(result.current).toBeDefined() + expect(typeof result.current).toBe('function') + }) + + it('should handle model with tools, vision, and proactive capabilities', () => { + const { result } = renderHook(() => useChat()) + + expect(result.current).toBeDefined() + }) + + it('should work with models that have proactive capability', () => { + const { result } = renderHook(() => useChat()) + + expect(result.current).toBeDefined() + expect(typeof result.current).toBe('function') + }) + }) }) diff --git a/web-app/src/lib/__tests__/completion.test.ts b/web-app/src/lib/__tests__/completion.test.ts index 2b3ccaec7..f8fed4fec 100644 --- a/web-app/src/lib/__tests__/completion.test.ts +++ b/web-app/src/lib/__tests__/completion.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect, vi, beforeEach } from 'vitest' -import { +import { newUserThreadContent, newAssistantThreadContent, emptyThreadContent, @@ -8,7 +8,8 @@ import { stopModel, normalizeTools, extractToolCall, - postMessageProcessing + postMessageProcessing, + captureProactiveScreenshots } from '../completion' // Mock dependencies @@ -72,6 +73,54 @@ vi.mock('../extension', () => ({ ExtensionManager: {}, })) +vi.mock('@/hooks/useServiceHub', () => ({ + getServiceHub: vi.fn(() => ({ + mcp: vi.fn(() => ({ + getTools: vi.fn(() => Promise.resolve([])), + callToolWithCancellation: vi.fn(() => ({ + promise: Promise.resolve({ + content: [{ type: 'text', text: 'mock result' }], + error: '', + }), + cancel: vi.fn(), + })), + })), + rag: vi.fn(() => ({ + getToolNames: vi.fn(() => Promise.resolve([])), + callTool: vi.fn(() => Promise.resolve({ + content: [{ type: 'text', text: 'mock rag result' }], + error: '', + })), + })), + })), +})) + +vi.mock('@/hooks/useAttachments', () => ({ + useAttachments: { + getState: vi.fn(() => ({ enabled: true })), + }, +})) + +vi.mock('@/hooks/useAppState', () => ({ + useAppState: { + getState: vi.fn(() => ({ + setCancelToolCall: vi.fn(), + })), + }, +})) + +vi.mock('@/lib/platform/const', () => ({ + PlatformFeatures: { + ATTACHMENTS: true, + }, +})) + +vi.mock('@/lib/platform/types', () => ({ + PlatformFeature: { + ATTACHMENTS: 'ATTACHMENTS', + }, +})) + describe('completion.ts', () => { beforeEach(() => { vi.clearAllMocks() @@ -187,4 +236,448 @@ describe('completion.ts', () => { expect(result.length).toBe(0) }) }) + + describe('Proactive Mode - Browser MCP Tool Detection', () => { + // We need to access the private function, so we'll test it through postMessageProcessing + it('should detect browser tool names with "browser" prefix', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockGetTools = vi.fn(() => Promise.resolve([])) + const mockMcp = { + getTools: mockGetTools, + callToolWithCancellation: vi.fn(() => ({ + promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), + cancel: vi.fn(), + })) + } + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => mockMcp, + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing( + calls, + builder, + message, + abortController, + {}, + undefined, + false, + true // isProactiveMode = true + ) + + // Verify tool was executed + expect(mockMcp.callToolWithCancellation).toHaveBeenCalled() + }) + + it('should detect browserbase tools', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockCallTool = vi.fn(() => ({ + promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), + cancel: vi.fn(), + })) + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: () => Promise.resolve([]), + callToolWithCancellation: mockCallTool + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'browserbase_screenshot', arguments: '{}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) + + expect(mockCallTool).toHaveBeenCalled() + }) + + it('should detect multi_browserbase tools', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockCallTool = vi.fn(() => ({ + promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), + cancel: vi.fn(), + })) + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: () => Promise.resolve([]), + callToolWithCancellation: mockCallTool + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'multi_browserbase_stagehand_navigate', arguments: '{}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) + + expect(mockCallTool).toHaveBeenCalled() + }) + + it('should not treat non-browser tools as browser tools', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockGetTools = vi.fn(() => Promise.resolve([])) + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: vi.fn(() => ({ + promise: Promise.resolve({ content: [{ type: 'text', text: 'result' }], error: '' }), + cancel: vi.fn(), + })) + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'fetch_url', arguments: '{"url": "test.com"}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing(calls, builder, message, abortController, {}, undefined, false, true) + + // Proactive screenshots should not be called for non-browser tools + expect(mockGetTools).not.toHaveBeenCalled() + }) + }) + + describe('Proactive Mode - Screenshot Capture', () => { + it('should capture screenshot and snapshot when available', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockScreenshotResult = { + content: [{ type: 'image', data: 'base64screenshot', mimeType: 'image/png' }], + error: '', + } + const mockSnapshotResult = { + content: [{ type: 'text', text: 'snapshot html' }], + error: '', + } + + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'browserbase_screenshot', inputSchema: {} }, + { name: 'browserbase_snapshot', inputSchema: {} } + ])) + const mockCallTool = vi.fn() + .mockReturnValueOnce({ + promise: Promise.resolve(mockScreenshotResult), + cancel: vi.fn(), + }) + .mockReturnValueOnce({ + promise: Promise.resolve(mockSnapshotResult), + cancel: vi.fn(), + }) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }) + } as any) + + const abortController = new AbortController() + const results = await captureProactiveScreenshots(abortController) + + expect(results).toHaveLength(2) + expect(results[0]).toEqual(mockScreenshotResult) + expect(results[1]).toEqual(mockSnapshotResult) + expect(mockCallTool).toHaveBeenCalledTimes(2) + }) + + it('should handle missing screenshot tool gracefully', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'some_other_tool', inputSchema: {} } + ])) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: vi.fn() + }) + } as any) + + const abortController = new AbortController() + const results = await captureProactiveScreenshots(abortController) + + expect(results).toHaveLength(0) + }) + + it('should handle screenshot capture errors gracefully', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'browserbase_screenshot', inputSchema: {} } + ])) + const mockCallTool = vi.fn(() => ({ + promise: Promise.reject(new Error('Screenshot failed')), + cancel: vi.fn(), + })) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }) + } as any) + + const abortController = new AbortController() + const results = await captureProactiveScreenshots(abortController) + + // Should return empty array on error, not throw + expect(results).toHaveLength(0) + }) + + it('should respect abort controller', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'browserbase_screenshot', inputSchema: {} } + ])) + const mockCallTool = vi.fn(() => ({ + promise: new Promise((resolve) => setTimeout(() => resolve({ + content: [{ type: 'image', data: 'base64', mimeType: 'image/png' }], + error: '', + }), 100)), + cancel: vi.fn(), + })) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }) + } as any) + + const abortController = new AbortController() + abortController.abort() + + const results = await captureProactiveScreenshots(abortController) + + // Should not attempt to capture if already aborted + expect(results).toHaveLength(0) + }) + }) + + describe('Proactive Mode - Screenshot Filtering', () => { + it('should filter out old image_url content from tool messages', () => { + const builder = { + messages: [ + { role: 'user', content: 'Hello' }, + { + role: 'tool', + content: [ + { type: 'text', text: 'Tool result' }, + { type: 'image_url', image_url: { url: '' } } + ], + tool_call_id: 'old_call' + }, + { role: 'assistant', content: 'Response' }, + ] + } + + expect(builder.messages).toHaveLength(3) + }) + }) + + describe('Proactive Mode - Integration', () => { + it('should trigger proactive screenshots after browser tool execution', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + + const mockScreenshotResult = { + content: [{ type: 'image', data: 'proactive_screenshot', mimeType: 'image/png' }], + error: '', + } + + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'browserbase_screenshot', inputSchema: {} } + ])) + + let callCount = 0 + const mockCallTool = vi.fn(() => { + callCount++ + if (callCount === 1) { + // First call: the browser tool itself + return { + promise: Promise.resolve({ + content: [{ type: 'text', text: 'navigated to page' }], + error: '', + }), + cancel: vi.fn(), + } + } else { + // Second call: proactive screenshot + return { + promise: Promise.resolve(mockScreenshotResult), + cancel: vi.fn(), + } + } + }) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'browserbase_navigate', arguments: '{"url": "test.com"}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing( + calls, + builder, + message, + abortController, + {}, + undefined, + false, + true + ) + + // Should have called: 1) browser tool, 2) getTools, 3) proactive screenshot + expect(mockCallTool).toHaveBeenCalledTimes(2) + expect(mockGetTools).toHaveBeenCalled() + expect(builder.addToolMessage).toHaveBeenCalledTimes(2) + }) + + it('should not trigger proactive screenshots when mode is disabled', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + + const mockGetTools = vi.fn(() => Promise.resolve([ + { name: 'browserbase_screenshot', inputSchema: {} } + ])) + + const mockCallTool = vi.fn(() => ({ + promise: Promise.resolve({ + content: [{ type: 'text', text: 'navigated' }], + error: '', + }), + cancel: vi.fn(), + })) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'browserbase_navigate', arguments: '{}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing( + calls, + builder, + message, + abortController, + {}, + undefined, + false, + false + ) + + expect(mockCallTool).toHaveBeenCalledTimes(1) + expect(mockGetTools).not.toHaveBeenCalled() + }) + + it('should not trigger proactive screenshots for non-browser tools', async () => { + const { getServiceHub } = await import('@/hooks/useServiceHub') + + const mockGetTools = vi.fn(() => Promise.resolve([])) + const mockCallTool = vi.fn(() => ({ + promise: Promise.resolve({ + content: [{ type: 'text', text: 'fetched data' }], + error: '', + }), + cancel: vi.fn(), + })) + + vi.mocked(getServiceHub).mockReturnValue({ + mcp: () => ({ + getTools: mockGetTools, + callToolWithCancellation: mockCallTool + }), + rag: () => ({ getToolNames: () => Promise.resolve([]) }) + } as any) + + const calls = [{ + id: 'call_1', + type: 'function' as const, + function: { name: 'fetch_url', arguments: '{"url": "test.com"}' } + }] + const builder = { + addToolMessage: vi.fn(), + getMessages: vi.fn(() => []) + } as any + const message = { thread_id: 'test-thread', metadata: {} } as any + const abortController = new AbortController() + + await postMessageProcessing( + calls, + builder, + message, + abortController, + {}, + undefined, + false, + true + ) + + expect(mockCallTool).toHaveBeenCalledTimes(1) + expect(mockGetTools).not.toHaveBeenCalled() + }) + }) })