Middleware¶
Middleware provides hooks into the request/response lifecycle for logging, tracing, metrics, and custom policies.
Basic Usage¶
from agent import Agent, Middleware
class LoggingMiddleware(Middleware):
def before(self, request):
print(f"Request: {request.input[:50]}...")
return request
def after(self, request, response):
print(f"Response: {response.text[:50]}...")
return response
def on_error(self, request, error):
print(f"Error: {error}")
return error # Re-raise the error
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[LoggingMiddleware()],
)
Middleware Hooks¶
before(request) -> AgentRequest¶
Called before the request is sent to the provider. Can modify or replace the request.
class SystemPromptMiddleware(Middleware):
def __init__(self, prefix: str):
self.prefix = prefix
def before(self, request):
if request.system:
request.system = f"{self.prefix}\n\n{request.system}"
else:
request.system = self.prefix
return request
after(request, response) -> AgentResponse¶
Called after receiving a response. Can modify or replace the response.
class MetadataMiddleware(Middleware):
def after(self, request, response):
response.metadata["processed_at"] = datetime.now().isoformat()
response.metadata["input_length"] = len(request.input or "")
return response
on_error(request, error) -> Exception | None¶
Called when an error occurs. Return None to suppress the error.
class ErrorRecoveryMiddleware(Middleware):
def on_error(self, request, error):
if isinstance(error, RateLimitError):
print(f"Rate limited, will retry...")
return error # Let retry handler deal with it
if isinstance(error, ProviderError) and error.status_code == 503:
print("Service unavailable, suppressing error")
return None # Suppress the error
return error # Re-raise other errors
Built-in Middleware¶
LoggingMiddleware¶
from agent.middleware import LoggingMiddleware
# With default print
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[LoggingMiddleware()],
)
# With custom logger
import logging
logger = logging.getLogger(__name__)
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[LoggingMiddleware(log_fn=logger.info)],
)
MetricsMiddleware¶
from agent.middleware import MetricsMiddleware
metrics = MetricsMiddleware()
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[metrics],
)
# Make some requests
agent.run("Hello")
agent.run("World")
# Get metrics
stats = metrics.stats()
print(f"Requests: {stats['request_count']}")
print(f"Total tokens: {stats['total_tokens']}")
print(f"Errors: {stats['error_count']}")
print(f"Avg latency: {stats['avg_latency_ms']:.2f}ms")
RedactionMiddleware¶
from agent.middleware import RedactionMiddleware
# Redact sensitive patterns
redactor = RedactionMiddleware(patterns=[
r"sk-[a-zA-Z0-9]{20,}", # API keys
r"\b\d{16}\b", # Credit cards
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", # Emails
])
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[redactor],
)
RetryPolicyMiddleware¶
from agent.middleware import RetryPolicyMiddleware
from agent.errors import RateLimitError, ProviderError
retry_policy = RetryPolicyMiddleware(
max_retries=5,
retryable_errors=(RateLimitError, ProviderError),
)
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[retry_policy],
)
Chaining Middleware¶
Middleware executes in order for before hooks and reverse order for after hooks:
class TimingMiddleware(Middleware):
def before(self, request):
request.metadata["start_time"] = time.time()
return request
def after(self, request, response):
elapsed = time.time() - request.metadata["start_time"]
print(f"Total time: {elapsed:.2f}s")
return response
class CachingMiddleware(Middleware):
def __init__(self):
self.cache = {}
def before(self, request):
key = hash(request.input)
if key in self.cache:
# Return cached response (skip provider call)
raise CacheHit(self.cache[key])
return request
def after(self, request, response):
key = hash(request.input)
self.cache[key] = response
return response
# Order: TimingMiddleware.before -> CachingMiddleware.before -> Provider
# CachingMiddleware.after -> TimingMiddleware.after
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=[TimingMiddleware(), CachingMiddleware()],
)
Custom Middleware Examples¶
Rate Limiting¶
import time
from collections import deque
class RateLimiter(Middleware):
def __init__(self, requests_per_minute: int = 60):
self.rpm = requests_per_minute
self.requests = deque()
def before(self, request):
now = time.time()
# Remove old requests
while self.requests and self.requests[0] < now - 60:
self.requests.popleft()
# Check rate limit
if len(self.requests) >= self.rpm:
wait_time = 60 - (now - self.requests[0])
time.sleep(wait_time)
self.requests.append(now)
return request
Content Filtering¶
class ContentFilter(Middleware):
def __init__(self, blocked_words: list[str]):
self.blocked = set(w.lower() for w in blocked_words)
def before(self, request):
if request.input:
words = set(request.input.lower().split())
if words & self.blocked:
raise ValueError("Request contains blocked content")
return request
def after(self, request, response):
if response.text:
words = set(response.text.lower().split())
if words & self.blocked:
response.text = "[Content filtered]"
return response
Request Tracing¶
import uuid
class TracingMiddleware(Middleware):
def before(self, request):
trace_id = str(uuid.uuid4())
request.metadata["trace_id"] = trace_id
print(f"[{trace_id}] Starting request")
return request
def after(self, request, response):
trace_id = request.metadata.get("trace_id", "unknown")
print(f"[{trace_id}] Completed in {response.latency_ms}ms")
return response
def on_error(self, request, error):
trace_id = request.metadata.get("trace_id", "unknown")
print(f"[{trace_id}] Error: {error}")
return error
Cost Tracking¶
class CostTracker(Middleware):
def __init__(self):
self.total_cost = 0.0
self.budget = float('inf')
def set_budget(self, amount: float):
self.budget = amount
def before(self, request):
if self.total_cost >= self.budget:
raise BudgetExceededError(f"Budget of ${self.budget} exceeded")
return request
def after(self, request, response):
if response.cost_estimate:
self.total_cost += response.cost_estimate
print(f"Cost: ${response.cost_estimate:.4f} (Total: ${self.total_cost:.4f})")
return response
Response Validation¶
class ResponseValidator(Middleware):
def __init__(self, min_length: int = 10):
self.min_length = min_length
def after(self, request, response):
if response.text and len(response.text) < self.min_length:
raise ValueError(f"Response too short: {len(response.text)} chars")
return response
Middleware Chain¶
Access the middleware chain programmatically:
from agent.middleware import MiddlewareChain
chain = MiddlewareChain([
LoggingMiddleware(),
MetricsMiddleware(),
])
# Add more middleware
chain.add(RateLimiter())
# Use with agent
agent = Agent(
provider="openai",
model="gpt-4o",
middleware=chain.middlewares,
)
Best Practices¶
1. Keep Middleware Focused¶
# Good - single responsibility
class LoggingMiddleware(Middleware):
def before(self, request):
print(f"Request: {request.input}")
return request
class MetricsMiddleware(Middleware):
def after(self, request, response):
self.record_latency(response.latency_ms)
return response
# Bad - too many responsibilities
class KitchenSinkMiddleware(Middleware):
def before(self, request):
print(f"Request: {request.input}")
self.check_rate_limit()
self.validate_input()
self.add_tracing()
return request
2. Don’t Swallow Errors Silently¶
# Good - log before suppressing
def on_error(self, request, error):
if should_suppress(error):
logger.warning(f"Suppressing error: {error}")
return None
return error
# Bad - silent suppression
def on_error(self, request, error):
return None # What happened?
3. Be Careful with Request Modification¶
# Good - copy before modifying
def before(self, request):
modified = AgentRequest(
input=request.input,
system=f"PREFIX: {request.system}",
# ... copy other fields
)
return modified
# Risky - mutating shared state
def before(self, request):
request.system = f"PREFIX: {request.system}"
return request
Next Steps¶
Error Handling - Handle errors in middleware
Routing - Apply middleware to routers
Custom Providers - Provider-specific middleware