Custom Providers

Implement custom provider adapters to support additional LLM backends or specialized endpoints.

Provider Interface

All providers must extend BaseProvider and implement the abstract methods:

from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Iterator

from agent.providers.base import BaseProvider
from agent.messages import AgentRequest
from agent.response import AgentResponse
from agent.stream import StreamEvent
from agent.types import ProviderCapabilities


class CustomProvider(BaseProvider):
    """Your custom provider implementation."""
    
    name = "custom"
    capabilities = ProviderCapabilities(
        streaming=True,
        tools=True,
        structured_output=True,
        json_mode=True,
        vision=False,
        system_messages=True,
    )
    
    def __init__(
        self,
        api_key: str | None = None,
        base_url: str | None = None,
        timeout: float = 120.0,
        max_retries: int = 2,
        **kwargs: Any,
    ):
        super().__init__(api_key, base_url, timeout, max_retries, **kwargs)
        # Initialize your client
        self.client = MyAPIClient(api_key, base_url)
    
    def run(self, request: AgentRequest) -> AgentResponse:
        """Execute a synchronous request."""
        # Convert request to your API format
        api_request = self._convert_request(request)
        
        # Call your API
        api_response = self.client.chat(api_request)
        
        # Convert to normalized response
        return self._convert_response(api_response)
    
    async def run_async(self, request: AgentRequest) -> AgentResponse:
        """Execute an asynchronous request."""
        api_request = self._convert_request(request)
        api_response = await self.client.chat_async(api_request)
        return self._convert_response(api_response)
    
    def stream(self, request: AgentRequest) -> Iterator[StreamEvent]:
        """Execute a streaming request."""
        api_request = self._convert_request(request)
        
        for chunk in self.client.chat_stream(api_request):
            yield self._convert_chunk(chunk)
    
    async def stream_async(self, request: AgentRequest) -> AsyncIterator[StreamEvent]:
        """Execute an async streaming request."""
        api_request = self._convert_request(request)
        
        async for chunk in self.client.chat_stream_async(api_request):
            yield self._convert_chunk(chunk)

Converting Requests

Transform the normalized AgentRequest to your API format:

def _convert_request(self, request: AgentRequest) -> dict:
    """Convert AgentRequest to API format."""
    messages = []
    
    # Handle system message
    if request.system:
        messages.append({
            "role": "system",
            "content": request.system,
        })
    
    # Convert existing messages
    for msg in request.messages:
        messages.append(self._convert_message(msg))
    
    # Add user input
    if request.input:
        messages.append({
            "role": "user",
            "content": request.input,
        })
    
    api_request = {
        "messages": messages,
        "model": self.extra_config.get("model", "default-model"),
    }
    
    # Add optional parameters
    if request.temperature is not None:
        api_request["temperature"] = request.temperature
    if request.max_tokens is not None:
        api_request["max_tokens"] = request.max_tokens
    
    # Handle tools
    if request.tools:
        api_request["tools"] = [
            tool.to_openai_schema()  # Or your own format
            for tool in request.tools
        ]
    
    return api_request

def _convert_message(self, msg: Message) -> dict:
    """Convert a single message."""
    result = {"role": msg.role, "content": msg.content}
    
    if msg.role == "assistant" and msg.tool_calls:
        result["tool_calls"] = msg.tool_calls
    
    if msg.role == "tool":
        result["tool_call_id"] = msg.tool_call_id
    
    return result

Converting Responses

Transform API responses to the normalized AgentResponse:

from agent.types import Usage, ToolCall

def _convert_response(self, api_response: dict) -> AgentResponse:
    """Convert API response to AgentResponse."""
    
    # Extract text content
    text = api_response.get("content", "")
    
    # Extract tool calls
    tool_calls = []
    for tc in api_response.get("tool_calls", []):
        tool_calls.append(ToolCall(
            id=tc["id"],
            name=tc["function"]["name"],
            arguments=json.loads(tc["function"]["arguments"]),
        ))
    
    # Extract usage
    usage = None
    if "usage" in api_response:
        usage = Usage(
            prompt_tokens=api_response["usage"]["prompt_tokens"],
            completion_tokens=api_response["usage"]["completion_tokens"],
            total_tokens=api_response["usage"]["total_tokens"],
        )
    
    return AgentResponse(
        text=text,
        content=[{"type": "text", "text": text}] if text else [],
        provider=self.name,
        model=api_response.get("model", ""),
        usage=usage,
        stop_reason=api_response.get("finish_reason"),
        tool_calls=tool_calls,
        raw=api_response,
    )

Converting Stream Events

Transform streaming chunks to normalized StreamEvent:

def _convert_chunk(self, chunk: dict) -> Iterator[StreamEvent]:
    """Convert a streaming chunk to events."""
    
    # Text delta
    if "delta" in chunk and "content" in chunk["delta"]:
        yield StreamEvent.text_delta(
            chunk["delta"]["content"],
            raw=chunk,
        )
    
    # Tool call
    if "delta" in chunk and "tool_calls" in chunk["delta"]:
        for tc in chunk["delta"]["tool_calls"]:
            if tc.get("function", {}).get("name"):
                # New tool call
                yield StreamEvent.tool_call_start(
                    ToolCall(
                        id=tc["id"],
                        name=tc["function"]["name"],
                        arguments={},
                    ),
                    raw=chunk,
                )
            elif tc.get("function", {}).get("arguments"):
                # Argument delta
                yield StreamEvent.tool_call_delta_event(
                    tc["id"],
                    {"arguments": tc["function"]["arguments"]},
                    raw=chunk,
                )
    
    # Usage
    if "usage" in chunk:
        yield StreamEvent.usage_event(
            Usage(
                prompt_tokens=chunk["usage"]["prompt_tokens"],
                completion_tokens=chunk["usage"]["completion_tokens"],
                total_tokens=chunk["usage"]["total_tokens"],
            ),
            raw=chunk,
        )
    
    # End of message
    if chunk.get("finish_reason"):
        yield StreamEvent.message_end(raw=chunk)

Error Handling

Convert API errors to Agent errors:

from agent.errors import (
    AuthenticationError,
    ProviderError,
    RateLimitError,
    TimeoutError as AgentTimeoutError,
)

def run(self, request: AgentRequest) -> AgentResponse:
    try:
        api_response = self.client.chat(...)
        return self._convert_response(api_response)
    except MyAuthError as e:
        raise AuthenticationError(str(e), raw=e)
    except MyRateLimitError as e:
        raise RateLimitError(
            str(e),
            provider=self.name,
            retry_after=e.retry_after,
            raw=e,
        )
    except MyTimeoutError as e:
        raise AgentTimeoutError(str(e), timeout=self.timeout, raw=e)
    except MyAPIError as e:
        raise ProviderError(
            str(e),
            provider=self.name,
            status_code=e.status_code,
            raw=e,
        )

Registering Your Provider

Register the provider for use with the Agent class:

from agent.providers.registry import ProviderRegistry

# Register with optional aliases
ProviderRegistry.register(
    "custom",
    CustomProvider,
    aliases=["my-provider", "custom-llm"],
)

# Now you can use it
from agent import Agent

agent = Agent(
    provider="custom",  # or "my-provider"
    model="my-model",
    api_key="...",
)

OpenAI-Compatible Providers

For APIs that follow the OpenAI format, extend the OpenAI provider:

from agent.providers.openai import OpenAIProvider
from agent.providers.registry import ProviderRegistry
from agent.types import ProviderCapabilities


class GroqProvider(OpenAIProvider):
    """Groq provider - OpenAI-compatible API."""
    
    name = "groq"
    capabilities = ProviderCapabilities(
        streaming=True,
        tools=True,
        structured_output=True,
        json_mode=True,
        vision=False,  # Groq doesn't support vision
        system_messages=True,
    )
    
    GROQ_BASE_URL = "https://api.groq.com/openai/v1"
    
    def __init__(
        self,
        api_key: str | None = None,
        base_url: str | None = None,
        **kwargs,
    ):
        super().__init__(
            api_key=api_key,
            base_url=base_url or self.GROQ_BASE_URL,
            **kwargs,
        )


ProviderRegistry.register("groq", GroqProvider)

Testing Your Provider

import pytest
from agent import Agent
from agent.providers.registry import ProviderRegistry

# Register test provider
ProviderRegistry.register("test", TestProvider)

def test_basic_run():
    agent = Agent(provider="test", model="test-model")
    response = agent.run("Hello")
    assert response.text is not None
    assert response.provider == "test"

def test_streaming():
    agent = Agent(provider="test", model="test-model")
    events = list(agent.stream("Hello"))
    assert any(e.type == "text_delta" for e in events)
    assert any(e.type == "message_end" for e in events)

def test_tool_calling():
    @tool
    def my_tool(x: str) -> str:
        return x
    
    agent = Agent(provider="test", model="test-model", tools=[my_tool])
    response = agent.run("Use the tool")
    # Verify tool calls work

@pytest.mark.asyncio
async def test_async():
    agent = Agent(provider="test", model="test-model")
    response = await agent.run_async("Hello")
    assert response.text is not None

Complete Example

Here’s a complete example for a hypothetical API:

"""
MyLLM Provider Adapter
"""

import json
from typing import Any, AsyncIterator, Iterator

import httpx

from agent.errors import (
    AuthenticationError,
    ProviderError,
    RateLimitError,
)
from agent.messages import AgentRequest, Message
from agent.providers.base import BaseProvider
from agent.providers.registry import ProviderRegistry
from agent.response import AgentResponse
from agent.stream import StreamEvent
from agent.types import ProviderCapabilities, Usage, ToolCall


class MyLLMProvider(BaseProvider):
    """Provider adapter for MyLLM API."""
    
    name = "myllm"
    capabilities = ProviderCapabilities(
        streaming=True,
        tools=True,
        structured_output=True,
        json_mode=True,
        vision=True,
        system_messages=True,
    )
    
    BASE_URL = "https://api.myllm.com/v1"
    
    def __init__(self, api_key: str | None = None, **kwargs):
        super().__init__(api_key=api_key, **kwargs)
        self.client = httpx.Client(
            base_url=self.base_url or self.BASE_URL,
            headers={"Authorization": f"Bearer {api_key}"},
            timeout=self.timeout,
        )
        self.async_client = httpx.AsyncClient(
            base_url=self.base_url or self.BASE_URL,
            headers={"Authorization": f"Bearer {api_key}"},
            timeout=self.timeout,
        )
    
    def run(self, request: AgentRequest) -> AgentResponse:
        try:
            response = self.client.post(
                "/chat/completions",
                json=self._build_request(request),
            )
            response.raise_for_status()
            return self._parse_response(response.json())
        except httpx.HTTPStatusError as e:
            self._handle_error(e)
    
    async def run_async(self, request: AgentRequest) -> AgentResponse:
        try:
            response = await self.async_client.post(
                "/chat/completions",
                json=self._build_request(request),
            )
            response.raise_for_status()
            return self._parse_response(response.json())
        except httpx.HTTPStatusError as e:
            self._handle_error(e)
    
    def stream(self, request: AgentRequest) -> Iterator[StreamEvent]:
        req = self._build_request(request)
        req["stream"] = True
        
        with self.client.stream("POST", "/chat/completions", json=req) as response:
            for line in response.iter_lines():
                if line.startswith("data: "):
                    data = json.loads(line[6:])
                    yield from self._parse_chunk(data)
    
    async def stream_async(self, request: AgentRequest) -> AsyncIterator[StreamEvent]:
        req = self._build_request(request)
        req["stream"] = True
        
        async with self.async_client.stream("POST", "/chat/completions", json=req) as response:
            async for line in response.aiter_lines():
                if line.startswith("data: "):
                    data = json.loads(line[6:])
                    for event in self._parse_chunk(data):
                        yield event
    
    def _build_request(self, request: AgentRequest) -> dict:
        # ... implementation
        pass
    
    def _parse_response(self, data: dict) -> AgentResponse:
        # ... implementation
        pass
    
    def _parse_chunk(self, data: dict) -> Iterator[StreamEvent]:
        # ... implementation
        pass
    
    def _handle_error(self, error: httpx.HTTPStatusError):
        if error.response.status_code == 401:
            raise AuthenticationError("Invalid API key", raw=error)
        elif error.response.status_code == 429:
            raise RateLimitError("Rate limited", provider=self.name, raw=error)
        else:
            raise ProviderError(str(error), provider=self.name, raw=error)


# Register the provider
ProviderRegistry.register("myllm", MyLLMProvider, aliases=["my-llm"])

Next Steps