Mentors4EDU's picture
Upload 14 files
f2bab5e verified
"""
Base agent class for distributed computing.
"""
import torch
import ray
import uuid
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
import logging
from .couchdb_client import CouchDBClient
from .config import settings
logger = logging.getLogger(__name__)
@ray.remote
class Agent:
"""Distributed computing agent for tensor operations and model training."""
def __init__(self):
self.agent_id = str(uuid.uuid4())
self.db_client = CouchDBClient()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.current_job: Optional[Dict] = None
self._register_agent()
self._start_heartbeat()
def _register_agent(self):
"""Register agent with the cluster."""
capabilities = {
"device": str(self.device),
"cuda_available": torch.cuda.is_available(),
"cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"memory_available": torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0
}
success = self.db_client.register_agent(self.agent_id, capabilities)
if not success:
raise RuntimeError("Failed to register agent")
def _start_heartbeat(self):
"""Start agent heartbeat."""
async def heartbeat_loop():
while True:
try:
self.db_client.update_heartbeat(self.agent_id)
await asyncio.sleep(30)
except Exception as e:
logger.error(f"Heartbeat error: {e}")
await asyncio.sleep(5)
asyncio.create_task(heartbeat_loop())
def process_tensors(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Process tensor operations."""
results = {}
for name, tensor in tensors.items():
tensor = tensor.to(self.device)
# Perform tensor operations
results[name] = self._compute_tensor(tensor)
return results
def _compute_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""Compute operations on a single tensor."""
# Add custom tensor operations here
return tensor
async def run(self):
"""Main agent loop."""
while True:
try:
# Try to claim a job
job = self.db_client.claim_job(self.agent_id)
if job:
self.current_job = job
await self._process_job(job)
else:
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error in agent loop: {e}")
await asyncio.sleep(5)
async def _process_job(self, job: Dict[str, Any]):
"""Process a claimed job."""
try:
job_type = job['type']
params = job['params']
result = None
if job_type == 'gradient_computation':
result = await self._compute_gradients(params)
elif job_type == 'model_update':
result = await self._update_model(params)
# Store job results
self.db_client.update_job_status(
job['_id'],
'completed',
result
)
except Exception as e:
logger.error(f"Job processing error: {e}")
self.db_client.update_job_status(
job['_id'],
'failed',
{'error': str(e)}
)
finally:
self.current_job = None
async def _compute_gradients(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Compute gradients for model training."""
try:
# Load model checkpoint
checkpoint = params.get('checkpoint')
if checkpoint:
state_dict = torch.load(checkpoint, map_location=self.device)
# Compute gradients
gradients = self._compute_model_gradients(state_dict, params.get('batch'))
# Store gradients in CouchDB
gradient_id = self.db_client.store_gradients(
self.current_job['_id'],
gradients
)
return {'gradient_id': gradient_id}
except Exception as e:
logger.error(f"Gradient computation error: {e}")
raise
def _compute_model_gradients(self, state_dict: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, Any]:
"""Compute gradients for a given model state and batch."""
# Convert gradients to serializable format
gradients = {}
for name, param in state_dict.items():
if param.requires_grad:
grad = param.grad
if grad is not None:
gradients[name] = grad.cpu().numpy().tolist()
return gradients
async def _update_model(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update model with new parameters."""
try:
new_state = params.get('state')
if new_state:
# Apply model updates
state_id = self.db_client.store_model_state(new_state)
return {'state_id': state_id}
except Exception as e:
logger.error(f"Model update error: {e}")
raise
def shutdown(self):
"""Shutdown the agent."""
# Update agent status to inactive
self.db_client.update_job_status(
self.agent_id,
'inactive'
)
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()