149 lines
4.5 KiB
Python
149 lines
4.5 KiB
Python
"""Tests for streaming functionality."""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from src.providers.openrouter import OpenRouterProvider
|
|
from src.providers.lmstudio import LMStudioProvider
|
|
from src.agent import DebateAgent
|
|
|
|
|
|
def test_openrouter_stream_yields_chunks():
|
|
"""Test that OpenRouter streaming yields chunks."""
|
|
# This is a mock test since we can't make real API calls
|
|
provider = OpenRouterProvider(
|
|
model="test-model",
|
|
api_key="test-key"
|
|
)
|
|
|
|
# Mock the streaming response
|
|
mock_chunks = [
|
|
Mock(choices=[Mock(delta=Mock(content="Hello "))]),
|
|
Mock(choices=[Mock(delta=Mock(content="world"))]),
|
|
Mock(choices=[Mock(delta=Mock(content="!"))]),
|
|
]
|
|
|
|
with patch.object(provider.client.chat.completions, 'create', return_value=iter(mock_chunks)):
|
|
messages = [{"role": "user", "content": "Test"}]
|
|
chunks = list(provider.generate_response_stream(messages))
|
|
|
|
assert chunks == ["Hello ", "world", "!"]
|
|
|
|
|
|
def test_agent_stream_accumulates_response():
|
|
"""Test that agent streaming accumulates response in memory."""
|
|
mock_provider = Mock()
|
|
mock_provider.generate_response_stream.return_value = iter(["Hello ", "world", "!"])
|
|
|
|
agent = DebateAgent(
|
|
name="Test Agent",
|
|
provider=mock_provider,
|
|
system_prompt="You are a test agent",
|
|
)
|
|
|
|
# Stream the response
|
|
chunks = list(agent.generate_response_stream())
|
|
|
|
# Check chunks were yielded
|
|
assert chunks == ["Hello ", "world", "!"]
|
|
|
|
# Check full response was added to memory
|
|
assert len(agent.memory) == 1
|
|
assert agent.memory[0]["role"] == "assistant"
|
|
assert agent.memory[0]["content"] == "Hello world!"
|
|
|
|
|
|
def test_agent_stream_with_existing_memory():
|
|
"""Test streaming with existing conversation memory."""
|
|
mock_provider = Mock()
|
|
mock_provider.generate_response_stream.return_value = iter(["Response"])
|
|
|
|
agent = DebateAgent(
|
|
name="Test Agent",
|
|
provider=mock_provider,
|
|
system_prompt="You are a test agent",
|
|
)
|
|
|
|
# Add some existing messages
|
|
agent.add_message("user", "First message")
|
|
agent.memory.append({"role": "assistant", "content": "First response"})
|
|
agent.add_message("user", "Second message")
|
|
|
|
# Stream the response
|
|
list(agent.generate_response_stream())
|
|
|
|
# Check memory has all messages
|
|
assert len(agent.memory) == 4
|
|
assert agent.memory[-1]["role"] == "assistant"
|
|
assert agent.memory[-1]["content"] == "Response"
|
|
|
|
|
|
def test_streaming_vs_non_streaming_same_result():
|
|
"""Test that streaming and non-streaming produce the same result."""
|
|
mock_provider = Mock()
|
|
|
|
# Set up mock for both methods
|
|
mock_provider.generate_response.return_value = "Complete response"
|
|
mock_provider.generate_response_stream.return_value = iter(["Complete ", "response"])
|
|
|
|
agent1 = DebateAgent("Agent1", mock_provider, "Prompt")
|
|
agent2 = DebateAgent("Agent2", mock_provider, "Prompt")
|
|
|
|
# Non-streaming
|
|
response1 = agent1.generate_response()
|
|
|
|
# Streaming
|
|
chunks = list(agent2.generate_response_stream())
|
|
response2 = ''.join(chunks)
|
|
|
|
# Both should produce same text
|
|
assert response1 == "Complete response"
|
|
assert response2 == "Complete response"
|
|
|
|
|
|
def test_empty_stream_handling():
|
|
"""Test handling of empty streams."""
|
|
mock_provider = Mock()
|
|
mock_provider.generate_response_stream.return_value = iter([])
|
|
|
|
agent = DebateAgent(
|
|
name="Test Agent",
|
|
provider=mock_provider,
|
|
system_prompt="You are a test agent",
|
|
)
|
|
|
|
# Stream the response
|
|
chunks = list(agent.generate_response_stream())
|
|
|
|
# Should handle empty stream
|
|
assert chunks == []
|
|
assert len(agent.memory) == 1
|
|
assert agent.memory[0]["content"] == ""
|
|
|
|
|
|
def test_stream_with_none_chunks():
|
|
"""Test that None chunks are filtered out."""
|
|
# Mock OpenRouter-style response with None deltas
|
|
mock_provider = Mock()
|
|
|
|
def mock_stream():
|
|
yield "Hello"
|
|
yield None # Should be filtered
|
|
yield " world"
|
|
yield None # Should be filtered
|
|
yield "!"
|
|
|
|
mock_provider.generate_response_stream.return_value = mock_stream()
|
|
|
|
agent = DebateAgent(
|
|
name="Test Agent",
|
|
provider=mock_provider,
|
|
system_prompt="You are a test agent",
|
|
)
|
|
|
|
# Stream the response
|
|
chunks = [c for c in agent.generate_response_stream() if c is not None]
|
|
|
|
# None chunks should not appear
|
|
assert chunks == ["Hello", " world", "!"]
|
|
assert agent.memory[0]["content"] == "Hello world!"
|