Cloud-Agents / cloud_agents /coordinator.py
Mentors4EDU's picture
Upload 14 files
f2bab5e verified
"""
Coordinator for distributed model training.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Any, Optional
import asyncio
import logging
from huggingface_hub import snapshot_download
import os
import ray
from .couchdb_client import CouchDBClient
from .config import settings
from .tensor_ops import TensorOps
logger = logging.getLogger(__name__)
class Coordinator:
"""Coordinator for distributed training of OpenPeerLLM."""
def __init__(self):
self.db_client = CouchDBClient()
self.model_id = settings.MODEL_ID
self.batch_size = settings.BATCH_SIZE
self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS
self._initialize_model()
def _initialize_model(self):
"""Initialize the model and tokenizer."""
try:
# Download model and tokenizer from Hugging Face
cache_dir = snapshot_download(self.model_id)
self.model = AutoModelForCausalLM.from_pretrained(cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
# Store initial model state
initial_state = {
'model_state': self.model.state_dict(),
'step': 0,
'epoch': 0
}
self.db_client.store_model_state(initial_state)
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
async def coordinate_training(self, training_config: Dict[str, Any]):
"""Coordinate distributed training across agents."""
try:
num_epochs = training_config.get('num_epochs', 1)
steps_per_epoch = training_config.get('steps_per_epoch', 100)
for epoch in range(num_epochs):
logger.info(f"Starting epoch {epoch}")
await self._train_epoch(epoch, steps_per_epoch)
# Save checkpoint after each epoch
self._save_checkpoint(epoch)
except Exception as e:
logger.error(f"Training coordination error: {e}")
raise
async def _train_epoch(self, epoch: int, steps_per_epoch: int):
"""Train for one epoch."""
for step in range(steps_per_epoch):
# Get active agents
active_agents = self.db_client.get_active_agents()
if not active_agents:
logger.warning("No active agents available")
await asyncio.sleep(5)
continue
# Distribute gradient computation jobs
gradient_jobs = await self._distribute_gradient_computation(
active_agents,
self.batch_size
)
# Collect and process gradients
gradients = await self._collect_gradients(gradient_jobs)
if gradients:
# Update model with collected gradients
self._update_model_parameters(gradients)
# Distribute updated model state to agents
await self._distribute_model_update()
async def _distribute_gradient_computation(
self,
agents: List[Dict[str, Any]],
batch_size: int
) -> List[str]:
"""Distribute gradient computation jobs to available agents."""
job_ids = []
# Get current model state
current_state = self.db_client.get_latest_model_state()
if not current_state:
raise RuntimeError("No model state available")
# Create gradient computation jobs
for agent in agents:
job_id = self.db_client.create_job(
'gradient_computation',
{
'batch_size': batch_size,
'state': current_state['state']
}
)
job_ids.append(job_id)
return job_ids
async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
"""Collect gradients from completed jobs."""
all_gradients = []
timeout = 300 # 5 minutes timeout
async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_time()
while True:
if asyncio.get_event_time() - start_time > timeout:
logger.warning(f"Job {job_id} timed out")
return None
job = self.db_client.get_job(job_id)
if job['status'] == 'completed':
gradient_id = job['result']['gradient_id']
return self.db_client.get_gradients(gradient_id)
elif job['status'] == 'failed':
logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}")
return None
await asyncio.sleep(1)
# Wait for all gradient computations to complete
gradient_tasks = [wait_for_job(job_id) for job_id in job_ids]
gradients = await asyncio.gather(*gradient_tasks)
# Filter out None results (failed jobs)
return [g for g in gradients if g is not None]
def _update_model_parameters(self, gradients: List[Dict[str, Any]]):
"""Update model parameters with collected gradients."""
try:
# Average gradients from all workers
avg_gradients = TensorOps.average_gradients([
{k: torch.tensor(v) for k, v in g.items()}
for g in gradients
])
# Apply gradient clipping
clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0)
# Update model parameters
with torch.no_grad():
for name, param in self.model.named_parameters():
if name in clipped_gradients:
param.sub_(clipped_gradients[name] * self.model.config.learning_rate)
except Exception as e:
logger.error(f"Error updating model parameters: {e}")
raise
async def _distribute_model_update(self):
"""Distribute updated model state to all agents."""
try:
# Store updated model state
state = {
'model_state': self.model.state_dict(),
'timestamp': datetime.utcnow().isoformat()
}
state_id = self.db_client.store_model_state(state)
# Create model update jobs for all active agents
active_agents = self.db_client.get_active_agents()
for agent in active_agents:
self.db_client.create_job(
'model_update',
{
'state_id': state_id,
'state': state
}
)
except Exception as e:
logger.error(f"Error distributing model update: {e}")
raise
def _save_checkpoint(self, epoch: int):
"""Save a checkpoint of the current model state."""
try:
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None
}, checkpoint_path)
logger.info(f"Saved checkpoint for epoch {epoch}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
raise