Source code for agent.router

"""
Agent router for multi-agent routing and fallback.
"""

import asyncio
import concurrent.futures
import threading
from collections.abc import Callable
from typing import Any

from pydantic import BaseModel

from agent.agent import Agent
from agent.config import PRICING, resolve_model
from agent.errors import AgentError, RoutingError
from agent.messages import AgentRequest, Message
from agent.response import AgentResponse
from agent.stream import AsyncStreamResponse, StreamResponse
from agent.types.router import RouteResult, RoutingStrategy


[docs] class AgentRouter: """ Routes requests across multiple agents with fallback support. The router provides strategies for: - Automatic failover when providers fail - Load balancing across providers - Capability-based routing - Cost optimization Example: ```python router = AgentRouter( agents=[ Agent(provider="anthropic", model="claude-sonnet"), Agent(provider="openai", model="gpt-4o"), ], strategy="fallback", ) # Automatically falls back if first provider fails response = router.run("Hello!") ``` """ def __init__( self, agents: list[Agent], strategy: RoutingStrategy | str = RoutingStrategy.FALLBACK, custom_router: Callable[[AgentRequest, list[Agent]], RouteResult] | None = None, ): """ Initialize the router. Args: agents: List of agents to route between strategy: Routing strategy to use custom_router: Custom routing function (required if strategy is CUSTOM) """ if not agents: raise ValueError("At least one agent is required") self.agents = agents self.strategy = RoutingStrategy(strategy) if isinstance(strategy, str) else strategy self.custom_router = custom_router self._round_robin_index = 0 self._rr_lock = threading.Lock() if self.strategy == RoutingStrategy.CUSTOM and not custom_router: raise ValueError("custom_router required when strategy is CUSTOM")
[docs] def run( self, input: str | None = None, *, messages: list[Message] | None = None, system: str | None = None, temperature: float | None = None, max_tokens: int | None = None, stop: list[str] | None = None, metadata: dict[str, Any] | None = None, ) -> AgentResponse: """ Execute a request using the routing strategy. Args: input: User input text messages: Optional message history system: System prompt temperature: Sampling temperature max_tokens: Max tokens stop: Stop sequences metadata: Request metadata Returns: AgentResponse from the selected agent Raises: RoutingError: If all agents fail """ request = AgentRequest( input=input, messages=messages or [], system=system, temperature=temperature, max_tokens=max_tokens, stop=stop, metadata=metadata or {}, ) if self.strategy == RoutingStrategy.FALLBACK: return self._run_fallback(request) elif self.strategy == RoutingStrategy.ROUND_ROBIN: return self._run_round_robin(request) elif self.strategy == RoutingStrategy.FASTEST: return self._run_fastest_sync(request) elif self.strategy == RoutingStrategy.CHEAPEST: return self._run_cheapest(request) elif self.strategy == RoutingStrategy.CAPABILITY: return self._run_capability(request) elif self.strategy == RoutingStrategy.CUSTOM: return self._run_custom(request) else: raise ValueError(f"Unknown strategy: {self.strategy}")
[docs] async def run_async( self, input: str | None = None, *, messages: list[Message] | None = None, system: str | None = None, temperature: float | None = None, max_tokens: int | None = None, stop: list[str] | None = None, metadata: dict[str, Any] | None = None, ) -> AgentResponse: """ Execute a request asynchronously using the routing strategy. Args: input: User input text messages: Optional message history system: System prompt temperature: Sampling temperature max_tokens: Max tokens stop: Stop sequences metadata: Request metadata Returns: AgentResponse from the selected agent """ request = AgentRequest( input=input, messages=messages or [], system=system, temperature=temperature, max_tokens=max_tokens, stop=stop, metadata=metadata or {}, ) if self.strategy == RoutingStrategy.FALLBACK: return await self._run_fallback_async(request) elif self.strategy == RoutingStrategy.ROUND_ROBIN: return await self._run_round_robin_async(request) elif self.strategy == RoutingStrategy.FASTEST: return await self._run_fastest_async(request) elif self.strategy == RoutingStrategy.CHEAPEST: return await self._run_cheapest_async(request) elif self.strategy == RoutingStrategy.CAPABILITY: return await self._run_capability_async(request) elif self.strategy == RoutingStrategy.CUSTOM: return await self._run_custom_async(request) else: raise ValueError(f"Unknown strategy: {self.strategy}")
[docs] def stream( self, input: str | None = None, *, messages: list[Message] | None = None, system: str | None = None, temperature: float | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, ) -> StreamResponse: """ Stream a response using the routing strategy. For fallback strategy, only tries next agent if initial connection fails. Args: input: User input text messages: Optional message history system: System prompt temperature: Sampling temperature max_tokens: Max tokens metadata: Request metadata Returns: StreamResponse from the selected agent """ errors: list[Exception] = [] for agent in self._get_ordered_agents(): try: return agent.stream( input=input, messages=messages, system=system, temperature=temperature, max_tokens=max_tokens, metadata=metadata, ) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, )
[docs] async def stream_async( self, input: str | None = None, *, messages: list[Message] | None = None, system: str | None = None, temperature: float | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, ) -> AsyncStreamResponse: """ Stream a response asynchronously using the routing strategy. Args: input: User input text messages: Optional message history system: System prompt temperature: Sampling temperature max_tokens: Max tokens metadata: Request metadata Returns: AsyncStreamResponse from the selected agent """ errors: list[Exception] = [] for agent in self._get_ordered_agents(): try: return await agent.stream_async( input=input, messages=messages, system=system, temperature=temperature, max_tokens=max_tokens, metadata=metadata, ) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, )
[docs] def json( self, input: str | None = None, *, schema: type[BaseModel] | dict[str, Any], messages: list[Message] | None = None, system: str | None = None, temperature: float | None = None, max_tokens: int | None = None, metadata: dict[str, Any] | None = None, ) -> AgentResponse: """ Execute a structured output request with routing. Args: input: User input text schema: Output schema messages: Optional message history system: System prompt temperature: Sampling temperature max_tokens: Max tokens metadata: Request metadata Returns: AgentResponse with parsed output """ errors: list[Exception] = [] for agent in self._get_ordered_agents(): try: return agent.json( input=input, schema=schema, messages=messages, system=system, temperature=temperature, max_tokens=max_tokens, metadata=metadata, ) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed for structured output", errors=errors, )
# Strategy implementations def _run_fallback(self, request: AgentRequest) -> AgentResponse: """Try each agent in order until one succeeds.""" errors: list[Exception] = [] for agent in self.agents: try: return agent._runtime.run(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) async def _run_fallback_async(self, request: AgentRequest) -> AgentResponse: """Try each agent in order until one succeeds (async).""" errors: list[Exception] = [] for agent in self.agents: try: return await agent._runtime.run_async(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) def _run_round_robin(self, request: AgentRequest) -> AgentResponse: """Rotate through agents.""" errors: list[Exception] = [] with self._rr_lock: start_index = self._round_robin_index self._round_robin_index = (start_index + 1) % len(self.agents) for i in range(len(self.agents)): index = (start_index + i) % len(self.agents) agent = self.agents[index] try: return agent._runtime.run(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) async def _run_round_robin_async(self, request: AgentRequest) -> AgentResponse: """Rotate through agents (async).""" errors: list[Exception] = [] with self._rr_lock: start_index = self._round_robin_index self._round_robin_index = (start_index + 1) % len(self.agents) for i in range(len(self.agents)): index = (start_index + i) % len(self.agents) agent = self.agents[index] try: return await agent._runtime.run_async(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) def _run_fastest_sync(self, request: AgentRequest) -> AgentResponse: """Race agents synchronously (uses threads).""" with concurrent.futures.ThreadPoolExecutor(max_workers=len(self.agents)) as executor: futures = {executor.submit(agent._runtime.run, request): agent for agent in self.agents} errors: list[Exception] = [] for future in concurrent.futures.as_completed(futures): try: result = future.result() # Cancel remaining futures for f in futures: if f is not future and not f.done(): f.cancel() return result except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) async def _run_fastest_async(self, request: AgentRequest) -> AgentResponse: """Race agents, return first successful response.""" tasks = [asyncio.create_task(agent._runtime.run_async(request)) for agent in self.agents] errors: list[Exception] = [] done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_COMPLETED, ) # Get first successful result for task in done: try: result = task.result() # Cancel pending tasks and wait for cancellation for p in pending: p.cancel() if pending: await asyncio.gather(*pending, return_exceptions=True) return result except AgentError as e: errors.append(e) # Wait for remaining if first failed if pending: done, _ = await asyncio.wait(pending) for task in done: try: return task.result() except AgentError as e: errors.append(e) raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) @staticmethod def _get_agent_cost(agent: Agent) -> float: """Estimate cost for an agent based on model pricing.""" model = resolve_model(agent.model) pricing = PRICING.get(model, {}) return pricing.get("input", float("inf")) + pricing.get("output", float("inf")) def _run_cheapest(self, request: AgentRequest) -> AgentResponse: """Use the cheapest available agent.""" sorted_agents = sorted(self.agents, key=self._get_agent_cost) errors: list[Exception] = [] for agent in sorted_agents: try: return agent._runtime.run(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) async def _run_cheapest_async(self, request: AgentRequest) -> AgentResponse: """Use the cheapest available agent (async).""" sorted_agents = sorted(self.agents, key=self._get_agent_cost) errors: list[Exception] = [] for agent in sorted_agents: try: return await agent._runtime.run_async(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All {len(self.agents)} agents failed", errors=errors, ) def _run_capability(self, request: AgentRequest) -> AgentResponse: """Route based on required capabilities.""" # Determine required capabilities from request needs_tools = bool(request.tools) needs_schema = bool(request.schema) # Filter to capable agents capable_agents = [] for agent in self.agents: provider = agent._provider if needs_tools and not provider.supports_tools(): continue if needs_schema and not provider.supports_structured_output(): continue capable_agents.append(agent) if not capable_agents: raise RoutingError( "No agents support the required capabilities", errors=[], ) # Use fallback among capable agents errors: list[Exception] = [] for agent in capable_agents: try: return agent._runtime.run(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All capable agents ({len(capable_agents)}) failed", errors=errors, ) async def _run_capability_async(self, request: AgentRequest) -> AgentResponse: """Route based on required capabilities (async).""" needs_tools = bool(request.tools) needs_schema = bool(request.schema) capable_agents = [] for agent in self.agents: provider = agent._provider if needs_tools and not provider.supports_tools(): continue if needs_schema and not provider.supports_structured_output(): continue capable_agents.append(agent) if not capable_agents: raise RoutingError( "No agents support the required capabilities", errors=[], ) errors: list[Exception] = [] for agent in capable_agents: try: return await agent._runtime.run_async(request) except AgentError as e: errors.append(e) continue raise RoutingError( f"All capable agents ({len(capable_agents)}) failed", errors=errors, ) def _run_custom(self, request: AgentRequest) -> AgentResponse: """Use custom routing function.""" if not self.custom_router: raise ValueError("custom_router not configured") result = self.custom_router(request, self.agents) return result.agent._runtime.run(request) async def _run_custom_async(self, request: AgentRequest) -> AgentResponse: """Use custom routing function (async).""" if not self.custom_router: raise ValueError("custom_router not configured") result = self.custom_router(request, self.agents) return await result.agent._runtime.run_async(request) def _get_ordered_agents(self) -> list[Agent]: """Get agents in order based on strategy.""" if self.strategy == RoutingStrategy.ROUND_ROBIN: with self._rr_lock: start = self._round_robin_index self._round_robin_index = (start + 1) % len(self.agents) return self.agents[start:] + self.agents[:start] if self.strategy == RoutingStrategy.CHEAPEST: return sorted(self.agents, key=self._get_agent_cost) return self.agents def __repr__(self) -> str: return f"AgentRouter(agents={len(self.agents)}, strategy={self.strategy.value})"