""" Scaling manager for horizontal scaling of cloud agents. """ import ray import asyncio import logging from typing import Dict, List, Optional, Any from datetime import datetime, timedelta from .couchdb_client import CouchDBClient from .agent import Agent from .config import settings logger = logging.getLogger(__name__) class ScalingManager: """Manager for horizontal scaling of cloud agents.""" def __init__(self): self.db_client = CouchDBClient() self._initialize_ray() self.min_agents = 2 self.max_agents = 10 self.scale_up_threshold = 0.8 # Scale up when 80% of agents are busy self.scale_down_threshold = 0.3 # Scale down when less than 30% of agents are busy self.agent_refs: Dict[str, ray.actor.ActorHandle] = {} def _initialize_ray(self): """Initialize Ray for distributed computing.""" if not ray.is_initialized(): ray.init(address=f"ray://{settings.COORDINATOR_HOST}:{settings.RAY_HEAD_PORT}") async def monitor_and_scale(self): """Monitor cluster health and scale as needed.""" while True: try: await self._check_agent_health() await self._scale_cluster() await asyncio.sleep(60) # Check every minute except Exception as e: logger.error(f"Error in monitor and scale loop: {e}") await asyncio.sleep(5) async def _check_agent_health(self): """Check health of all agents and remove dead ones.""" try: active_agents = self.db_client.get_active_agents() current_time = datetime.utcnow() for agent in active_agents: last_heartbeat = datetime.fromisoformat(agent['last_heartbeat']) if current_time - last_heartbeat > timedelta(minutes=5): # Agent is considered dead logger.warning(f"Agent {agent['_id']} appears to be dead. Removing...") await self._remove_agent(agent['_id']) except Exception as e: logger.error(f"Error checking agent health: {e}") raise async def _scale_cluster(self): """Scale the cluster based on workload.""" try: active_agents = self.db_client.get_active_agents() total_agents = len(active_agents) busy_agents = len([a for a in active_agents if a['current_job'] is not None]) if total_agents < 1: # Always ensure at least one agent is running await self._add_agent() return utilization = busy_agents / total_agents if total_agents > 0 else 0 # Scale up if needed if utilization >= self.scale_up_threshold and total_agents < self.max_agents: num_to_add = min(2, self.max_agents - total_agents) # Add up to 2 agents at a time logger.info(f"Scaling up: Adding {num_to_add} agents") for _ in range(num_to_add): await self._add_agent() # Scale down if needed elif utilization <= self.scale_down_threshold and total_agents > self.min_agents: num_to_remove = min(1, total_agents - self.min_agents) # Remove 1 agent at a time logger.info(f"Scaling down: Removing {num_to_remove} agents") idle_agents = [a for a in active_agents if a['current_job'] is None] for _ in range(num_to_remove): if idle_agents: await self._remove_agent(idle_agents.pop()['_id']) except Exception as e: logger.error(f"Error scaling cluster: {e}") raise async def _add_agent(self): """Add a new agent to the cluster.""" try: # Create new agent actor using Ray agent_ref = ray.remote(Agent).options( num_cpus=1, num_gpus=0.5 if ray.get_gpu_ids() else 0 ).remote() # Store reference agent_id = await agent_ref.get_id.remote() self.agent_refs[agent_id] = agent_ref # Start agent ray.get(agent_ref.run.remote()) logger.info(f"Added new agent {agent_id}") return agent_id except Exception as e: logger.error(f"Error adding agent: {e}") raise async def _remove_agent(self, agent_id: str): """Remove an agent from the cluster.""" try: # Get agent reference agent_ref = self.agent_refs.get(agent_id) if agent_ref: # Shutdown agent gracefully await agent_ref.shutdown.remote() # Remove from Ray ray.kill(agent_ref) # Remove from local tracking del self.agent_refs[agent_id] logger.info(f"Removed agent {agent_id}") except Exception as e: logger.error(f"Error removing agent: {e}") raise def get_cluster_status(self) -> Dict[str, Any]: """Get current status of the cluster.""" try: active_agents = self.db_client.get_active_agents() total_agents = len(active_agents) busy_agents = len([a for a in active_agents if a['current_job'] is not None]) return { 'total_agents': total_agents, 'busy_agents': busy_agents, 'idle_agents': total_agents - busy_agents, 'utilization': busy_agents / total_agents if total_agents > 0 else 0, 'can_scale_up': total_agents < self.max_agents, 'can_scale_down': total_agents > self.min_agents } except Exception as e: logger.error(f"Error getting cluster status: {e}") raise