added streaming
This commit is contained in:
parent
ae8c6a4f04
commit
ff247c1bcd
3
.cursorindexingignore
Normal file
3
.cursorindexingignore
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
|
||||||
|
# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references
|
||||||
|
.specstory/**
|
||||||
4
.specstory/.gitignore
vendored
Normal file
4
.specstory/.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# SpecStory project identity file
|
||||||
|
/.project.json
|
||||||
|
# SpecStory explanation file
|
||||||
|
/.what-is-this.md
|
||||||
@ -14,10 +14,11 @@ A terminal application that enables two LLMs to engage in structured debates on
|
|||||||
- **Beautiful UI**: Rich terminal interface with side-by-side display, color-coded positions, and formatted output
|
- **Beautiful UI**: Rich terminal interface with side-by-side display, color-coded positions, and formatted output
|
||||||
|
|
||||||
### Advanced Features
|
### Advanced Features
|
||||||
|
- **Streaming Responses**: Real-time streaming of LLM responses with live side-by-side display and tokens/second metrics
|
||||||
- **Automatic Memory Management**: Token counting and automatic memory truncation to prevent context overflow
|
- **Automatic Memory Management**: Token counting and automatic memory truncation to prevent context overflow
|
||||||
- **Auto-Save**: Debates automatically saved after each round (configurable)
|
- **Auto-Save**: Debates automatically saved after each round (configurable)
|
||||||
- **Response Validation**: Ensures agents provide valid, non-empty responses
|
- **Response Validation**: Ensures agents provide valid, non-empty responses
|
||||||
- **Statistics Tracking**: Real-time tracking of response times, token usage, and memory consumption
|
- **Statistics Tracking**: Real-time tracking of response times, token usage, memory consumption, and streaming speeds
|
||||||
- **Comprehensive Logging**: Optional file and console logging with configurable levels
|
- **Comprehensive Logging**: Optional file and console logging with configurable levels
|
||||||
- **CLI Arguments**: Control all aspects via command-line flags
|
- **CLI Arguments**: Control all aspects via command-line flags
|
||||||
- **Environment Variables**: Secure API key management via `.env` files
|
- **Environment Variables**: Secure API key management via `.env` files
|
||||||
@ -138,6 +139,7 @@ python -m src.main [OPTIONS]
|
|||||||
- `--topic, -t TEXT` - Debate topic (skips interactive prompt)
|
- `--topic, -t TEXT` - Debate topic (skips interactive prompt)
|
||||||
- `--exchanges, -e NUMBER` - Exchanges per round (default: 10)
|
- `--exchanges, -e NUMBER` - Exchanges per round (default: 10)
|
||||||
- `--no-auto-save` - Disable automatic saving after each round
|
- `--no-auto-save` - Disable automatic saving after each round
|
||||||
|
- `--no-streaming` - Disable streaming responses (show complete responses at once instead of real-time streaming)
|
||||||
- `--log-level LEVEL` - Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
- `--log-level LEVEL` - Logging level: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||||
- `--log-file PATH` - Log to file (default: console only)
|
- `--log-file PATH` - Log to file (default: console only)
|
||||||
- `--max-memory-tokens NUMBER` - Maximum tokens to keep in agent memory
|
- `--max-memory-tokens NUMBER` - Maximum tokens to keep in agent memory
|
||||||
@ -157,6 +159,9 @@ python -m src.main --log-level DEBUG --log-file debug.log
|
|||||||
# Disable auto-save for manual control
|
# Disable auto-save for manual control
|
||||||
python -m src.main --no-auto-save
|
python -m src.main --no-auto-save
|
||||||
|
|
||||||
|
# Disable streaming for slower connections
|
||||||
|
python -m src.main --no-streaming
|
||||||
|
|
||||||
# Use custom config and memory limit
|
# Use custom config and memory limit
|
||||||
python -m src.main --config my_config.yaml --max-memory-tokens 50000
|
python -m src.main --config my_config.yaml --max-memory-tokens 50000
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
157
debates/debate_pee_is_stored_in_the_balls_20251111_195431.json
Normal file
157
debates/debate_pee_is_stored_in_the_balls_20251111_195431.json
Normal file
File diff suppressed because one or more lines are too long
29
src/agent.py
29
src/agent.py
@ -123,6 +123,35 @@ class DebateAgent:
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def generate_response_stream(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Generate a streaming response based on current memory.
|
||||||
|
|
||||||
|
Yields chunks as they arrive and accumulates them into memory
|
||||||
|
after streaming completes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Additional parameters for the LLM provider
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: Response chunks as they arrive
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The complete accumulated response
|
||||||
|
"""
|
||||||
|
accumulated = []
|
||||||
|
|
||||||
|
# Stream chunks from provider
|
||||||
|
for chunk in self.provider.generate_response_stream(self.memory, **kwargs):
|
||||||
|
accumulated.append(chunk)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# After streaming completes, add full response to memory
|
||||||
|
full_response = ''.join(accumulated)
|
||||||
|
self.memory.append({"role": "assistant", "content": full_response})
|
||||||
|
|
||||||
|
return full_response
|
||||||
|
|
||||||
def get_memory(self) -> List[Dict[str, str]]:
|
def get_memory(self) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Get the agent's conversation memory.
|
Get the agent's conversation memory.
|
||||||
|
|||||||
@ -50,3 +50,8 @@ DEFAULT_PRESENCE_PENALTY = 0.0
|
|||||||
|
|
||||||
# Token Estimation (approximate tokens per character for different languages)
|
# Token Estimation (approximate tokens per character for different languages)
|
||||||
TOKENS_PER_CHAR_ENGLISH = 0.25 # Rough estimate for English text
|
TOKENS_PER_CHAR_ENGLISH = 0.25 # Rough estimate for English text
|
||||||
|
|
||||||
|
# Streaming Configuration
|
||||||
|
STREAMING_ENABLED_DEFAULT = True # Whether streaming is enabled by default
|
||||||
|
STREAMING_REFRESH_RATE = 10 # UI updates per second during streaming
|
||||||
|
STREAMING_MIN_TERMINAL_WIDTH = 100 # Minimum terminal width for side-by-side streaming
|
||||||
|
|||||||
162
src/debate.py
162
src/debate.py
@ -194,6 +194,160 @@ class DebateOrchestrator:
|
|||||||
|
|
||||||
return response_for, response_against
|
return response_for, response_against
|
||||||
|
|
||||||
|
def conduct_exchange_stream(
|
||||||
|
self, agent_for: DebateAgent, agent_against: DebateAgent
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Conduct one exchange with streaming responses (both agents respond once).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_for: Agent arguing 'for'
|
||||||
|
agent_against: Agent arguing 'against'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_for, response_against)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderResponseError: If response validation fails
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting streaming exchange {self.current_exchange + 1}")
|
||||||
|
|
||||||
|
# Build prompts (same as non-streaming)
|
||||||
|
if self.current_exchange == 0:
|
||||||
|
prompt_for = f"Present your opening argument for the position that {self.topic}."
|
||||||
|
else:
|
||||||
|
prompt_for = self._build_context_prompt(agent_for)
|
||||||
|
|
||||||
|
agent_for.add_message("user", prompt_for)
|
||||||
|
|
||||||
|
# Prepare prompt for AGAINST agent (will use after FOR finishes)
|
||||||
|
if self.current_exchange == 0:
|
||||||
|
# Will be updated with FOR's response after streaming
|
||||||
|
prompt_against_template = "against_first_exchange"
|
||||||
|
else:
|
||||||
|
prompt_against = self._build_context_prompt(agent_against)
|
||||||
|
agent_against.add_message("user", prompt_against)
|
||||||
|
|
||||||
|
# Get streaming generators
|
||||||
|
stream_for = agent_for.generate_response_stream()
|
||||||
|
|
||||||
|
# We need to consume stream_for first to get the complete response
|
||||||
|
# before we can build the prompt for agent_against in the first exchange
|
||||||
|
if self.current_exchange == 0:
|
||||||
|
# For first exchange, we need FOR's complete response before AGAINST can start
|
||||||
|
# So we'll handle this specially in the UI function
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create generator for AGAINST (will be consumed after FOR completes)
|
||||||
|
def get_stream_against():
|
||||||
|
"""Generator that yields chunks from AGAINST agent."""
|
||||||
|
# For first exchange, build prompt with FOR's response
|
||||||
|
if self.current_exchange == 0:
|
||||||
|
# The response_for will be available after stream_for is consumed
|
||||||
|
# We'll handle this in the UI layer by passing a callback
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Generate streaming response
|
||||||
|
for chunk in agent_against.generate_response_stream():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Use streaming UI to display both responses
|
||||||
|
# It will consume FOR first, then AGAINST
|
||||||
|
from . import ui
|
||||||
|
|
||||||
|
# Track timing
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Handle first exchange specially (AGAINST needs FOR's response)
|
||||||
|
if self.current_exchange == 0:
|
||||||
|
# Manually consume FOR stream first
|
||||||
|
response_for_chunks = []
|
||||||
|
for chunk in stream_for:
|
||||||
|
response_for_chunks.append(chunk)
|
||||||
|
response_for = ''.join(response_for_chunks)
|
||||||
|
|
||||||
|
# Validate FOR response
|
||||||
|
response_for = self._validate_response(response_for, agent_for.name)
|
||||||
|
|
||||||
|
# Record in debate history
|
||||||
|
exchange_data_for = {
|
||||||
|
"exchange": self.current_exchange + 1,
|
||||||
|
"agent": agent_for.name,
|
||||||
|
"position": "for",
|
||||||
|
"content": response_for,
|
||||||
|
}
|
||||||
|
self.debate_history.append(exchange_data_for)
|
||||||
|
|
||||||
|
# Now build AGAINST prompt with FOR's response
|
||||||
|
prompt_against = (
|
||||||
|
f"Your opponent's opening argument: {response_for}\n\n"
|
||||||
|
f"Present your opening counter-argument against the position that {self.topic}."
|
||||||
|
)
|
||||||
|
agent_against.add_message("user", prompt_against)
|
||||||
|
|
||||||
|
# Get AGAINST stream
|
||||||
|
stream_against = agent_against.generate_response_stream()
|
||||||
|
|
||||||
|
# Display with UI (FOR already complete, just show it while AGAINST streams)
|
||||||
|
def for_replay():
|
||||||
|
"""Generator that just yields the complete FOR response."""
|
||||||
|
yield response_for
|
||||||
|
|
||||||
|
response_for_display, response_against, _, tokens_per_sec_against = ui.stream_exchange_pair(
|
||||||
|
exchange_num=self.current_exchange + 1,
|
||||||
|
agent_for_name=agent_for.name,
|
||||||
|
agent_for_stream=for_replay(),
|
||||||
|
agent_against_name=agent_against.name,
|
||||||
|
agent_against_stream=stream_against,
|
||||||
|
total_exchanges=self.exchanges_per_round,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Normal case: stream both
|
||||||
|
stream_against = agent_against.generate_response_stream()
|
||||||
|
|
||||||
|
response_for, response_against, tokens_per_sec_for, tokens_per_sec_against = ui.stream_exchange_pair(
|
||||||
|
exchange_num=self.current_exchange + 1,
|
||||||
|
agent_for_name=agent_for.name,
|
||||||
|
agent_for_stream=stream_for,
|
||||||
|
agent_against_name=agent_against.name,
|
||||||
|
agent_against_stream=stream_against,
|
||||||
|
total_exchanges=self.exchanges_per_round,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate FOR response
|
||||||
|
response_for = self._validate_response(response_for, agent_for.name)
|
||||||
|
|
||||||
|
# Record FOR in debate history
|
||||||
|
exchange_data_for = {
|
||||||
|
"exchange": self.current_exchange + 1,
|
||||||
|
"agent": agent_for.name,
|
||||||
|
"position": "for",
|
||||||
|
"content": response_for,
|
||||||
|
}
|
||||||
|
self.debate_history.append(exchange_data_for)
|
||||||
|
|
||||||
|
# Track timing
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
self.response_times.append(response_time)
|
||||||
|
self.total_response_time += response_time
|
||||||
|
|
||||||
|
# Validate AGAINST response
|
||||||
|
response_against = self._validate_response(response_against, agent_against.name)
|
||||||
|
|
||||||
|
# Record AGAINST in debate history
|
||||||
|
exchange_data_against = {
|
||||||
|
"exchange": self.current_exchange + 1,
|
||||||
|
"agent": agent_against.name,
|
||||||
|
"position": "against",
|
||||||
|
"content": response_against,
|
||||||
|
}
|
||||||
|
self.debate_history.append(exchange_data_against)
|
||||||
|
|
||||||
|
self.current_exchange += 1
|
||||||
|
logger.info(f"Streaming exchange {self.current_exchange} completed")
|
||||||
|
|
||||||
|
return response_for, response_against
|
||||||
|
|
||||||
def _build_context_prompt(self, agent: DebateAgent) -> str:
|
def _build_context_prompt(self, agent: DebateAgent) -> str:
|
||||||
"""
|
"""
|
||||||
Build a context-aware prompt that includes recent debate history.
|
Build a context-aware prompt that includes recent debate history.
|
||||||
@ -255,7 +409,7 @@ class DebateOrchestrator:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
def run_round(
|
def run_round(
|
||||||
self, agent_for: DebateAgent, agent_against: DebateAgent
|
self, agent_for: DebateAgent, agent_against: DebateAgent, streaming: bool = True
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
Run a full round of exchanges.
|
Run a full round of exchanges.
|
||||||
@ -263,6 +417,7 @@ class DebateOrchestrator:
|
|||||||
Args:
|
Args:
|
||||||
agent_for: Agent arguing 'for'
|
agent_for: Agent arguing 'for'
|
||||||
agent_against: Agent arguing 'against'
|
agent_against: Agent arguing 'against'
|
||||||
|
streaming: Whether to use streaming responses (default: True)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of exchanges from this round
|
List of exchanges from this round
|
||||||
@ -271,7 +426,10 @@ class DebateOrchestrator:
|
|||||||
exchanges_to_run = self.exchanges_per_round
|
exchanges_to_run = self.exchanges_per_round
|
||||||
|
|
||||||
for _ in range(exchanges_to_run):
|
for _ in range(exchanges_to_run):
|
||||||
self.conduct_exchange(agent_for, agent_against)
|
if streaming:
|
||||||
|
self.conduct_exchange_stream(agent_for, agent_against)
|
||||||
|
else:
|
||||||
|
self.conduct_exchange(agent_for, agent_against)
|
||||||
|
|
||||||
return self.debate_history[round_start:]
|
return self.debate_history[round_start:]
|
||||||
|
|
||||||
|
|||||||
14
src/main.py
14
src/main.py
@ -244,7 +244,7 @@ def setup_configuration(config_path: str = DEFAULT_CONFIG_FILE) -> Config:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against, auto_save: bool = True):
|
def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against, auto_save: bool = True, streaming: bool = True):
|
||||||
"""
|
"""
|
||||||
Run the debate loop with user interaction.
|
Run the debate loop with user interaction.
|
||||||
|
|
||||||
@ -253,6 +253,7 @@ def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against,
|
|||||||
agent_for: Agent arguing 'for'
|
agent_for: Agent arguing 'for'
|
||||||
agent_against: Agent arguing 'against'
|
agent_against: Agent arguing 'against'
|
||||||
auto_save: Whether to auto-save after each round
|
auto_save: Whether to auto-save after each round
|
||||||
|
streaming: Whether to use streaming responses (default: True)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
DebateError: If debate encounters an error
|
DebateError: If debate encounters an error
|
||||||
@ -265,7 +266,7 @@ def run_debate_loop(orchestrator: DebateOrchestrator, agent_for, agent_against,
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Run the round (exchanges are displayed as they happen)
|
# Run the round (exchanges are displayed as they happen)
|
||||||
orchestrator.run_round(agent_for, agent_against)
|
orchestrator.run_round(agent_for, agent_against, streaming=streaming)
|
||||||
|
|
||||||
# Auto-save after each round if enabled
|
# Auto-save after each round if enabled
|
||||||
if auto_save:
|
if auto_save:
|
||||||
@ -412,6 +413,12 @@ Examples:
|
|||||||
help="Maximum tokens to keep in agent memory"
|
help="Maximum tokens to keep in agent memory"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-streaming",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable streaming responses (show complete responses at once)"
|
||||||
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -485,7 +492,8 @@ def main():
|
|||||||
|
|
||||||
# Run the debate loop
|
# Run the debate loop
|
||||||
auto_save = not args.no_auto_save
|
auto_save = not args.no_auto_save
|
||||||
run_debate_loop(orchestrator, agent_for, agent_against, auto_save=auto_save)
|
streaming = not args.no_streaming # Streaming enabled by default
|
||||||
|
run_debate_loop(orchestrator, agent_for, agent_against, auto_save=auto_save, streaming=streaming)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
ui.print_error(f"Error during debate: {str(e)}")
|
ui.print_error(f"Error during debate: {str(e)}")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""LM Studio LLM provider implementation."""
|
"""LM Studio LLM provider implementation."""
|
||||||
|
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
import requests
|
import requests
|
||||||
@ -89,6 +90,105 @@ class LMStudioProvider(BaseLLMProvider):
|
|||||||
logger.error(f"Unexpected error in LMStudio generate_response: {e}")
|
logger.error(f"Unexpected error in LMStudio generate_response: {e}")
|
||||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||||
|
|
||||||
|
def generate_response_stream(
|
||||||
|
self, messages: List[Dict[str, str]], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a streaming response using LM Studio local API.
|
||||||
|
|
||||||
|
Yields chunks of the response as they arrive via Server-Sent Events (SSE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: Response chunks as they arrive
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderError: If the API call fails
|
||||||
|
ProviderTimeoutError: If request times out
|
||||||
|
ProviderConnectionError: If connection fails
|
||||||
|
"""
|
||||||
|
logger.debug(f"Generating streaming response with LMStudio model: {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare the request payload
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS_PER_RESPONSE),
|
||||||
|
"stream": True, # Enable streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add any additional kwargs
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key not in ["model", "messages", "stream"]:
|
||||||
|
payload[key] = value
|
||||||
|
|
||||||
|
logger.debug(f"Stream API params: model={payload['model']}, temp={payload['temperature']}, "
|
||||||
|
f"max_tokens={payload['max_tokens']}")
|
||||||
|
|
||||||
|
# Make the streaming API request
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=API_TIMEOUT_SECONDS,
|
||||||
|
stream=True, # Enable streaming mode
|
||||||
|
)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse Server-Sent Events stream
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line = line.decode('utf-8')
|
||||||
|
|
||||||
|
# Skip comments and empty lines
|
||||||
|
if line.startswith(':') or not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse data lines
|
||||||
|
if line.startswith('data: '):
|
||||||
|
data_str = line[6:] # Remove 'data: ' prefix
|
||||||
|
|
||||||
|
# Check for stream end
|
||||||
|
if data_str.strip() == '[DONE]':
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
if 'choices' in data and len(data['choices']) > 0:
|
||||||
|
delta = data['choices'][0].get('delta', {}).get('content')
|
||||||
|
if delta:
|
||||||
|
yield delta
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Failed to parse SSE chunk: {data_str}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except requests.exceptions.ConnectionError as e:
|
||||||
|
logger.warning(f"Cannot connect to LMStudio during streaming: {e}")
|
||||||
|
raise ProviderConnectionError(
|
||||||
|
f"Cannot connect to LM Studio at {self.base_url}. "
|
||||||
|
"Make sure LM Studio is running and the server is started."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except requests.exceptions.Timeout as e:
|
||||||
|
logger.warning(f"LMStudio request timed out during streaming: {e}")
|
||||||
|
raise ProviderTimeoutError(
|
||||||
|
f"LM Studio request timed out after {API_TIMEOUT_SECONDS}s"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(f"LMStudio HTTP error during streaming: {e}")
|
||||||
|
raise ProviderError(f"LM Studio HTTP error: {str(e)}") from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in streaming: {e}")
|
||||||
|
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||||
|
|
||||||
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Make the actual API call to LM Studio.
|
Make the actual API call to LM Studio.
|
||||||
|
|||||||
@ -96,6 +96,82 @@ class OpenRouterProvider(BaseLLMProvider):
|
|||||||
logger.error(f"Unexpected error in OpenRouter generate_response: {e}")
|
logger.error(f"Unexpected error in OpenRouter generate_response: {e}")
|
||||||
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||||
|
|
||||||
|
def generate_response_stream(
|
||||||
|
self, messages: List[Dict[str, str]], **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a streaming response using OpenRouter API.
|
||||||
|
|
||||||
|
Yields chunks of the response as they arrive from the API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries with 'role' and 'content'
|
||||||
|
**kwargs: Additional parameters (temperature, max_tokens, etc.)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: Response chunks as they arrive
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderError: If the API call fails
|
||||||
|
ProviderRateLimitError: If rate limit is exceeded
|
||||||
|
ProviderTimeoutError: If request times out
|
||||||
|
ProviderConnectionError: If connection fails
|
||||||
|
"""
|
||||||
|
logger.debug(f"Generating streaming response with OpenRouter model: {self.model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set up params for streaming
|
||||||
|
params = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": kwargs.get("temperature", DEFAULT_TEMPERATURE),
|
||||||
|
"max_tokens": kwargs.get("max_tokens", DEFAULT_MAX_TOKENS_PER_RESPONSE),
|
||||||
|
"stream": True, # Enable streaming
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add any additional kwargs
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key not in ["model", "messages", "stream"]:
|
||||||
|
params[key] = value
|
||||||
|
|
||||||
|
logger.debug(f"Stream API params: model={params['model']}, temp={params['temperature']}, "
|
||||||
|
f"max_tokens={params['max_tokens']}")
|
||||||
|
|
||||||
|
# Create streaming request
|
||||||
|
stream = self.client.chat.completions.create(**params)
|
||||||
|
|
||||||
|
# Yield chunks as they arrive
|
||||||
|
for chunk in stream:
|
||||||
|
if chunk.choices and len(chunk.choices) > 0:
|
||||||
|
delta = chunk.choices[0].delta.content
|
||||||
|
if delta is not None:
|
||||||
|
yield delta
|
||||||
|
|
||||||
|
except RateLimitError as e:
|
||||||
|
logger.warning(f"Rate limit exceeded during streaming: {e}")
|
||||||
|
raise ProviderRateLimitError(f"Rate limit exceeded: {str(e)}") from e
|
||||||
|
|
||||||
|
except APITimeoutError as e:
|
||||||
|
logger.warning(f"Request timed out during streaming: {e}")
|
||||||
|
raise ProviderTimeoutError(f"Request timed out after {API_TIMEOUT_SECONDS}s") from e
|
||||||
|
|
||||||
|
except APIConnectionError as e:
|
||||||
|
logger.warning(f"Connection error during streaming: {e}")
|
||||||
|
raise ProviderConnectionError(f"Failed to connect to OpenRouter: {str(e)}") from e
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
# Check for authentication errors
|
||||||
|
if "401" in str(e) or "unauthorized" in str(e).lower():
|
||||||
|
logger.error(f"Authentication failed during streaming: {e}")
|
||||||
|
raise ProviderAuthenticationError(f"Invalid API key or authentication failed") from e
|
||||||
|
|
||||||
|
logger.error(f"API error during streaming: {e}")
|
||||||
|
raise ProviderError(f"OpenRouter API error: {str(e)}") from e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in streaming: {e}")
|
||||||
|
raise ProviderError(f"Unexpected error: {str(e)}") from e
|
||||||
|
|
||||||
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
def _call_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
||||||
"""
|
"""
|
||||||
Make the actual API call to OpenRouter.
|
Make the actual API call to OpenRouter.
|
||||||
|
|||||||
129
src/ui.py
129
src/ui.py
@ -5,7 +5,11 @@ from rich.panel import Panel
|
|||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.prompt import Prompt
|
from rich.prompt import Prompt
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.layout import Layout
|
||||||
|
from rich.text import Text
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
|
from .constants import STREAMING_REFRESH_RATE
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@ -132,6 +136,131 @@ def print_exchange_pair(
|
|||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
|
|
||||||
|
def stream_exchange_pair(
|
||||||
|
exchange_num: int,
|
||||||
|
agent_for_name: str,
|
||||||
|
agent_for_stream, # Generator yielding chunks
|
||||||
|
agent_against_name: str,
|
||||||
|
agent_against_stream, # Generator yielding chunks
|
||||||
|
total_exchanges: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Display streaming responses side-by-side with live updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_num: Exchange number
|
||||||
|
agent_for_name: Name of agent arguing FOR
|
||||||
|
agent_for_stream: Generator yielding FOR agent's response chunks
|
||||||
|
agent_against_name: Name of agent arguing AGAINST
|
||||||
|
agent_against_stream: Generator yielding AGAINST agent's response chunks
|
||||||
|
total_exchanges: Total number of exchanges in round
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (complete_for_response, complete_against_response, tokens_per_sec_for, tokens_per_sec_against)
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from .utils.token_counter import count_tokens
|
||||||
|
|
||||||
|
content_for = []
|
||||||
|
content_against = []
|
||||||
|
|
||||||
|
# Track timing for tokens/second calculation
|
||||||
|
start_time_for = time.time()
|
||||||
|
start_time_against = None
|
||||||
|
end_time_for = None
|
||||||
|
end_time_against = None
|
||||||
|
|
||||||
|
# Create layout with two columns
|
||||||
|
layout = Layout()
|
||||||
|
layout.split_row(
|
||||||
|
Layout(name="left"),
|
||||||
|
Layout(name="right")
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_display(for_text, against_text, for_done=False, against_done=False,
|
||||||
|
for_tokens_per_sec=None, against_tokens_per_sec=None):
|
||||||
|
"""Helper to update the live display."""
|
||||||
|
# Build status indicators
|
||||||
|
for_status = ""
|
||||||
|
if for_done:
|
||||||
|
for_status = " ✓"
|
||||||
|
if for_tokens_per_sec:
|
||||||
|
for_status += f" ({for_tokens_per_sec:.1f} tok/s)"
|
||||||
|
else:
|
||||||
|
for_status = " [dim][streaming...][/dim]"
|
||||||
|
|
||||||
|
against_status = ""
|
||||||
|
if not against_text and not against_done:
|
||||||
|
against_status = " [dim][waiting...][/dim]"
|
||||||
|
elif against_done:
|
||||||
|
against_status = " ✓"
|
||||||
|
if against_tokens_per_sec:
|
||||||
|
against_status += f" ({against_tokens_per_sec:.1f} tok/s)"
|
||||||
|
else:
|
||||||
|
against_status = " [dim][streaming...][/dim]"
|
||||||
|
|
||||||
|
# Left column - FOR agent
|
||||||
|
layout["left"].update(
|
||||||
|
Panel(
|
||||||
|
Markdown(for_text) if for_text else Text("Starting...", style="dim"),
|
||||||
|
title=f"[bold]Exchange {exchange_num}/{total_exchanges}[/bold]\n{agent_for_name} - [green]FOR[/green]{for_status}",
|
||||||
|
border_style="green",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right column - AGAINST agent
|
||||||
|
layout["right"].update(
|
||||||
|
Panel(
|
||||||
|
Markdown(against_text) if against_text else Text("Waiting...", style="dim"),
|
||||||
|
title=f"[bold]Exchange {exchange_num}/{total_exchanges}[/bold]\n{agent_against_name} - [red]AGAINST[/red]{against_status}",
|
||||||
|
border_style="red",
|
||||||
|
padding=(1, 2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with Live(layout, refresh_per_second=STREAMING_REFRESH_RATE, console=console) as live:
|
||||||
|
# Stream Agent FOR first
|
||||||
|
for chunk in agent_for_stream:
|
||||||
|
content_for.append(chunk)
|
||||||
|
update_display(''.join(content_for), '', for_done=False, against_done=False)
|
||||||
|
|
||||||
|
# Mark FOR as complete and calculate tokens/sec
|
||||||
|
end_time_for = time.time()
|
||||||
|
for_duration = end_time_for - start_time_for
|
||||||
|
for_text = ''.join(content_for)
|
||||||
|
for_tokens = count_tokens(for_text)
|
||||||
|
for_tokens_per_sec = for_tokens / for_duration if for_duration > 0 else 0
|
||||||
|
|
||||||
|
update_display(for_text, '', for_done=True, against_done=False,
|
||||||
|
for_tokens_per_sec=for_tokens_per_sec)
|
||||||
|
|
||||||
|
# Stream Agent AGAINST
|
||||||
|
start_time_against = time.time()
|
||||||
|
for chunk in agent_against_stream:
|
||||||
|
content_against.append(chunk)
|
||||||
|
update_display(for_text, ''.join(content_against),
|
||||||
|
for_done=True, against_done=False,
|
||||||
|
for_tokens_per_sec=for_tokens_per_sec)
|
||||||
|
|
||||||
|
# Mark AGAINST as complete and calculate tokens/sec
|
||||||
|
end_time_against = time.time()
|
||||||
|
against_duration = end_time_against - start_time_against
|
||||||
|
against_text = ''.join(content_against)
|
||||||
|
against_tokens = count_tokens(against_text)
|
||||||
|
against_tokens_per_sec = against_tokens / against_duration if against_duration > 0 else 0
|
||||||
|
|
||||||
|
update_display(for_text, against_text,
|
||||||
|
for_done=True, against_done=True,
|
||||||
|
for_tokens_per_sec=for_tokens_per_sec,
|
||||||
|
against_tokens_per_sec=against_tokens_per_sec)
|
||||||
|
|
||||||
|
# After Live context, display remains on screen
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
return for_text, against_text, for_tokens_per_sec, against_tokens_per_sec
|
||||||
|
|
||||||
|
|
||||||
def print_round_complete(exchange_count: int):
|
def print_round_complete(exchange_count: int):
|
||||||
"""
|
"""
|
||||||
Print round completion message.
|
Print round completion message.
|
||||||
|
|||||||
148
tests/test_streaming.py
Normal file
148
tests/test_streaming.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
"""Tests for streaming functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from src.providers.openrouter import OpenRouterProvider
|
||||||
|
from src.providers.lmstudio import LMStudioProvider
|
||||||
|
from src.agent import DebateAgent
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_stream_yields_chunks():
|
||||||
|
"""Test that OpenRouter streaming yields chunks."""
|
||||||
|
# This is a mock test since we can't make real API calls
|
||||||
|
provider = OpenRouterProvider(
|
||||||
|
model="test-model",
|
||||||
|
api_key="test-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the streaming response
|
||||||
|
mock_chunks = [
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="Hello "))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="world"))]),
|
||||||
|
Mock(choices=[Mock(delta=Mock(content="!"))]),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(provider.client.chat.completions, 'create', return_value=iter(mock_chunks)):
|
||||||
|
messages = [{"role": "user", "content": "Test"}]
|
||||||
|
chunks = list(provider.generate_response_stream(messages))
|
||||||
|
|
||||||
|
assert chunks == ["Hello ", "world", "!"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_stream_accumulates_response():
|
||||||
|
"""Test that agent streaming accumulates response in memory."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.generate_response_stream.return_value = iter(["Hello ", "world", "!"])
|
||||||
|
|
||||||
|
agent = DebateAgent(
|
||||||
|
name="Test Agent",
|
||||||
|
provider=mock_provider,
|
||||||
|
system_prompt="You are a test agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
chunks = list(agent.generate_response_stream())
|
||||||
|
|
||||||
|
# Check chunks were yielded
|
||||||
|
assert chunks == ["Hello ", "world", "!"]
|
||||||
|
|
||||||
|
# Check full response was added to memory
|
||||||
|
assert len(agent.memory) == 1
|
||||||
|
assert agent.memory[0]["role"] == "assistant"
|
||||||
|
assert agent.memory[0]["content"] == "Hello world!"
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_stream_with_existing_memory():
|
||||||
|
"""Test streaming with existing conversation memory."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.generate_response_stream.return_value = iter(["Response"])
|
||||||
|
|
||||||
|
agent = DebateAgent(
|
||||||
|
name="Test Agent",
|
||||||
|
provider=mock_provider,
|
||||||
|
system_prompt="You are a test agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add some existing messages
|
||||||
|
agent.add_message("user", "First message")
|
||||||
|
agent.memory.append({"role": "assistant", "content": "First response"})
|
||||||
|
agent.add_message("user", "Second message")
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
list(agent.generate_response_stream())
|
||||||
|
|
||||||
|
# Check memory has all messages
|
||||||
|
assert len(agent.memory) == 4
|
||||||
|
assert agent.memory[-1]["role"] == "assistant"
|
||||||
|
assert agent.memory[-1]["content"] == "Response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_vs_non_streaming_same_result():
|
||||||
|
"""Test that streaming and non-streaming produce the same result."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
|
||||||
|
# Set up mock for both methods
|
||||||
|
mock_provider.generate_response.return_value = "Complete response"
|
||||||
|
mock_provider.generate_response_stream.return_value = iter(["Complete ", "response"])
|
||||||
|
|
||||||
|
agent1 = DebateAgent("Agent1", mock_provider, "Prompt")
|
||||||
|
agent2 = DebateAgent("Agent2", mock_provider, "Prompt")
|
||||||
|
|
||||||
|
# Non-streaming
|
||||||
|
response1 = agent1.generate_response()
|
||||||
|
|
||||||
|
# Streaming
|
||||||
|
chunks = list(agent2.generate_response_stream())
|
||||||
|
response2 = ''.join(chunks)
|
||||||
|
|
||||||
|
# Both should produce same text
|
||||||
|
assert response1 == "Complete response"
|
||||||
|
assert response2 == "Complete response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_stream_handling():
|
||||||
|
"""Test handling of empty streams."""
|
||||||
|
mock_provider = Mock()
|
||||||
|
mock_provider.generate_response_stream.return_value = iter([])
|
||||||
|
|
||||||
|
agent = DebateAgent(
|
||||||
|
name="Test Agent",
|
||||||
|
provider=mock_provider,
|
||||||
|
system_prompt="You are a test agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
chunks = list(agent.generate_response_stream())
|
||||||
|
|
||||||
|
# Should handle empty stream
|
||||||
|
assert chunks == []
|
||||||
|
assert len(agent.memory) == 1
|
||||||
|
assert agent.memory[0]["content"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_with_none_chunks():
|
||||||
|
"""Test that None chunks are filtered out."""
|
||||||
|
# Mock OpenRouter-style response with None deltas
|
||||||
|
mock_provider = Mock()
|
||||||
|
|
||||||
|
def mock_stream():
|
||||||
|
yield "Hello"
|
||||||
|
yield None # Should be filtered
|
||||||
|
yield " world"
|
||||||
|
yield None # Should be filtered
|
||||||
|
yield "!"
|
||||||
|
|
||||||
|
mock_provider.generate_response_stream.return_value = mock_stream()
|
||||||
|
|
||||||
|
agent = DebateAgent(
|
||||||
|
name="Test Agent",
|
||||||
|
provider=mock_provider,
|
||||||
|
system_prompt="You are a test agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream the response
|
||||||
|
chunks = [c for c in agent.generate_response_stream() if c is not None]
|
||||||
|
|
||||||
|
# None chunks should not appear
|
||||||
|
assert chunks == ["Hello", " world", "!"]
|
||||||
|
assert agent.memory[0]["content"] == "Hello world!"
|
||||||
Loading…
x
Reference in New Issue
Block a user