Source code for agent.schemas
"""
Agent structured output schemas.
Support for Pydantic-based schema definition and JSON Schema.
"""
import json
import re
from typing import Any, TypeVar
from pydantic import BaseModel, ValidationError
from agent.errors import SchemaValidationError
T = TypeVar("T", bound=BaseModel)
[docs]
class Schema:
"""
Wrapper for schema definitions.
Supports both Pydantic models and raw JSON Schema dictionaries.
"""
def __init__(
self,
schema: type[BaseModel] | dict[str, Any],
*,
strict: bool = True,
repair_attempts: int = 1,
):
self.schema = schema
self.strict = strict
self.repair_attempts = repair_attempts
self._is_pydantic = isinstance(schema, type) and issubclass(schema, BaseModel)
@property
def json_schema(self) -> dict[str, Any]:
"""Get the JSON Schema representation."""
if self._is_pydantic:
return self.schema.model_json_schema() # type: ignore
return self.schema # type: ignore
[docs]
def validate(self, data: Any) -> Any:
"""
Validate data against the schema.
Returns the validated/parsed object, or raises SchemaValidationError.
"""
if self._is_pydantic:
try:
if isinstance(data, str):
return self.schema.model_validate_json(data) # type: ignore
return self.schema.model_validate(data) # type: ignore
except ValidationError as e:
raise SchemaValidationError(
f"Schema validation failed: {e}",
schema=self.schema,
output=data,
) from e
else:
# For raw JSON Schema, just return the data
# (validation would require jsonschema library)
return data
[docs]
def parse_json(self, text: str) -> Any:
"""
Parse JSON text and validate against schema.
Handles extracting JSON from markdown code blocks.
"""
# Clean up the text
cleaned = text.strip()
# Try to extract JSON from markdown code blocks
if cleaned.startswith("```"):
lines = cleaned.split("\n")
# Remove first line (```json or ```)
lines = lines[1:]
# Remove last line (```)
if lines and lines[-1].strip() == "```":
lines = lines[:-1]
cleaned = "\n".join(lines)
# Parse JSON
try:
data = json.loads(cleaned)
except json.JSONDecodeError as e:
raise SchemaValidationError(
f"Failed to parse JSON: {e}",
schema=self.schema,
output=text,
) from e
return self.validate(data)
[docs]
def schema_to_prompt(schema: Schema) -> str:
"""Convert a schema to a prompt instruction."""
json_schema = schema.json_schema
# Get the schema name if available
name = json_schema.get("title", "response")
# Build the prompt
prompt = f"Respond with a JSON object for '{name}' matching this schema:\n\n```json\n{json.dumps(json_schema, indent=2)}\n```\n\n"
prompt += "IMPORTANT: Return ONLY the JSON object, no other text."
return prompt
[docs]
def repair_json(text: str, error: Exception) -> str:
"""
Attempt to repair malformed JSON.
Basic repairs:
- Add missing closing braces/brackets
- Fix trailing commas
- Fix unquoted keys
"""
# Remove trailing commas before closing braces/brackets
text = re.sub(r",\s*([}\]])", r"\1", text)
# Try to balance braces
open_braces = text.count("{")
close_braces = text.count("}")
if open_braces > close_braces:
text = text + "}" * (open_braces - close_braces)
open_brackets = text.count("[")
close_brackets = text.count("]")
if open_brackets > close_brackets:
text = text + "]" * (open_brackets - close_brackets)
return text