Middleware

Middleware provides hooks into the request/response lifecycle for logging, tracing, metrics, and custom policies.

Basic Usage

from agent import Agent, Middleware

class LoggingMiddleware(Middleware):
    def before(self, request):
        print(f"Request: {request.input[:50]}...")
        return request
    
    def after(self, request, response):
        print(f"Response: {response.text[:50]}...")
        return response
    
    def on_error(self, request, error):
        print(f"Error: {error}")
        return error  # Re-raise the error

agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[LoggingMiddleware()],
)

Middleware Hooks

before(request) -> AgentRequest

Called before the request is sent to the provider. Can modify or replace the request.

class SystemPromptMiddleware(Middleware):
    def __init__(self, prefix: str):
        self.prefix = prefix
    
    def before(self, request):
        if request.system:
            request.system = f"{self.prefix}\n\n{request.system}"
        else:
            request.system = self.prefix
        return request

after(request, response) -> AgentResponse

Called after receiving a response. Can modify or replace the response.

class MetadataMiddleware(Middleware):
    def after(self, request, response):
        response.metadata["processed_at"] = datetime.now().isoformat()
        response.metadata["input_length"] = len(request.input or "")
        return response

on_error(request, error) -> Exception | None

Called when an error occurs. Return None to suppress the error.

class ErrorRecoveryMiddleware(Middleware):
    def on_error(self, request, error):
        if isinstance(error, RateLimitError):
            print(f"Rate limited, will retry...")
            return error  # Let retry handler deal with it
        
        if isinstance(error, ProviderError) and error.status_code == 503:
            print("Service unavailable, suppressing error")
            return None  # Suppress the error
        
        return error  # Re-raise other errors

Built-in Middleware

LoggingMiddleware

from agent.middleware import LoggingMiddleware

# With default print
agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[LoggingMiddleware()],
)

# With custom logger
import logging
logger = logging.getLogger(__name__)

agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[LoggingMiddleware(log_fn=logger.info)],
)

MetricsMiddleware

from agent.middleware import MetricsMiddleware

metrics = MetricsMiddleware()

agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[metrics],
)

# Make some requests
agent.run("Hello")
agent.run("World")

# Get metrics
stats = metrics.stats()
print(f"Requests: {stats['request_count']}")
print(f"Total tokens: {stats['total_tokens']}")
print(f"Errors: {stats['error_count']}")
print(f"Avg latency: {stats['avg_latency_ms']:.2f}ms")

RedactionMiddleware

from agent.middleware import RedactionMiddleware

# Redact sensitive patterns
redactor = RedactionMiddleware(patterns=[
    r"sk-[a-zA-Z0-9]{20,}",           # API keys
    r"\b\d{16}\b",                     # Credit cards
    r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",  # Emails
])

agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[redactor],
)

RetryPolicyMiddleware

from agent.middleware import RetryPolicyMiddleware
from agent.errors import RateLimitError, ProviderError

retry_policy = RetryPolicyMiddleware(
    max_retries=5,
    retryable_errors=(RateLimitError, ProviderError),
)

agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[retry_policy],
)

Chaining Middleware

Middleware executes in order for before hooks and reverse order for after hooks:

class TimingMiddleware(Middleware):
    def before(self, request):
        request.metadata["start_time"] = time.time()
        return request
    
    def after(self, request, response):
        elapsed = time.time() - request.metadata["start_time"]
        print(f"Total time: {elapsed:.2f}s")
        return response

class CachingMiddleware(Middleware):
    def __init__(self):
        self.cache = {}
    
    def before(self, request):
        key = hash(request.input)
        if key in self.cache:
            # Return cached response (skip provider call)
            raise CacheHit(self.cache[key])
        return request
    
    def after(self, request, response):
        key = hash(request.input)
        self.cache[key] = response
        return response

# Order: TimingMiddleware.before -> CachingMiddleware.before -> Provider
#        CachingMiddleware.after -> TimingMiddleware.after
agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=[TimingMiddleware(), CachingMiddleware()],
)

Custom Middleware Examples

Rate Limiting

import time
from collections import deque

class RateLimiter(Middleware):
    def __init__(self, requests_per_minute: int = 60):
        self.rpm = requests_per_minute
        self.requests = deque()
    
    def before(self, request):
        now = time.time()
        
        # Remove old requests
        while self.requests and self.requests[0] < now - 60:
            self.requests.popleft()
        
        # Check rate limit
        if len(self.requests) >= self.rpm:
            wait_time = 60 - (now - self.requests[0])
            time.sleep(wait_time)
        
        self.requests.append(now)
        return request

Content Filtering

class ContentFilter(Middleware):
    def __init__(self, blocked_words: list[str]):
        self.blocked = set(w.lower() for w in blocked_words)
    
    def before(self, request):
        if request.input:
            words = set(request.input.lower().split())
            if words & self.blocked:
                raise ValueError("Request contains blocked content")
        return request
    
    def after(self, request, response):
        if response.text:
            words = set(response.text.lower().split())
            if words & self.blocked:
                response.text = "[Content filtered]"
        return response

Request Tracing

import uuid

class TracingMiddleware(Middleware):
    def before(self, request):
        trace_id = str(uuid.uuid4())
        request.metadata["trace_id"] = trace_id
        print(f"[{trace_id}] Starting request")
        return request
    
    def after(self, request, response):
        trace_id = request.metadata.get("trace_id", "unknown")
        print(f"[{trace_id}] Completed in {response.latency_ms}ms")
        return response
    
    def on_error(self, request, error):
        trace_id = request.metadata.get("trace_id", "unknown")
        print(f"[{trace_id}] Error: {error}")
        return error

Cost Tracking

class CostTracker(Middleware):
    def __init__(self):
        self.total_cost = 0.0
        self.budget = float('inf')
    
    def set_budget(self, amount: float):
        self.budget = amount
    
    def before(self, request):
        if self.total_cost >= self.budget:
            raise BudgetExceededError(f"Budget of ${self.budget} exceeded")
        return request
    
    def after(self, request, response):
        if response.cost_estimate:
            self.total_cost += response.cost_estimate
            print(f"Cost: ${response.cost_estimate:.4f} (Total: ${self.total_cost:.4f})")
        return response

Response Validation

class ResponseValidator(Middleware):
    def __init__(self, min_length: int = 10):
        self.min_length = min_length
    
    def after(self, request, response):
        if response.text and len(response.text) < self.min_length:
            raise ValueError(f"Response too short: {len(response.text)} chars")
        return response

Middleware Chain

Access the middleware chain programmatically:

from agent.middleware import MiddlewareChain

chain = MiddlewareChain([
    LoggingMiddleware(),
    MetricsMiddleware(),
])

# Add more middleware
chain.add(RateLimiter())

# Use with agent
agent = Agent(
    provider="openai",
    model="gpt-4o",
    middleware=chain.middlewares,
)

Best Practices

1. Keep Middleware Focused

# Good - single responsibility
class LoggingMiddleware(Middleware):
    def before(self, request):
        print(f"Request: {request.input}")
        return request

class MetricsMiddleware(Middleware):
    def after(self, request, response):
        self.record_latency(response.latency_ms)
        return response

# Bad - too many responsibilities
class KitchenSinkMiddleware(Middleware):
    def before(self, request):
        print(f"Request: {request.input}")
        self.check_rate_limit()
        self.validate_input()
        self.add_tracing()
        return request

2. Don’t Swallow Errors Silently

# Good - log before suppressing
def on_error(self, request, error):
    if should_suppress(error):
        logger.warning(f"Suppressing error: {error}")
        return None
    return error

# Bad - silent suppression
def on_error(self, request, error):
    return None  # What happened?

3. Be Careful with Request Modification

# Good - copy before modifying
def before(self, request):
    modified = AgentRequest(
        input=request.input,
        system=f"PREFIX: {request.system}",
        # ... copy other fields
    )
    return modified

# Risky - mutating shared state
def before(self, request):
    request.system = f"PREFIX: {request.system}"
    return request

Next Steps