"""
Agent router for multi-agent routing and fallback.
"""
import asyncio
import concurrent.futures
import threading
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from agent.agent import Agent
from agent.config import PRICING, resolve_model
from agent.errors import AgentError, RoutingError
from agent.messages import AgentRequest, Message
from agent.response import AgentResponse
from agent.stream import AsyncStreamResponse, StreamResponse
from agent.types.router import RouteResult, RoutingStrategy
[docs]
class AgentRouter:
"""
Routes requests across multiple agents with fallback support.
The router provides strategies for:
- Automatic failover when providers fail
- Load balancing across providers
- Capability-based routing
- Cost optimization
Example:
```python
router = AgentRouter(
agents=[
Agent(provider="anthropic", model="claude-sonnet"),
Agent(provider="openai", model="gpt-4o"),
],
strategy="fallback",
)
# Automatically falls back if first provider fails
response = router.run("Hello!")
```
"""
def __init__(
self,
agents: list[Agent],
strategy: RoutingStrategy | str = RoutingStrategy.FALLBACK,
custom_router: Callable[[AgentRequest, list[Agent]], RouteResult] | None = None,
):
"""
Initialize the router.
Args:
agents: List of agents to route between
strategy: Routing strategy to use
custom_router: Custom routing function (required if strategy is CUSTOM)
"""
if not agents:
raise ValueError("At least one agent is required")
self.agents = agents
self.strategy = RoutingStrategy(strategy) if isinstance(strategy, str) else strategy
self.custom_router = custom_router
self._round_robin_index = 0
self._rr_lock = threading.Lock()
if self.strategy == RoutingStrategy.CUSTOM and not custom_router:
raise ValueError("custom_router required when strategy is CUSTOM")
[docs]
def run(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a request using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
stop: Stop sequences
metadata: Request metadata
Returns:
AgentResponse from the selected agent
Raises:
RoutingError: If all agents fail
"""
request = AgentRequest(
input=input,
messages=messages or [],
system=system,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
metadata=metadata or {},
)
if self.strategy == RoutingStrategy.FALLBACK:
return self._run_fallback(request)
elif self.strategy == RoutingStrategy.ROUND_ROBIN:
return self._run_round_robin(request)
elif self.strategy == RoutingStrategy.FASTEST:
return self._run_fastest_sync(request)
elif self.strategy == RoutingStrategy.CHEAPEST:
return self._run_cheapest(request)
elif self.strategy == RoutingStrategy.CAPABILITY:
return self._run_capability(request)
elif self.strategy == RoutingStrategy.CUSTOM:
return self._run_custom(request)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
[docs]
async def run_async(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a request asynchronously using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
stop: Stop sequences
metadata: Request metadata
Returns:
AgentResponse from the selected agent
"""
request = AgentRequest(
input=input,
messages=messages or [],
system=system,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
metadata=metadata or {},
)
if self.strategy == RoutingStrategy.FALLBACK:
return await self._run_fallback_async(request)
elif self.strategy == RoutingStrategy.ROUND_ROBIN:
return await self._run_round_robin_async(request)
elif self.strategy == RoutingStrategy.FASTEST:
return await self._run_fastest_async(request)
elif self.strategy == RoutingStrategy.CHEAPEST:
return await self._run_cheapest_async(request)
elif self.strategy == RoutingStrategy.CAPABILITY:
return await self._run_capability_async(request)
elif self.strategy == RoutingStrategy.CUSTOM:
return await self._run_custom_async(request)
else:
raise ValueError(f"Unknown strategy: {self.strategy}")
[docs]
def stream(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> StreamResponse:
"""
Stream a response using the routing strategy.
For fallback strategy, only tries next agent if initial connection fails.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
StreamResponse from the selected agent
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return agent.stream(
input=input,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
[docs]
async def stream_async(
self,
input: str | None = None,
*,
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> AsyncStreamResponse:
"""
Stream a response asynchronously using the routing strategy.
Args:
input: User input text
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
AsyncStreamResponse from the selected agent
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return await agent.stream_async(
input=input,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
[docs]
def json(
self,
input: str | None = None,
*,
schema: type[BaseModel] | dict[str, Any],
messages: list[Message] | None = None,
system: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
metadata: dict[str, Any] | None = None,
) -> AgentResponse:
"""
Execute a structured output request with routing.
Args:
input: User input text
schema: Output schema
messages: Optional message history
system: System prompt
temperature: Sampling temperature
max_tokens: Max tokens
metadata: Request metadata
Returns:
AgentResponse with parsed output
"""
errors: list[Exception] = []
for agent in self._get_ordered_agents():
try:
return agent.json(
input=input,
schema=schema,
messages=messages,
system=system,
temperature=temperature,
max_tokens=max_tokens,
metadata=metadata,
)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed for structured output",
errors=errors,
)
# Strategy implementations
def _run_fallback(self, request: AgentRequest) -> AgentResponse:
"""Try each agent in order until one succeeds."""
errors: list[Exception] = []
for agent in self.agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_fallback_async(self, request: AgentRequest) -> AgentResponse:
"""Try each agent in order until one succeeds (async)."""
errors: list[Exception] = []
for agent in self.agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_round_robin(self, request: AgentRequest) -> AgentResponse:
"""Rotate through agents."""
errors: list[Exception] = []
with self._rr_lock:
start_index = self._round_robin_index
self._round_robin_index = (start_index + 1) % len(self.agents)
for i in range(len(self.agents)):
index = (start_index + i) % len(self.agents)
agent = self.agents[index]
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_round_robin_async(self, request: AgentRequest) -> AgentResponse:
"""Rotate through agents (async)."""
errors: list[Exception] = []
with self._rr_lock:
start_index = self._round_robin_index
self._round_robin_index = (start_index + 1) % len(self.agents)
for i in range(len(self.agents)):
index = (start_index + i) % len(self.agents)
agent = self.agents[index]
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_fastest_sync(self, request: AgentRequest) -> AgentResponse:
"""Race agents synchronously (uses threads)."""
with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.agents)) as executor:
futures = {executor.submit(agent._runtime.run, request): agent for agent in self.agents}
errors: list[Exception] = []
for future in concurrent.futures.as_completed(futures):
try:
result = future.result()
# Cancel remaining futures
for f in futures:
if f is not future and not f.done():
f.cancel()
return result
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_fastest_async(self, request: AgentRequest) -> AgentResponse:
"""Race agents, return first successful response."""
tasks = [asyncio.create_task(agent._runtime.run_async(request)) for agent in self.agents]
errors: list[Exception] = []
done, pending = await asyncio.wait(
tasks,
return_when=asyncio.FIRST_COMPLETED,
)
# Get first successful result
for task in done:
try:
result = task.result()
# Cancel pending tasks and wait for cancellation
for p in pending:
p.cancel()
if pending:
await asyncio.gather(*pending, return_exceptions=True)
return result
except AgentError as e:
errors.append(e)
# Wait for remaining if first failed
if pending:
done, _ = await asyncio.wait(pending)
for task in done:
try:
return task.result()
except AgentError as e:
errors.append(e)
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
@staticmethod
def _get_agent_cost(agent: Agent) -> float:
"""Estimate cost for an agent based on model pricing."""
model = resolve_model(agent.model)
pricing = PRICING.get(model, {})
return pricing.get("input", float("inf")) + pricing.get("output", float("inf"))
def _run_cheapest(self, request: AgentRequest) -> AgentResponse:
"""Use the cheapest available agent."""
sorted_agents = sorted(self.agents, key=self._get_agent_cost)
errors: list[Exception] = []
for agent in sorted_agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
async def _run_cheapest_async(self, request: AgentRequest) -> AgentResponse:
"""Use the cheapest available agent (async)."""
sorted_agents = sorted(self.agents, key=self._get_agent_cost)
errors: list[Exception] = []
for agent in sorted_agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All {len(self.agents)} agents failed",
errors=errors,
)
def _run_capability(self, request: AgentRequest) -> AgentResponse:
"""Route based on required capabilities."""
# Determine required capabilities from request
needs_tools = bool(request.tools)
needs_schema = bool(request.schema)
# Filter to capable agents
capable_agents = []
for agent in self.agents:
provider = agent._provider
if needs_tools and not provider.supports_tools():
continue
if needs_schema and not provider.supports_structured_output():
continue
capable_agents.append(agent)
if not capable_agents:
raise RoutingError(
"No agents support the required capabilities",
errors=[],
)
# Use fallback among capable agents
errors: list[Exception] = []
for agent in capable_agents:
try:
return agent._runtime.run(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All capable agents ({len(capable_agents)}) failed",
errors=errors,
)
async def _run_capability_async(self, request: AgentRequest) -> AgentResponse:
"""Route based on required capabilities (async)."""
needs_tools = bool(request.tools)
needs_schema = bool(request.schema)
capable_agents = []
for agent in self.agents:
provider = agent._provider
if needs_tools and not provider.supports_tools():
continue
if needs_schema and not provider.supports_structured_output():
continue
capable_agents.append(agent)
if not capable_agents:
raise RoutingError(
"No agents support the required capabilities",
errors=[],
)
errors: list[Exception] = []
for agent in capable_agents:
try:
return await agent._runtime.run_async(request)
except AgentError as e:
errors.append(e)
continue
raise RoutingError(
f"All capable agents ({len(capable_agents)}) failed",
errors=errors,
)
def _run_custom(self, request: AgentRequest) -> AgentResponse:
"""Use custom routing function."""
if not self.custom_router:
raise ValueError("custom_router not configured")
result = self.custom_router(request, self.agents)
return result.agent._runtime.run(request)
async def _run_custom_async(self, request: AgentRequest) -> AgentResponse:
"""Use custom routing function (async)."""
if not self.custom_router:
raise ValueError("custom_router not configured")
result = self.custom_router(request, self.agents)
return await result.agent._runtime.run_async(request)
def _get_ordered_agents(self) -> list[Agent]:
"""Get agents in order based on strategy."""
if self.strategy == RoutingStrategy.ROUND_ROBIN:
with self._rr_lock:
start = self._round_robin_index
self._round_robin_index = (start + 1) % len(self.agents)
return self.agents[start:] + self.agents[:start]
if self.strategy == RoutingStrategy.CHEAPEST:
return sorted(self.agents, key=self._get_agent_cost)
return self.agents
def __repr__(self) -> str:
return f"AgentRouter(agents={len(self.agents)}, strategy={self.strategy.value})"