Mentors4EDU's picture
Upload 14 files
f2bab5e verified
"""
Scaling manager for horizontal scaling of cloud agents.
"""
import ray
import asyncio
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from .couchdb_client import CouchDBClient
from .agent import Agent
from .config import settings
logger = logging.getLogger(__name__)
class ScalingManager:
"""Manager for horizontal scaling of cloud agents."""
def __init__(self):
self.db_client = CouchDBClient()
self._initialize_ray()
self.min_agents = 2
self.max_agents = 10
self.scale_up_threshold = 0.8 # Scale up when 80% of agents are busy
self.scale_down_threshold = 0.3 # Scale down when less than 30% of agents are busy
self.agent_refs: Dict[str, ray.actor.ActorHandle] = {}
def _initialize_ray(self):
"""Initialize Ray for distributed computing."""
if not ray.is_initialized():
ray.init(address=f"ray://{settings.COORDINATOR_HOST}:{settings.RAY_HEAD_PORT}")
async def monitor_and_scale(self):
"""Monitor cluster health and scale as needed."""
while True:
try:
await self._check_agent_health()
await self._scale_cluster()
await asyncio.sleep(60) # Check every minute
except Exception as e:
logger.error(f"Error in monitor and scale loop: {e}")
await asyncio.sleep(5)
async def _check_agent_health(self):
"""Check health of all agents and remove dead ones."""
try:
active_agents = self.db_client.get_active_agents()
current_time = datetime.utcnow()
for agent in active_agents:
last_heartbeat = datetime.fromisoformat(agent['last_heartbeat'])
if current_time - last_heartbeat > timedelta(minutes=5):
# Agent is considered dead
logger.warning(f"Agent {agent['_id']} appears to be dead. Removing...")
await self._remove_agent(agent['_id'])
except Exception as e:
logger.error(f"Error checking agent health: {e}")
raise
async def _scale_cluster(self):
"""Scale the cluster based on workload."""
try:
active_agents = self.db_client.get_active_agents()
total_agents = len(active_agents)
busy_agents = len([a for a in active_agents if a['current_job'] is not None])
if total_agents < 1:
# Always ensure at least one agent is running
await self._add_agent()
return
utilization = busy_agents / total_agents if total_agents > 0 else 0
# Scale up if needed
if utilization >= self.scale_up_threshold and total_agents < self.max_agents:
num_to_add = min(2, self.max_agents - total_agents) # Add up to 2 agents at a time
logger.info(f"Scaling up: Adding {num_to_add} agents")
for _ in range(num_to_add):
await self._add_agent()
# Scale down if needed
elif utilization <= self.scale_down_threshold and total_agents > self.min_agents:
num_to_remove = min(1, total_agents - self.min_agents) # Remove 1 agent at a time
logger.info(f"Scaling down: Removing {num_to_remove} agents")
idle_agents = [a for a in active_agents if a['current_job'] is None]
for _ in range(num_to_remove):
if idle_agents:
await self._remove_agent(idle_agents.pop()['_id'])
except Exception as e:
logger.error(f"Error scaling cluster: {e}")
raise
async def _add_agent(self):
"""Add a new agent to the cluster."""
try:
# Create new agent actor using Ray
agent_ref = ray.remote(Agent).options(
num_cpus=1,
num_gpus=0.5 if ray.get_gpu_ids() else 0
).remote()
# Store reference
agent_id = await agent_ref.get_id.remote()
self.agent_refs[agent_id] = agent_ref
# Start agent
ray.get(agent_ref.run.remote())
logger.info(f"Added new agent {agent_id}")
return agent_id
except Exception as e:
logger.error(f"Error adding agent: {e}")
raise
async def _remove_agent(self, agent_id: str):
"""Remove an agent from the cluster."""
try:
# Get agent reference
agent_ref = self.agent_refs.get(agent_id)
if agent_ref:
# Shutdown agent gracefully
await agent_ref.shutdown.remote()
# Remove from Ray
ray.kill(agent_ref)
# Remove from local tracking
del self.agent_refs[agent_id]
logger.info(f"Removed agent {agent_id}")
except Exception as e:
logger.error(f"Error removing agent: {e}")
raise
def get_cluster_status(self) -> Dict[str, Any]:
"""Get current status of the cluster."""
try:
active_agents = self.db_client.get_active_agents()
total_agents = len(active_agents)
busy_agents = len([a for a in active_agents if a['current_job'] is not None])
return {
'total_agents': total_agents,
'busy_agents': busy_agents,
'idle_agents': total_agents - busy_agents,
'utilization': busy_agents / total_agents if total_agents > 0 else 0,
'can_scale_up': total_agents < self.max_agents,
'can_scale_down': total_agents > self.min_agents
}
except Exception as e:
logger.error(f"Error getting cluster status: {e}")
raise