Source code for agent.providers.gemini

"""
Google Gemini provider adapter.
"""

import uuid
from collections.abc import AsyncIterator, Iterator
from typing import Any, NoReturn

from agent.errors import (
    AuthenticationError,
    ProviderError,
    RateLimitError,
)
from agent.errors import (
    TimeoutError as AgentTimeoutError,
)
from agent.messages import AgentRequest, Message
from agent.providers.base import BaseProvider
from agent.providers.registry import ProviderRegistry
from agent.response import AgentResponse
from agent.stream import StreamEvent
from agent.types.config import ProviderCapabilities
from agent.types.response import Usage
from agent.types.tools import ToolCall

try:
    import google.generativeai as genai
    from google.generativeai.types import (
        Content,  # ty: ignore[unresolved-import]
        GenerationConfig,
        Part,  # ty: ignore[unresolved-import]
    )

    HAS_GEMINI = True
except ImportError:
    HAS_GEMINI = False
    genai: Any = None


[docs] class GeminiProvider(BaseProvider): """ Google Gemini provider adapter. Supports Gemini Pro, Gemini Flash, and other Gemini models. """ name = "gemini" capabilities = ProviderCapabilities( streaming=True, tools=True, structured_output=True, json_mode=True, vision=True, system_messages=True, batch=False, native_schema_output=True, ) def __init__( self, api_key: str | None = None, base_url: str | None = None, timeout: float = 120.0, max_retries: int = 2, **kwargs: Any, ): if not HAS_GEMINI: raise ImportError( "Google Generative AI package not installed. " "Install with: pip install agent-core-py[gemini]" ) super().__init__( api_key=api_key, base_url=base_url, timeout=timeout, max_retries=max_retries, **kwargs, ) # Configure the SDK assert genai is not None self._genai = genai self._genai.configure(api_key=api_key) # Store model name for later self._model_name = kwargs.get("model", "gemini-1.5-pro") def _get_model(self, request: AgentRequest) -> Any: """Get a configured Gemini model.""" model_name = self.extra_config.get("model", self._model_name) # Build generation config generation_config = GenerationConfig( temperature=request.temperature, max_output_tokens=request.max_tokens, top_p=request.top_p, stop_sequences=request.stop, ) # Build tools if present tools = None if request.tools: tools = [self._convert_tool(t) for t in request.tools] return self._genai.GenerativeModel( model_name=model_name, generation_config=generation_config, tools=tools, system_instruction=request.system, )
[docs] def run(self, request: AgentRequest) -> AgentResponse: """Execute a synchronous request.""" try: model = self._get_model(request) contents = self._convert_messages(request) response = model.generate_content(contents) return self._convert_response(response) except Exception as e: self._handle_error(e)
[docs] async def run_async(self, request: AgentRequest) -> AgentResponse: """Execute an asynchronous request.""" try: model = self._get_model(request) contents = self._convert_messages(request) response = await model.generate_content_async(contents) return self._convert_response(response) except Exception as e: self._handle_error(e)
[docs] def stream(self, request: AgentRequest) -> Iterator[StreamEvent]: """Execute a streaming request.""" try: model = self._get_model(request) contents = self._convert_messages(request) response = model.generate_content(contents, stream=True) for chunk in response: yield from self._convert_chunk(chunk) yield StreamEvent.message_end() except Exception as e: self._handle_error(e)
[docs] async def stream_async(self, request: AgentRequest) -> AsyncIterator[StreamEvent]: """Execute an async streaming request.""" try: model = self._get_model(request) contents = self._convert_messages(request) response = await model.generate_content_async(contents, stream=True) async for chunk in response: for event in self._convert_chunk(chunk): yield event yield StreamEvent.message_end() except Exception as e: self._handle_error(e)
def _convert_messages(self, request: AgentRequest) -> list[Any]: """Convert normalized messages to Gemini format.""" contents = [] for msg in request.messages: content = self._convert_message(msg) if content: contents.append(content) # Add input as user message if request.input: contents.append(Content(role="user", parts=[Part.from_text(request.input)])) return contents def _convert_message(self, msg: Message) -> Any | None: """Convert a single message to Gemini format.""" if msg.role == "system": # System messages handled separately in Gemini return None # Map roles role = "user" if msg.role in ("user", "tool") else "model" parts = [] # Handle content if isinstance(msg.content, str): if msg.role == "tool": # Tool results need special handling parts.append( Part.from_function_response( name=msg.name or "tool", response={"result": msg.content}, ) ) else: parts.append(Part.from_text(msg.content)) else: for part in msg.content: if part.type == "text" and part.text: parts.append(Part.from_text(part.text)) elif part.type == "image" and part.image_data: parts.append( Part.from_data( data=part.image_data, mime_type=part.media_type or "image/png", ) ) elif part.type == "image_url" and part.image_url: # Gemini prefers inline data, but we can try URL parts.append( Part.from_uri( uri=part.image_url, mime_type="image/jpeg", ) ) # Handle tool calls in assistant messages if msg.role == "assistant" and msg.tool_calls: for tc in msg.tool_calls: parts.append( Part.from_function_call( name=tc["name"], args=tc.get("arguments", {}), ) ) return Content(role=role, parts=parts) def _convert_tool(self, tool_spec: Any) -> Any: """Convert tool spec to Gemini format.""" schema = tool_spec.to_gemini_schema() return self._genai.protos.Tool( function_declarations=[ self._genai.protos.FunctionDeclaration( name=schema["name"], description=schema["description"], parameters=self._genai.protos.Schema( type=self._genai.protos.Type.OBJECT, properties={ k: self._convert_schema_property(v) for k, v in schema["parameters"].get("properties", {}).items() }, required=schema["parameters"].get("required", []), ), ) ] ) def _convert_schema_property(self, prop: dict[str, Any]) -> Any: """Convert a JSON Schema property to Gemini format.""" type_map = { "string": self._genai.protos.Type.STRING, "integer": self._genai.protos.Type.INTEGER, "number": self._genai.protos.Type.NUMBER, "boolean": self._genai.protos.Type.BOOLEAN, "array": self._genai.protos.Type.ARRAY, "object": self._genai.protos.Type.OBJECT, } schema_type = type_map.get(prop.get("type", "string"), self._genai.protos.Type.STRING) return self._genai.protos.Schema( type=schema_type, description=prop.get("description", ""), ) def _convert_response(self, response: Any) -> AgentResponse: """Convert Gemini response to normalized format.""" text_parts: list[str] = [] tool_calls: list[ToolCall] = [] if response.candidates: candidate = response.candidates[0] for part in candidate.content.parts: if hasattr(part, "text") and part.text: text_parts.append(part.text) elif hasattr(part, "function_call"): fc = part.function_call tool_calls.append( ToolCall( id=f"call_{uuid.uuid4().hex[:12]}", name=fc.name, arguments=dict(fc.args) if fc.args else {}, ) ) text = "".join(text_parts) if text_parts else None # Extract usage if available usage = None if hasattr(response, "usage_metadata") and response.usage_metadata: um = response.usage_metadata usage = Usage( prompt_tokens=getattr(um, "prompt_token_count", 0), completion_tokens=getattr(um, "candidates_token_count", 0), total_tokens=getattr(um, "total_token_count", 0), ) # Determine stop reason stop_reason = None if response.candidates: finish_reason = response.candidates[0].finish_reason stop_reason = str(finish_reason.name) if finish_reason else None return AgentResponse( text=text, content=[{"type": "text", "text": text}] if text else [], provider=self.name, model=self._model_name, usage=usage, stop_reason=stop_reason, tool_calls=tool_calls, raw=response, ) def _convert_chunk(self, chunk: Any) -> Iterator[StreamEvent]: """Convert a streaming chunk to events.""" if chunk.candidates: candidate = chunk.candidates[0] for part in candidate.content.parts: if hasattr(part, "text") and part.text: yield StreamEvent.text_delta(part.text, raw=chunk) elif hasattr(part, "function_call"): fc = part.function_call yield StreamEvent.tool_call_start( ToolCall( id=f"call_{uuid.uuid4().hex[:12]}", name=fc.name, arguments=dict(fc.args) if fc.args else {}, ), raw=chunk, ) def _handle_error(self, e: Exception) -> NoReturn: """Convert Gemini errors to Agent errors.""" # Re-raise if already an Agent error if isinstance(e, (AuthenticationError, RateLimitError, AgentTimeoutError, ProviderError)): raise # Check for specific exception types from the SDK error_type = type(e).__name__ error_str = str(e).lower() if ( error_type in ("PermissionDenied", "Unauthenticated") or "api key" in error_str or "authentication" in error_str or "permission denied" in error_str ): raise AuthenticationError(str(e), raw=e) from e elif ( error_type == "ResourceExhausted" or "rate limit" in error_str or "quota" in error_str or "429" in error_str ): raise RateLimitError(str(e), provider=self.name, raw=e) from e elif ( error_type in ("DeadlineExceeded",) or isinstance(e, TimeoutError) or "timeout" in error_str ): raise AgentTimeoutError(str(e), timeout=self.timeout, raw=e) from e else: raise ProviderError(str(e), provider=self.name, raw=e) from e
# Register the provider ProviderRegistry.register("gemini", GeminiProvider, aliases=["google", "palm"])