Source code for agent.providers.base
"""
Base provider interface.
All provider adapters must implement this interface.
"""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
from typing import Any
from agent.messages import AgentRequest
from agent.response import AgentResponse
from agent.stream import StreamEvent
from agent.types.config import ProviderCapabilities
[docs]
class BaseProvider(ABC):
"""
Base class for all provider adapters.
Each provider must implement the core methods to handle
request conversion, response normalization, and streaming.
"""
name: str = "base"
capabilities: ProviderCapabilities = ProviderCapabilities()
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 120.0,
max_retries: int = 2,
**kwargs: Any,
):
self.api_key = api_key
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
self.extra_config = kwargs
[docs]
@abstractmethod
def run(self, request: AgentRequest) -> AgentResponse:
"""
Execute a synchronous request.
Args:
request: Normalized agent request
Returns:
Normalized agent response
"""
...
[docs]
@abstractmethod
async def run_async(self, request: AgentRequest) -> AgentResponse:
"""
Execute an asynchronous request.
Args:
request: Normalized agent request
Returns:
Normalized agent response
"""
...
[docs]
@abstractmethod
def stream(self, request: AgentRequest) -> Iterator[StreamEvent]:
"""
Execute a streaming request.
Args:
request: Normalized agent request
Yields:
Normalized stream events
"""
...
[docs]
@abstractmethod
def stream_async(self, request: AgentRequest) -> AsyncIterator[StreamEvent]:
"""
Execute an asynchronous streaming request.
Subclasses implement this as an async generator (``async def`` with ``yield``).
Callers should iterate with ``async for``, not ``await``.
Args:
request: Normalized agent request
Yields:
Normalized stream events
"""
...
[docs]
def supports_structured_output(self) -> bool:
"""Check if provider supports structured output."""
return self.capabilities.structured_output
[docs]
def supports_vision(self) -> bool:
"""Check if provider supports vision/images."""
return self.capabilities.vision
[docs]
def supports_streaming(self) -> bool:
"""Check if provider supports streaming."""
return self.capabilities.streaming
[docs]
def supports_json_mode(self) -> bool:
"""Check if provider supports JSON mode."""
return self.capabilities.json_mode
[docs]
def supports_native_schema(self) -> bool:
"""Check if provider supports native schema-enforced output."""
return self.capabilities.native_schema_output
[docs]
def validate_config(self) -> list[str]:
"""
Validate provider configuration.
Returns:
List of validation error messages (empty if valid)
"""
errors = []
if not self.api_key:
errors.append(f"API key required for {self.name} provider")
return errors
def __repr__(self) -> str:
masked_key = "***" if self.api_key else "None"
return f"{self.__class__.__name__}(name={self.name!r}, api_key={masked_key!r})"