added streaming
This commit is contained in:
parent
ae8c6a4f04
commit
ff247c1bcd
3
.cursorindexingignore
Normal file
3
.cursorindexingignore
Normal file
@ -0,0 +1,3 @@
|
||||
|
||||
# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references
|
||||
.specstory/**
|
||||
4
.specstory/.gitignore
vendored
Normal file
4
.specstory/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
# SpecStory project identity file
|
||||
/.project.json
|
||||
# SpecStory explanation file
|
||||
/.what-is-this.md
|
||||
@ -14,10 +14,11 @@ A terminal application that enables two LLMs to engage in structured debates on
|
||||
- **Beautiful UI**: Rich terminal interface with side-by-side display, color-coded positions, and formatted output
|
||||
|
||||
### Advanced Features
|
||||
- **Streaming Responses**: Real-time streaming of LLM responses with live side-by-side display and tokens/second metrics
|
||||
- **Automatic Memory Management**: Token counting and automatic memory truncation to prevent context overflow
|
||||
- **Auto-Save**: Debates automatically saved after each round (configurable)
|
||||
- **Response Validation**: Ensures agents provide valid, non-empty responses
|
||||
- **Statistics Tracking**: Real-time tracking of response times, token usage, and memory consumption
|
||||
- **Statistics Tracking**: Real-time tracking of response times, token usage, memory consumption, and streaming speeds
|
||||
- **Comprehensive Logging**: Optional file and console logging with configurable levels
|
||||
- **CLI Arguments**: Control all aspects via command-line flags
|
||||
- **Environment Variables**: Secure API key management via `.env` files
|
||||
@ -138,6 +139,7 @@ python -m src.main [OPTIONS]
|
||||
- `--topic, -t TEXT` - Debate topic (skips interactive prompt)
|
||||
- `--exchanges, -e NUMBER` - Exchanges per round (default: 10)
|
||||
- `--no-auto-save` - Disable automatic saving after each round
|
||||
- `--no-streaming` - Disable streaming responses (show complete responses at once instead of real-time streaming)
|
||||
- `--log-level LEVEL` - Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
- `--log-file PATH` - Log to file (default: console only)
|
||||
- `--max-memory-tokens NUMBER` - Maximum tokens to keep in agent memory
|
||||
@ -157,6 +159,9 @@ python -m src.main --log-level DEBUG --log-file debug.log
|
||||
# Disable auto-save for manual control
|
||||
python -m src.main --no-auto-save
|
||||
|
||||
# Disable streaming for slower connections
|
||||
python -m src.main --no-streaming
|
||||
|
||||
# Use custom config and memory limit
|
||||
python -m src.main --config my_config.yaml --max-memory-tokens 50000
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
157
debates/debate_pee_is_stored_in_the_balls_20251111_195431.json
Normal file
157
debates/debate_pee_is_stored_in_the_balls_20251111_195431.json
Normal file
File diff suppressed because one or more lines are too long
29
src/agent.py
29
src/agent.py
@ -123,6 +123,35 @@ class DebateAgent:
|
||||
|
||||
return response
|
||||
|
||||
def generate_response_stream(self, **kwargs):
|
||||
"""
|
||||
Generate a streaming response based on current memory.
|
||||
|
||||
Yields chunks as they arrive and accumulates them into memory
|
||||
after streaming completes.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional parameters for the LLM provider
|
||||
|
||||
Yields:
|
||||
str: Response chunks as they arrive
|
||||
|
||||
Returns:
|
||||
str: The complete accumulated response
|
||||
"""
|
||||
accumulated = []
|
||||
|
||||
# Stream chunks from provider
|
||||
for chunk in self.provider.generate_response_stream(self.memory, **kwargs):
|
||||
accumulated.append(chunk)
|
||||
yield chunk
|
||||
|
||||
# After streaming completes, add full response to memory
|
||||
full_response = ''.join(accumulated)
|
||||
self.memory.append({"role": "assistant", "content": full_response})
|
||||
|
||||
return full_response
|
||||
|
||||
def get_memory(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Get the agent's conversation memory.
|
||||
|
||||
@ -50,3 +50,8 @@ DEFAULT_PRESENCE_PENALTY = 0.0
|
||||
|
||||
# Token Estimation (approximate tokens per character for different languages)
|
||||
TOKENS_PER_CHAR_ENGLISH = 0.25 # Rough estimate for English text
|
||||
|
||||
# Streaming Configuration
|
||||
STREAMING_ENABLED_DEFAULT = True # Whether streaming is enabled by default
|
||||
STREAMING_REFRESH_RATE = 10 # UI updates per second during streaming
|
||||
STREAMING_MIN_TERMINAL_WIDTH = 100 # Minimum terminal width for side-by-side streaming
|
||||
|
||||
162
src/debate.py
162
src/debate.py
@ -194,6 +194,160 @@ class DebateOrchestrator:
|
||||
|
||||
return response_for, response_against
|
||||
|
||||
def conduct_exchange_stream(
|
||||
self, agent_for: DebateAgent, agent_against: DebateAgent
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Conduct one exchange with streaming responses (both agents respond once).
|
||||
|
||||
Args:
|
||||
agent_for: Agent arguing 'for'
|
||||
agent_against: Agent arguing 'against'
|
||||
|
||||
Returns:
|
||||
Tuple of (response_for, response_against)
|
||||
|
||||
Raises:
|
||||
ProviderResponseError: If response validation fails
|
||||
"""
|
||||
logger.info(f"Starting streaming exchange {self.current_exchange + 1}")
|
||||
|
||||
# Build prompts (same as non-streaming)
|
||||
if self.current_exchange == 0:
|
||||
prompt_for = f"Present your opening argument for the position that {self.topic}."
|
||||
else:
|
||||
prompt_for = self._build_context_prompt(agent_for)
|
||||
|
||||
agent_for.add_message("user", prompt_for)
|
||||
|
||||
# Prepare prompt for AGAINST agent (will use after FOR finishes)
|
||||
if self.current_exchange == 0:
|
||||
# Will be updated with FOR's response after streaming
|
||||
prompt_against_template = "against_first_exchange"
|
||||
else:
|
||||
prompt_against = self._build_context_prompt(agent_against)
|
||||
agent_against.add_message("user", prompt_against)
|
||||
|
||||
# Get streaming generators
|
||||
stream_for = agent_for.generate_response_stream()
|
||||
|
||||
# We need to consume stream_for first to get the complete response
|
||||
# before we can build the prompt for agent_against in the first exchange
|
||||
if self.current_exchange == 0:
|
||||
# For first exchange, we need FOR's complete response before AGAINST can start
|
||||
# So we'll handle this specially in the UI function
|
||||
pass
|
||||
|
||||
# Create generator for AGAINST (will be consumed after FOR completes)
|
||||
def get_stream_against():
|
||||
"""Generator that yields chunks from AGAINST agent."""
|
||||
# For first exchange, build prompt with FOR's response
|
||||
if self.current_exchange == 0:
|
||||
# The response_for will be available after stream_for is consumed
|
||||
# We'll handle this in the UI layer by passing a callback
|
||||
pass
|
||||
|
||||
# Generate streaming response
|
||||
for chunk in agent_against.generate_response_stream():
|
||||
yield chunk
|
||||
|
||||
# Use streaming UI to display both responses
|
||||
# It will consume FOR first, then AGAINST
|
||||
from . import ui
|
||||
|
||||
# Track timing
|
||||
start_time = time.time()
|
||||
|
||||
# Handle first exchange specially (AGAINST needs FOR's response)
|
||||
if self.current_exchange == 0:
|
||||
# Manually consume FOR stream first
|
||||
response_for_chunks = []
|
||||
for chunk in stream_for:
|
||||
response_for_chunks.append(chunk)
|
||||
response_for = ''.join(response_for_chunks)
|
||||
|
||||
# Validate FOR response
|
||||
response_for = self._validate_response(response_for, agent_for.name)
|
||||
|
||||
# Record in debate history
|
||||
exchange_data_for = {
|
||||
"exchange": self.current_exchange + 1,
|
||||
"agent": agent_for.name,
|
||||
"position": "for",
|
||||
"content": response_for,
|
||||
}
|
||||
self.debate_history.append(exchange_data_for)
|
||||
|
||||
# Now build AGAINST prompt with FOR's response
|
||||
prompt_against = (
|
||||
f"Your opponent's opening argument: {response_for}\n\n"
|
||||
f"Present your opening counter-argument against the position that {self.topic}."
|
||||
)
|
||||
agent_against.add_message("user", prompt_against)
|
||||
|
||||
# Get AGAINST stream
|
||||
stream_against = agent_against.generate_response_stream()
|
||||
|
||||
# Display with UI (FOR already complete, just show it while AGAINST streams)
|
||||
def for_replay():
|
||||
"""Generator that just yields the complete FOR response."""
|
||||
yield response_for
|
||||
|
||||
response_for_display, response_against, _, tokens_per_sec_against = ui.stream_exchange_pair(
|
||||
exchange_num=self.current_exchange + 1,
|
||||
agent_for_name=agent_for.name,
|
||||
agent_for_stream=for_replay(),
|
||||
agent_against_name=agent_against.name,
|
||||
agent_against_stream=stream_against,
|
||||
total_exchanges=self.exchanges_per_round,
|
||||
)
|
||||
else:
|
||||
# Normal case: stream both
|
||||
stream_against = agent_against.generate_response_stream()
|
||||
|
||||
response_for, response_against, tokens_per_sec_for, tokens_per_sec_against = ui.stream_exchange_pair(
|
||||
exchange_num=self.current_exchange + 1,
|
||||
agent_for_name=agent_for.name,
|
||||
agent_for_stream=stream_for,
|
||||
agent_against_name=agent_against.name,
|
||||
agent_against_stream=stream_against,
|
||||
total_exchanges=self.exchanges_per_round,
|
||||
)
|
||||
|
||||
# Validate FOR response
|
||||
response_for = self._validate_response(response_for, agent_for.name)
|
||||
|
||||
# Record FOR in debate history
|
||||
exchange_data_for = {
|
||||
"exchange": self.current_exchange + 1,
|
||||
"agent": agent_for.name,
|
||||
"position": "for",
|
||||
"content": response_for,
|
||||
}
|
||||
self.debate_history.append(exchange_data_for)
|
||||
|
||||
# Track timing
|
||||
response_time = time.time() - start_time
|
||||
self.response_times.append(response_time)
|
||||
self.total_response_time += response_time
|
||||
|
||||
# Validate AGAINST response
|
||||
response_against = self._validate_response(response_against, agent_against.name)
|
||||
|
||||
# Record AGAINST in debate history
|
||||
exchange_data_against = {
|
||||
"exchange": self.current_exchange + 1,
|
||||
"agent": agent_against.name,
|
||||
"position": "against",
|
||||
"content": response_against,
|
||||
}
|
||||
self.debate_history.append(exchange_data_against)
|
||||
|
||||
self.current_exchange += 1
|
||||
logger.info(f"Streaming exchange {self.current_exchange} completed")
|
||||
|
||||
return response_for, response_against
|
||||
|
||||
def _build_context_prompt(self, agent: DebateAgent) -> str:
|
||||
"""
|
||||
Build a context-aware prompt that includes recent debate history.
|
||||
@ -255,7 +409,7 @@ class DebateOrchestrator:
|
||||
return response
|
||||
|
||||
def run_round(
|
||||
self, agent_for: DebateAgent, agent_against: DebateAgent
|
||||
self, agent_for: DebateAgent, agent_against: DebateAgent, streaming: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Run a full round of exchanges.
|
||||
@ -263,6 +417,7 @@ class DebateOrchestrator:
|
||||
Args:
|
||||
agent_for: Agent arguing 'for'
|
||||
agent_against: Agent arguing 'against'
|
||||
streaming: Whether to use streaming responses (default: True)
|
||||
|
||||
Returns:
|
||||
List of exchanges from this round
|
||||
@ -271,7 +426,10 @@ class DebateOrchestrator:
|
||||
exchanges_to_run = self.exchanges_per_round
|
||||
|
||||
for _ in range(exchanges_to_run):
|
||||
self.conduct_exchange(agent_for, agent_against)
|
||||
if streaming:
|
||||
self.conduct_exchange_stream(agent_for, agent_against)
|
||||
else:
|
||||
self.conduct_exchange(agent_for, agent_against)
|
||||
|
||||
return self.debate_history[round_start:]
|
||||
|
||||
|
||||
14
src/main.py
14
src/main.py
@ -244,7 +244,7 @@ def setup_configuration(config_path: str = DEFAULT_CONFIG_FILE) -> Config:
|
||||
return config
|
||||
|
||||
|
||||
def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against, auto_save: bool = True):
|
||||
def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against, auto_save: bool = True, streaming: bool = True):
|
||||
"""
|
||||
Run the debate loop with user interaction.
|
||||
|
||||
@ -253,6 +253,7 @@ def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against,
|
||||
agent_for: Agent arguing 'for'
|
||||
agent_against: Agent arguing 'against'
|
||||
auto_save: Whether to auto-save after each round
|
||||
streaming: Whether to use streaming responses (default: True)
|
||||
|
||||
Raises:
|
||||
DebateError: If debate encounters an error
|
||||
@ -265,7 +266,7 @@ def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against,
|
||||
|
||||
try:
|
||||
# Run the round (exchanges are displayed as they happen)
|
||||
orchestrator.run_round(agent_for, agent_against)
|
||||
orchestrator.run_round(agent_for, agent_against, streaming=streaming)
|
||||
|
||||
# Auto-save after each round if enabled
|
||||
if auto_save:
|
||||
@ -412,6 +413,12 @@ Examples:
|
||||
help="Maximum tokens to keep in agent memory"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-streaming",
|
||||
action="store_true",
|
||||
help="Disable streaming responses (show complete responses at once)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -485,7 +492,8 @@ def main():
|
||||
|
||||
# Run the debate loop
|
||||
auto_save = not args.no_auto_save
|
||||
run_debate_loop(orchestrator, agent_for, agent_against, auto_save=auto_save)
|
||||
streaming = not args.no_streaming # Streaming enabled by default
|
||||
run_debate_loop(orchestrator, agent_for, agent_against, auto_save=auto_save, streaming=streaming)
|
||||
|
||||
except Exception as e:
|
||||
ui.print_error(f"Error during debate: {str(e)}")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""LM Studio LLM provider implementation."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import requests
|
||||
@ -89,6 +90,105 @@ class LMStudioProvider(BaseLLMProvider):
|
||||
logger.error(f"Unexpected error in LMStudio generate_response: {e}")
|
||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||
|
||||
def generate_response_stream(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Generate a streaming response using LM Studio local API.
|
||||
|
||||
Yields chunks of the response as they arrive via Server-Sent Events (SSE).
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
||||
|
||||
Yields:
|
||||
str: Response chunks as they arrive
|
||||
|
||||
Raises:
|
||||
ProviderError: If the API call fails
|
||||
ProviderTimeoutError: If request times out
|
||||
ProviderConnectionError: If connection fails
|
||||
"""
|
||||
logger.debug(f"Generating streaming response with LMStudio model: {self.model}")
|
||||
|
||||
try:
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS_PER_RESPONSE),
|
||||
"stream": True, # Enable streaming
|
||||
}
|
||||
|
||||
# Add any additional kwargs
|
||||
for key, value in kwargs.items():
|
||||
if key not in ["model", "messages", "stream"]:
|
||||
payload[key] = value
|
||||
|
||||
logger.debug(f"Stream API params: model={payload['model']}, temp={payload['temperature']}, "
|
||||
f"max_tokens={payload['max_tokens']}")
|
||||
|
||||
# Make the streaming API request
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=API_TIMEOUT_SECONDS,
|
||||
stream=True, # Enable streaming mode
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse Server-Sent Events stream
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.decode('utf-8')
|
||||
|
||||
# Skip comments and empty lines
|
||||
if line.startswith(':') or not line.strip():
|
||||
continue
|
||||
|
||||
# Parse data lines
|
||||
if line.startswith('data: '):
|
||||
data_str = line[6:] # Remove 'data: ' prefix
|
||||
|
||||
# Check for stream end
|
||||
if data_str.strip() == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
if 'choices' in data and len(data['choices']) > 0:
|
||||
delta = data['choices'][0].get('delta', {}).get('content')
|
||||
if delta:
|
||||
yield delta
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse SSE chunk: {data_str}")
|
||||
continue
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
logger.warning(f"Cannot connect to LMStudio during streaming: {e}")
|
||||
raise ProviderConnectionError(
|
||||
f"Cannot connect to LM Studio at {self.base_url}. "
|
||||
"Make sure LM Studio is running and the server is started."
|
||||
) from e
|
||||
|
||||
except requests.exceptions.Timeout as e:
|
||||
logger.warning(f"LMStudio request timed out during streaming: {e}")
|
||||
raise ProviderTimeoutError(
|
||||
f"LM Studio request timed out after {API_TIMEOUT_SECONDS}s"
|
||||
) from e
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"LMStudio HTTP error during streaming: {e}")
|
||||
raise ProviderError(f"LM Studio HTTP error: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in streaming: {e}")
|
||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||
|
||||
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""
|
||||
Make the actual API call to LM Studio.
|
||||
|
||||
@ -96,6 +96,82 @@ class OpenRouterProvider(BaseLLMProvider):
|
||||
logger.error(f"Unexpected error in OpenRouter generate_response: {e}")
|
||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||
|
||||
def generate_response_stream(
|
||||
self, messages: List[Dict[str, str]], **kwargs
|
||||
):
|
||||
"""
|
||||
Generate a streaming response using OpenRouter API.
|
||||
|
||||
Yields chunks of the response as they arrive from the API.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
||||
|
||||
Yields:
|
||||
str: Response chunks as they arrive
|
||||
|
||||
Raises:
|
||||
ProviderError: If the API call fails
|
||||
ProviderRateLimitError: If rate limit is exceeded
|
||||
ProviderTimeoutError: If request times out
|
||||
ProviderConnectionError: If connection fails
|
||||
"""
|
||||
logger.debug(f"Generating streaming response with OpenRouter model: {self.model}")
|
||||
|
||||
try:
|
||||
# Set up params for streaming
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||
"max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS_PER_RESPONSE),
|
||||
"stream": True, # Enable streaming
|
||||
}
|
||||
|
||||
# Add any additional kwargs
|
||||
for key, value in kwargs.items():
|
||||
if key not in ["model", "messages", "stream"]:
|
||||
params[key] = value
|
||||
|
||||
logger.debug(f"Stream API params: model={params['model']}, temp={params['temperature']}, "
|
||||
f"max_tokens={params['max_tokens']}")
|
||||
|
||||
# Create streaming request
|
||||
stream = self.client.chat.completions.create(**params)
|
||||
|
||||
# Yield chunks as they arrive
|
||||
for chunk in stream:
|
||||
if chunk.choices and len(chunk.choices) > 0:
|
||||
delta = chunk.choices[0].delta.content
|
||||
if delta is not None:
|
||||
yield delta
|
||||
|
||||
except RateLimitError as e:
|
||||
logger.warning(f"Rate limit exceeded during streaming: {e}")
|
||||
raise ProviderRateLimitError(f"Rate limit exceeded: {str(e)}") from e
|
||||
|
||||
except APITimeoutError as e:
|
||||
logger.warning(f"Request timed out during streaming: {e}")
|
||||
raise ProviderTimeoutError(f"Request timed out after {API_TIMEOUT_SECONDS}s") from e
|
||||
|
||||
except APIConnectionError as e:
|
||||
logger.warning(f"Connection error during streaming: {e}")
|
||||
raise ProviderConnectionError(f"Failed to connect to OpenRouter: {str(e)}") from e
|
||||
|
||||
except APIError as e:
|
||||
# Check for authentication errors
|
||||
if "401" in str(e) or "unauthorized" in str(e).lower():
|
||||
logger.error(f"Authentication failed during streaming: {e}")
|
||||
raise ProviderAuthenticationError(f"Invalid API key or authentication failed") from e
|
||||
|
||||
logger.error(f"API error during streaming: {e}")
|
||||
raise ProviderError(f"OpenRouter API error: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in streaming: {e}")
|
||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||
|
||||
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||
"""
|
||||
Make the actual API call to OpenRouter.
|
||||
|
||||
129
src/ui.py
129
src/ui.py
@ -5,7 +5,11 @@ from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from rich.prompt import Prompt
|
||||
from rich.table import Table
|
||||
from rich.live import Live
|
||||
from rich.layout import Layout
|
||||
from rich.text import Text
|
||||
from typing import Optional, Dict
|
||||
from .constants import STREAMING_REFRESH_RATE
|
||||
|
||||
console = Console()
|
||||
|
||||
@ -132,6 +136,131 @@ def print_exchange_pair(
|
||||
console.print()
|
||||
|
||||
|
||||
def stream_exchange_pair(
|
||||
exchange_num: int,
|
||||
agent_for_name: str,
|
||||
agent_for_stream, # Generator yielding chunks
|
||||
agent_against_name: str,
|
||||
agent_against_stream, # Generator yielding chunks
|
||||
total_exchanges: int,
|
||||
):
|
||||
"""
|
||||
Display streaming responses side-by-side with live updates.
|
||||
|
||||
Args:
|
||||
exchange_num: Exchange number
|
||||
agent_for_name: Name of agent arguing FOR
|
||||
agent_for_stream: Generator yielding FOR agent's response chunks
|
||||
agent_against_name: Name of agent arguing AGAINST
|
||||
agent_against_stream: Generator yielding AGAINST agent's response chunks
|
||||
total_exchanges: Total number of exchanges in round
|
||||
|
||||
Returns:
|
||||
Tuple of (complete_for_response, complete_against_response, tokens_per_sec_for, tokens_per_sec_against)
|
||||
"""
|
||||
import time
|
||||
from .utils.token_counter import count_tokens
|
||||
|
||||
content_for = []
|
||||
content_against = []
|
||||
|
||||
# Track timing for tokens/second calculation
|
||||
start_time_for = time.time()
|
||||
start_time_against = None
|
||||
end_time_for = None
|
||||
end_time_against = None
|
||||
|
||||
# Create layout with two columns
|
||||
layout = Layout()
|
||||
layout.split_row(
|
||||
Layout(name="left"),
|
||||
Layout(name="right")
|
||||
)
|
||||
|
||||
def update_display(for_text, against_text, for_done=False, against_done=False,
|
||||
for_tokens_per_sec=None, against_tokens_per_sec=None):
|
||||
"""Helper to update the live display."""
|
||||
# Build status indicators
|
||||
for_status = ""
|
||||
if for_done:
|
||||
for_status = " ✓"
|
||||
if for_tokens_per_sec:
|
||||
for_status += f" ({for_tokens_per_sec:.1f} tok/s)"
|
||||
else:
|
||||
for_status = " [dim][streaming...][/dim]"
|
||||
|
||||
against_status = ""
|
||||
if not against_text and not against_done:
|
||||
against_status = " [dim][waiting...][/dim]"
|
||||
elif against_done:
|
||||
against_status = " ✓"
|
||||
if against_tokens_per_sec:
|
||||
against_status += f" ({against_tokens_per_sec:.1f} tok/s)"
|
||||
else:
|
||||
against_status = " [dim][streaming...][/dim]"
|
||||
|
||||
# Left column - FOR agent
|
||||
layout["left"].update(
|
||||
Panel(
|
||||
Markdown(for_text) if for_text else Text("Starting...", style="dim"),
|
||||
title=f"[bold]Exchange {exchange_num}/{total_exchanges}[/bold]\n{agent_for_name} - [green]FOR[/green]{for_status}",
|
||||
border_style="green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
# Right column - AGAINST agent
|
||||
layout["right"].update(
|
||||
Panel(
|
||||
Markdown(against_text) if against_text else Text("Waiting...", style="dim"),
|
||||
title=f"[bold]Exchange {exchange_num}/{total_exchanges}[/bold]\n{agent_against_name} - [red]AGAINST[/red]{against_status}",
|
||||
border_style="red",
|
||||
padding=(1, 2),
|
||||
)
|
||||
)
|
||||
|
||||
with Live(layout, refresh_per_second=STREAMING_REFRESH_RATE, console=console) as live:
|
||||
# Stream Agent FOR first
|
||||
for chunk in agent_for_stream:
|
||||
content_for.append(chunk)
|
||||
update_display(''.join(content_for), '', for_done=False, against_done=False)
|
||||
|
||||
# Mark FOR as complete and calculate tokens/sec
|
||||
end_time_for = time.time()
|
||||
for_duration = end_time_for - start_time_for
|
||||
for_text = ''.join(content_for)
|
||||
for_tokens = count_tokens(for_text)
|
||||
for_tokens_per_sec = for_tokens / for_duration if for_duration > 0 else 0
|
||||
|
||||
update_display(for_text, '', for_done=True, against_done=False,
|
||||
for_tokens_per_sec=for_tokens_per_sec)
|
||||
|
||||
# Stream Agent AGAINST
|
||||
start_time_against = time.time()
|
||||
for chunk in agent_against_stream:
|
||||
content_against.append(chunk)
|
||||
update_display(for_text, ''.join(content_against),
|
||||
for_done=True, against_done=False,
|
||||
for_tokens_per_sec=for_tokens_per_sec)
|
||||
|
||||
# Mark AGAINST as complete and calculate tokens/sec
|
||||
end_time_against = time.time()
|
||||
against_duration = end_time_against - start_time_against
|
||||
against_text = ''.join(content_against)
|
||||
against_tokens = count_tokens(against_text)
|
||||
against_tokens_per_sec = against_tokens / against_duration if against_duration > 0 else 0
|
||||
|
||||
update_display(for_text, against_text,
|
||||
for_done=True, against_done=True,
|
||||
for_tokens_per_sec=for_tokens_per_sec,
|
||||
against_tokens_per_sec=against_tokens_per_sec)
|
||||
|
||||
# After Live context, display remains on screen
|
||||
console.print()
|
||||
|
||||
return for_text, against_text, for_tokens_per_sec, against_tokens_per_sec
|
||||
|
||||
|
||||
def print_round_complete(exchange_count: int):
|
||||
"""
|
||||
Print round completion message.
|
||||
|
||||
148
tests/test_streaming.py
Normal file
148
tests/test_streaming.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Tests for streaming functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from src.providers.openrouter import OpenRouterProvider
|
||||
from src.providers.lmstudio import LMStudioProvider
|
||||
from src.agent import DebateAgent
|
||||
|
||||
|
||||
def test_openrouter_stream_yields_chunks():
|
||||
"""Test that OpenRouter streaming yields chunks."""
|
||||
# This is a mock test since we can't make real API calls
|
||||
provider = OpenRouterProvider(
|
||||
model="test-model",
|
||||
api_key="test-key"
|
||||
)
|
||||
|
||||
# Mock the streaming response
|
||||
mock_chunks = [
|
||||
Mock(choices=[Mock(delta=Mock(content="Hello "))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="world"))]),
|
||||
Mock(choices=[Mock(delta=Mock(content="!"))]),
|
||||
]
|
||||
|
||||
with patch.object(provider.client.chat.completions, 'create', return_value=iter(mock_chunks)):
|
||||
messages = [{"role": "user", "content": "Test"}]
|
||||
chunks = list(provider.generate_response_stream(messages))
|
||||
|
||||
assert chunks == ["Hello ", "world", "!"]
|
||||
|
||||
|
||||
def test_agent_stream_accumulates_response():
|
||||
"""Test that agent streaming accumulates response in memory."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate_response_stream.return_value = iter(["Hello ", "world", "!"])
|
||||
|
||||
agent = DebateAgent(
|
||||
name="Test Agent",
|
||||
provider=mock_provider,
|
||||
system_prompt="You are a test agent",
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
chunks = list(agent.generate_response_stream())
|
||||
|
||||
# Check chunks were yielded
|
||||
assert chunks == ["Hello ", "world", "!"]
|
||||
|
||||
# Check full response was added to memory
|
||||
assert len(agent.memory) == 1
|
||||
assert agent.memory[0]["role"] == "assistant"
|
||||
assert agent.memory[0]["content"] == "Hello world!"
|
||||
|
||||
|
||||
def test_agent_stream_with_existing_memory():
|
||||
"""Test streaming with existing conversation memory."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate_response_stream.return_value = iter(["Response"])
|
||||
|
||||
agent = DebateAgent(
|
||||
name="Test Agent",
|
||||
provider=mock_provider,
|
||||
system_prompt="You are a test agent",
|
||||
)
|
||||
|
||||
# Add some existing messages
|
||||
agent.add_message("user", "First message")
|
||||
agent.memory.append({"role": "assistant", "content": "First response"})
|
||||
agent.add_message("user", "Second message")
|
||||
|
||||
# Stream the response
|
||||
list(agent.generate_response_stream())
|
||||
|
||||
# Check memory has all messages
|
||||
assert len(agent.memory) == 4
|
||||
assert agent.memory[-1]["role"] == "assistant"
|
||||
assert agent.memory[-1]["content"] == "Response"
|
||||
|
||||
|
||||
def test_streaming_vs_non_streaming_same_result():
|
||||
"""Test that streaming and non-streaming produce the same result."""
|
||||
mock_provider = Mock()
|
||||
|
||||
# Set up mock for both methods
|
||||
mock_provider.generate_response.return_value = "Complete response"
|
||||
mock_provider.generate_response_stream.return_value = iter(["Complete ", "response"])
|
||||
|
||||
agent1 = DebateAgent("Agent1", mock_provider, "Prompt")
|
||||
agent2 = DebateAgent("Agent2", mock_provider, "Prompt")
|
||||
|
||||
# Non-streaming
|
||||
response1 = agent1.generate_response()
|
||||
|
||||
# Streaming
|
||||
chunks = list(agent2.generate_response_stream())
|
||||
response2 = ''.join(chunks)
|
||||
|
||||
# Both should produce same text
|
||||
assert response1 == "Complete response"
|
||||
assert response2 == "Complete response"
|
||||
|
||||
|
||||
def test_empty_stream_handling():
|
||||
"""Test handling of empty streams."""
|
||||
mock_provider = Mock()
|
||||
mock_provider.generate_response_stream.return_value = iter([])
|
||||
|
||||
agent = DebateAgent(
|
||||
name="Test Agent",
|
||||
provider=mock_provider,
|
||||
system_prompt="You are a test agent",
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
chunks = list(agent.generate_response_stream())
|
||||
|
||||
# Should handle empty stream
|
||||
assert chunks == []
|
||||
assert len(agent.memory) == 1
|
||||
assert agent.memory[0]["content"] == ""
|
||||
|
||||
|
||||
def test_stream_with_none_chunks():
|
||||
"""Test that None chunks are filtered out."""
|
||||
# Mock OpenRouter-style response with None deltas
|
||||
mock_provider = Mock()
|
||||
|
||||
def mock_stream():
|
||||
yield "Hello"
|
||||
yield None # Should be filtered
|
||||
yield " world"
|
||||
yield None # Should be filtered
|
||||
yield "!"
|
||||
|
||||
mock_provider.generate_response_stream.return_value = mock_stream()
|
||||
|
||||
agent = DebateAgent(
|
||||
name="Test Agent",
|
||||
provider=mock_provider,
|
||||
system_prompt="You are a test agent",
|
||||
)
|
||||
|
||||
# Stream the response
|
||||
chunks = [c for c in agent.generate_response_stream() if c is not None]
|
||||
|
||||
# None chunks should not appear
|
||||
assert chunks == ["Hello", " world", "!"]
|
||||
assert agent.memory[0]["content"] == "Hello world!"
|
||||
Loading…
x
Reference in New Issue
Block a user