from abc import ABC, abstractmethod
from typing import Dict, List, Any, Optional, Callable
from dataclasses import dataclass
from enum import Enum
import logging
import asyncio
import json
import time
from contextlib import asynccontextmanager
import redis
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# Design Patterns Implementation for AI Agents
# 1. Strategy Pattern for Tool Execution
class ToolExecutionStrategy(ABC):
"""Abstract strategy for tool execution"""
@abstractmethod
async def execute(self, tool_name: str, parameters: Dict[str, Any]) -> Any:
pass
@abstractmethod
def validate_parameters(self, tool_name: str, parameters: Dict[str, Any]) -> bool:
pass
class LocalToolExecutionStrategy(ToolExecutionStrategy):
"""Local tool execution strategy"""
def __init__(self):
self.tools = {}
self.logger = logging.getLogger(__name__)
async def execute(self, tool_name: str, parameters: Dict[str, Any]) -> Any:
if not self.validate_parameters(tool_name, parameters):
raise ValueError(f"Invalid parameters for tool {tool_name}")
tool_func = self.tools.get(tool_name)
if not tool_func:
raise ValueError(f"Tool {tool_name} not found")
try:
self.logger.info(f"Executing tool {tool_name} with parameters {parameters}")
result = await tool_func(**parameters)
self.logger.info(f"Tool {tool_name} executed successfully")
return result
except Exception as e:
self.logger.error(f"Tool execution failed: {e}")
raise
def validate_parameters(self, tool_name: str, parameters: Dict[str, Any]) -> bool:
# Implement parameter validation logic
return True
def register_tool(self, name: str, func: Callable):
"""Register a tool function"""
self.tools[name] = func
class RemoteToolExecutionStrategy(ToolExecutionStrategy):
"""Remote tool execution strategy via API"""
def __init__(self, api_endpoint: str, api_key: str):
self.api_endpoint = api_endpoint
self.api_key = api_key
self.logger = logging.getLogger(__name__)
async def execute(self, tool_name: str, parameters: Dict[str, Any]) -> Any:
# Implement remote API call logic
self.logger.info(f"Executing remote tool {tool_name}")
# Simulate API call
await asyncio.sleep(0.1)
return {"status": "success", "result": "remote_execution_result"}
def validate_parameters(self, tool_name: str, parameters: Dict[str, Any]) -> bool:
return True
# 2. Observer Pattern for Event Management
class AgentEvent:
"""Agent event data structure"""
def __init__(self, event_type: str, agent_id: str, data: Dict[str, Any]):
self.event_type = event_type
self.agent_id = agent_id
self.data = data
self.timestamp = time.time()
class EventObserver(ABC):
"""Abstract observer for agent events"""
@abstractmethod
async def handle_event(self, event: AgentEvent):
pass
class LoggingObserver(EventObserver):
"""Observer that logs all events"""
def __init__(self):
self.logger = logging.getLogger(__name__)
async def handle_event(self, event: AgentEvent):
self.logger.info(f"Event: {event.event_type} from {event.agent_id} at {event.timestamp}")
class MetricsObserver(EventObserver):
"""Observer that collects metrics"""
def __init__(self):
self.metrics = {}
async def handle_event(self, event: AgentEvent):
event_type = event.event_type
self.metrics[event_type] = self.metrics.get(event_type, 0) + 1
class EventManager:
"""Event manager implementing observer pattern"""
def __init__(self):
self.observers: List[EventObserver] = []
def subscribe(self, observer: EventObserver):
"""Subscribe an observer to events"""
self.observers.append(observer)
def unsubscribe(self, observer: EventObserver):
"""Unsubscribe an observer from events"""
if observer in self.observers:
self.observers.remove(observer)
async def publish_event(self, event: AgentEvent):
"""Publish an event to all observers"""
for observer in self.observers:
try:
await observer.handle_event(event)
except Exception as e:
logging.error(f"Observer error: {e}")
# 3. Memento Pattern for State Management
@dataclass
class AgentMemento:
"""Memento storing agent state"""
agent_id: str
conversation_history: List[Dict[str, Any]]
context: Dict[str, Any]
timestamp: float
checkpoint_id: str
class AgentStateManager:
"""Caretaker for agent state mementos"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
self.logger = logging.getLogger(__name__)
async def save_state(self, memento: AgentMemento) -> str:
"""Save agent state and return checkpoint ID"""
try:
key = f"agent_state:{memento.agent_id}:{memento.checkpoint_id}"
state_data = {
"conversation_history": memento.conversation_history,
"context": memento.context,
"timestamp": memento.timestamp
}
self.redis_client.setex(key, 3600, json.dumps(state_data)) # 1 hour expiry
self.logger.info(f"Saved state for agent {memento.agent_id}")
return memento.checkpoint_id
except Exception as e:
self.logger.error(f"Failed to save state: {e}")
raise
async def restore_state(self, agent_id: str, checkpoint_id: str) -> Optional[AgentMemento]:
"""Restore agent state from checkpoint"""
try:
key = f"agent_state:{agent_id}:{checkpoint_id}"
state_data = self.redis_client.get(key)
if not state_data:
return None
data = json.loads(state_data)
return AgentMemento(
agent_id=agent_id,
conversation_history=data["conversation_history"],
context=data["context"],
timestamp=data["timestamp"],
checkpoint_id=checkpoint_id
)
except Exception as e:
self.logger.error(f"Failed to restore state: {e}")
return None
# 4. Factory Pattern for Agent Creation
class AgentType(Enum):
CUSTOMER_SERVICE = "customer_service"
TECHNICAL_SUPPORT = "technical_support"
SALES = "sales"
GENERAL_PURPOSE = "general_purpose"
class AgentConfiguration:
"""Agent configuration data"""
def __init__(self, agent_type: AgentType, llm_config: Dict[str, Any],
tools: List[str], memory_config: Dict[str, Any]):
self.agent_type = agent_type
self.llm_config = llm_config
self.tools = tools
self.memory_config = memory_config
class BaseAgent(ABC):
"""Abstract base agent class"""
def __init__(self, agent_id: str, config: AgentConfiguration,
tool_strategy: ToolExecutionStrategy, event_manager: EventManager,
state_manager: AgentStateManager):
self.agent_id = agent_id
self.config = config
self.tool_strategy = tool_strategy
self.event_manager = event_manager
self.state_manager = state_manager
self.conversation_history = []
self.context = {}
self.logger = logging.getLogger(__name__)
@abstractmethod
async def process_message(self, message: str, context: Dict[str, Any] = None) -> str:
pass
async def create_checkpoint(self) -> str:
"""Create a state checkpoint"""
checkpoint_id = f"checkpoint_{int(time.time())}"
memento = AgentMemento(
agent_id=self.agent_id,
conversation_history=self.conversation_history.copy(),
context=self.context.copy(),
timestamp=time.time(),
checkpoint_id=checkpoint_id
)
await self.state_manager.save_state(memento)
await self.event_manager.publish_event(
AgentEvent("checkpoint_created", self.agent_id, {"checkpoint_id": checkpoint_id})
)
return checkpoint_id
async def restore_from_checkpoint(self, checkpoint_id: str) -> bool:
"""Restore agent state from checkpoint"""
memento = await self.state_manager.restore_state(self.agent_id, checkpoint_id)
if memento:
self.conversation_history = memento.conversation_history
self.context = memento.context
await self.event_manager.publish_event(
AgentEvent("state_restored", self.agent_id, {"checkpoint_id": checkpoint_id})
)
return True
return False
class CustomerServiceAgent(BaseAgent):
"""Specialized customer service agent"""
async def process_message(self, message: str, context: Dict[str, Any] = None) -> str:
# Add message to conversation history
self.conversation_history.append({"role": "user", "content": message})
# Publish event
await self.event_manager.publish_event(
AgentEvent("message_received", self.agent_id, {"message": message})
)
# Process with customer service logic
response = f"Customer Service Response to: {message}"
# Add response to history
self.conversation_history.append({"role": "assistant", "content": response})
return response
class TechnicalSupportAgent(BaseAgent):
"""Specialized technical support agent"""
async def process_message(self, message: str, context: Dict[str, Any] = None) -> str:
# Add message to conversation history
self.conversation_history.append({"role": "user", "content": message})
# Publish event
await self.event_manager.publish_event(
AgentEvent("message_received", self.agent_id, {"message": message})
)
# Process with technical support logic
response = f"Technical Support Response to: {message}"
# Add response to history
self.conversation_history.append({"role": "assistant", "content": response})
return response
class AgentFactory:
"""Factory for creating specialized agents"""
def __init__(self, tool_strategy: ToolExecutionStrategy, event_manager: EventManager,
state_manager: AgentStateManager):
self.tool_strategy = tool_strategy
self.event_manager = event_manager
self.state_manager = state_manager
self.agent_configs = self._setup_default_configs()
def _setup_default_configs(self) -> Dict[AgentType, AgentConfiguration]:
"""Setup default configurations for different agent types"""
return {
AgentType.CUSTOMER_SERVICE: AgentConfiguration(
agent_type=AgentType.CUSTOMER_SERVICE,
llm_config={"model": "gpt-4", "temperature": 0.1},
tools=["search_knowledge_base", "create_ticket", "get_customer_info"],
memory_config={"max_history": 50, "context_window": 4000}
),
AgentType.TECHNICAL_SUPPORT: AgentConfiguration(
agent_type=AgentType.TECHNICAL_SUPPORT,
llm_config={"model": "gpt-4", "temperature": 0.0},
tools=["run_diagnostics", "search_documentation", "create_ticket"],
memory_config={"max_history": 100, "context_window": 8000}
)
}
def create_agent(self, agent_type: AgentType, agent_id: str = None) -> BaseAgent:
"""Create an agent of the specified type"""
if agent_id is None:
agent_id = f"{agent_type.value}_{int(time.time())}"
config = self.agent_configs.get(agent_type)
if not config:
raise ValueError(f"No configuration found for agent type {agent_type}")
# Create appropriate agent based on type
if agent_type == AgentType.CUSTOMER_SERVICE:
return CustomerServiceAgent(agent_id, config, self.tool_strategy,
self.event_manager, self.state_manager)
elif agent_type == AgentType.TECHNICAL_SUPPORT:
return TechnicalSupportAgent(agent_id, config, self.tool_strategy,
self.event_manager, self.state_manager)
else:
raise ValueError(f"Unsupported agent type: {agent_type}")
# 5. Chain of Responsibility Pattern
class RequestHandler(ABC):
"""Abstract request handler"""
def __init__(self):
self._next_handler: Optional[RequestHandler] = None
def set_next(self, handler: 'RequestHandler') -> 'RequestHandler':
"""Set the next handler in the chain"""
self._next_handler = handler
return handler
@abstractmethod
async def handle(self, request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Handle the request or pass to next handler"""
if self._next_handler:
return await self._next_handler.handle(request)
return None
class AuthenticationHandler(RequestHandler):
"""Handler for authentication"""
async def handle(self, request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not request.get("authenticated", False):
return {"error": "Authentication required", "status": 401}
# Pass to next handler if authenticated
return await super().handle(request)
class RateLimitHandler(RequestHandler):
"""Handler for rate limiting"""
def __init__(self, max_requests: int = 100):
super().__init__()
self.max_requests = max_requests
self.request_counts = {}
async def handle(self, request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
user_id = request.get("user_id")
if user_id:
count = self.request_counts.get(user_id, 0)
if count >= self.max_requests:
return {"error": "Rate limit exceeded", "status": 429}
self.request_counts[user_id] = count + 1
# Pass to next handler if within limits
return await super().handle(request)
class AgentProcessingHandler(RequestHandler):
"""Handler for agent processing"""
def __init__(self, agent: BaseAgent):
super().__init__()
self.agent = agent
async def handle(self, request: Dict[str, Any]) -> Optional[Dict[str, Any]]:
message = request.get("message")
if message:
response = await self.agent.process_message(message, request.get("context"))
return {"response": response, "status": 200}
return {"error": "No message provided", "status": 400}
# Usage Example and Integration
async def main():
"""Example usage of the design patterns"""
# Setup infrastructure
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
event_manager = EventManager()
state_manager = AgentStateManager(redis_client)
# Setup observers
logging_observer = LoggingObserver()
metrics_observer = MetricsObserver()
event_manager.subscribe(logging_observer)
event_manager.subscribe(metrics_observer)
# Setup tool execution strategy
tool_strategy = LocalToolExecutionStrategy()
# Create agent factory
agent_factory = AgentFactory(tool_strategy, event_manager, state_manager)
# Create agents
cs_agent = agent_factory.create_agent(AgentType.CUSTOMER_SERVICE, "cs_agent_001")
tech_agent = agent_factory.create_agent(AgentType.TECHNICAL_SUPPORT, "tech_agent_001")
# Setup request processing chain
auth_handler = AuthenticationHandler()
rate_limit_handler = RateLimitHandler(max_requests=50)
cs_processing_handler = AgentProcessingHandler(cs_agent)
# Chain the handlers
auth_handler.set_next(rate_limit_handler).set_next(cs_processing_handler)
# Process a request
request = {
"authenticated": True,
"user_id": "user123",
"message": "I need help with my account",
"context": {"session_id": "session_456"}
}
# Process through the chain
result = await auth_handler.handle(request)
print(f"Processing result: {result}")
# Create checkpoint
checkpoint_id = await cs_agent.create_checkpoint()
print(f"Created checkpoint: {checkpoint_id}")
# Process another message
await cs_agent.process_message("Follow-up question")
# Restore from checkpoint
restored = await cs_agent.restore_from_checkpoint(checkpoint_id)
print(f"Restored from checkpoint: {restored}")
# Check metrics
print(f"Event metrics: {metrics_observer.metrics}")
if __name__ == "__main__":
asyncio.run(main())