"""
Agent router for multi-agent routing and fallback.
"""
import asyncio
import concurrent.futures
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from agent.agent import Agent
from agent.config import PRICING, resolve_model
from agent.errors import AgentError, RoutingError
from agent.messages import AgentRequest, Message
from agent.response import AgentResponse
from agent.stream import AsyncStreamResponse, StreamResponse
from agent.types.router import RouteResult, RoutingStrategy
[docs]
class AgentRouter:
"""
Routes requests across multiple agents with fallback support.
The router provides strategies for:
- Automatic failover when providers fail
- Load balancing across providers
- Capability-based routing
- Cost optimization
Example:
```python
router = AgentRouter(
agents=[
Agent(provider="anthropic", model="claude-sonnet"),
Agent(provider="openai", model="gpt-4o"),
],
strategy="fallback",
)
# Automatically falls back if first provider fails
response = router.run("Hello!")
```
"""
def __init__(
self,
agents: list[Agent],
strategy: RoutingStrategy | str = RoutingStrategy.FALLBACK,
custom_router: Callable[[AgentRequest, list[Agent]], RouteResult] | None = None,
):
"""
Initialize the router.
Args:
agents: List of agents to route between
strategy: Routing strategy to use
custom_router: Custom routing function (required if strategy is CUSTOM)
"""
if not agents:
raise ValueError("At least one agent is required")
self.agents = agents
self.strategy = RoutingStrategy(strategy) if isinstance(strategy, str) else strategy
self.custom_router = custom_router
self._round_robin_index = 0
if self.strategy == RoutingStrategy.CUSTOM and not custom_router:
raise ValueError("custom_router required when strategy is CUSTOM")
[docs]
def run(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a request using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
stop: Stop sequences
metadata: Request metadata
Returns:
AgentResponse from the selected agent
Raises:
RoutingError: If all agents fail
"""
request = AgentRequest(
input=input,
messages=messages or [],
system=system,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
metadata=metadata or {},
)
if self.strategy == RoutingStrategy.FALLBACK:
return self._run_fallback(request)
elif self.strategy == RoutingStrategy.ROUND_ROBIN:
return self._run_round_robin(request)
elif self.strategy == RoutingStrategy.FASTEST:
return self._run_fastest_sync(request)
elif self.strategy == RoutingStrategy.CHEAPEST:
return self._run_cheapest(request)
elif self.strategy == RoutingStrategy.CAPABILITY:
return self._run_capability(request)
elif self.strategy == RoutingStrategy.CUSTOM:
return self._run_custom(request)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
[docs]
async def run_async(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a request asynchronously using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
stop: Stop sequences
metadata: Request metadata
Returns:
AgentResponse from the selected agent
"""
request = AgentRequest(
input=input,
messages=messages or [],
system=system,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
metadata=metadata or {},
)
if self.strategy == RoutingStrategy.FALLBACK:
return await self._run_fallback_async(request)
elif self.strategy == RoutingStrategy.ROUND_ROBIN:
return await self._run_round_robin_async(request)
elif self.strategy == RoutingStrategy.FASTEST:
return await self._run_fastest_async(request)
elif self.strategy == RoutingStrategy.CHEAPEST:
return await self._run_cheapest_async(request)
elif self.strategy == RoutingStrategy.CAPABILITY:
return await self._run_capability_async(request)
elif self.strategy == RoutingStrategy.CUSTOM:
return await self._run_custom_async(request)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
[docs]
def stream(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> StreamResponse:
"""
Stream a response using the routing strategy.
For fallback strategy, only tries next agent if initial connection fails.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
StreamResponse from the selected agent
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return agent.stream(
input=input,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
[docs]
async def stream_async(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> AsyncStreamResponse:
"""
Stream a response asynchronously using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
AsyncStreamResponse from the selected agent
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return await agent.stream_async(
input=input,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
[docs]
def json(
self,
input: str | None = None,
*,
schema: type[BaseModel] | dict[str, Any],
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a structured output request with routing.
Args:
input: User input text
schema: Output schema
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
AgentResponse with parsed output
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return agent.json(
input=input,
schema=schema,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed for structured output",
errors=errors,
)
# Strategy implementations
def _run_fallback(self, request: AgentRequest) -> AgentResponse:
"""Try each agent in order until one succeeds."""
errors: list[Exception] = []
for agent in self.agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_fallback_async(self, request: AgentRequest) -> AgentResponse:
"""Try each agent in order until one succeeds (async)."""
errors: list[Exception] = []
for agent in self.agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_round_robin(self, request: AgentRequest) -> AgentResponse:
"""Rotate through agents."""
errors: list[Exception] = []
start_index = self._round_robin_index
for i in range(len(self.agents)):
index = (start_index + i) % len(self.agents)
agent = self.agents[index]
try:
self._round_robin_index = (index + 1) % len(self.agents)
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_round_robin_async(self, request: AgentRequest) -> AgentResponse:
"""Rotate through agents (async)."""
errors: list[Exception] = []
start_index = self._round_robin_index
for i in range(len(self.agents)):
index = (start_index + i) % len(self.agents)
agent = self.agents[index]
try:
self._round_robin_index = (index + 1) % len(self.agents)
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_fastest_sync(self, request: AgentRequest) -> AgentResponse:
"""Race agents synchronously (uses threads)."""
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.agents)) as executor:
futures = {executor.submit(agent._runtime.run, request): agent for agent in self.agents}
errors: list[Exception] = []
for future in concurrent.futures.as_completed(futures):
try:
return future.result()
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_fastest_async(self, request: AgentRequest) -> AgentResponse:
"""Race agents, return first successful response."""
tasks = [asyncio.create_task(agent._runtime.run_async(request)) for agent in self.agents]
errors: list[Exception] = []
done, pending = await asyncio.wait(
tasks,
return_when=asyncio.FIRST_COMPLETED,
)
# Get first successful result
for task in done:
try:
result = task.result()
# Cancel pending tasks
for p in pending:
p.cancel()
return result
except AgentError as e:
errors.append(e)
# Wait for remaining if first failed
if pending:
done, _ = await asyncio.wait(pending)
for task in done:
try:
return task.result()
except AgentError as e:
errors.append(e)
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_cheapest(self, request: AgentRequest) -> AgentResponse:
"""Use the cheapest available agent."""
# Sort agents by estimated cost
def get_cost(agent: Agent) -> float:
model = resolve_model(agent.model)
pricing = PRICING.get(model, {})
return pricing.get("input", float("inf")) + pricing.get("output", float("inf"))
sorted_agents = sorted(self.agents, key=get_cost)
errors: list[Exception] = []
for agent in sorted_agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_cheapest_async(self, request: AgentRequest) -> AgentResponse:
"""Use the cheapest available agent (async)."""
def get_cost(agent: Agent) -> float:
model = resolve_model(agent.model)
pricing = PRICING.get(model, {})
return pricing.get("input", float("inf")) + pricing.get("output", float("inf"))
sorted_agents = sorted(self.agents, key=get_cost)
errors: list[Exception] = []
for agent in sorted_agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_capability(self, request: AgentRequest) -> AgentResponse:
"""Route based on required capabilities."""
# Determine required capabilities from request
needs_tools = bool(request.tools)
needs_schema = bool(request.schema)
# Filter to capable agents
capable_agents = []
for agent in self.agents:
provider = agent._provider
if needs_tools and not provider.supports_tools():
continue
if needs_schema and not provider.supports_structured_output():
continue
capable_agents.append(agent)
if not capable_agents:
raise RoutingError(
"No agents support the required capabilities",
errors=[],
)
# Use fallback among capable agents
errors: list[Exception] = []
for agent in capable_agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All capable agents ({len(capable_agents)}) failed",
errors=errors,
)
async def _run_capability_async(self, request: AgentRequest) -> AgentResponse:
"""Route based on required capabilities (async)."""
needs_tools = bool(request.tools)
needs_schema = bool(request.schema)
capable_agents = []
for agent in self.agents:
provider = agent._provider
if needs_tools and not provider.supports_tools():
continue
if needs_schema and not provider.supports_structured_output():
continue
capable_agents.append(agent)
if not capable_agents:
raise RoutingError(
"No agents support the required capabilities",
errors=[],
)
errors: list[Exception] = []
for agent in capable_agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All capable agents ({len(capable_agents)}) failed",
errors=errors,
)
def _run_custom(self, request: AgentRequest) -> AgentResponse:
"""Use custom routing function."""
if not self.custom_router:
raise ValueError("custom_router not configured")
result = self.custom_router(request, self.agents)
return result.agent._runtime.run(request)
async def _run_custom_async(self, request: AgentRequest) -> AgentResponse:
"""Use custom routing function (async)."""
if not self.custom_router:
raise ValueError("custom_router not configured")
result = self.custom_router(request, self.agents)
return await result.agent._runtime.run_async(request)
def _get_ordered_agents(self) -> list[Agent]:
"""Get agents in order based on strategy."""
if self.strategy == RoutingStrategy.ROUND_ROBIN:
# Rotate starting point
start = self._round_robin_index
self._round_robin_index = (start + 1) % len(self.agents)
return self.agents[start:] + self.agents[:start]
if self.strategy == RoutingStrategy.CHEAPEST:
def get_cost(agent: Agent) -> float:
model = resolve_model(agent.model)
pricing = PRICING.get(model, {})
return pricing.get("input", float("inf")) + pricing.get("output", float("inf"))
return sorted(self.agents, key=get_cost)
return self.agents
def __repr__(self) -> str:
return f"AgentRouter(agents={len(self.agents)}, strategy={self.strategy.value})"