File size: 8,217 Bytes
f2bab5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
"""
Coordinator for distributed model training.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Any, Optional
import asyncio
import logging
from huggingface_hub import snapshot_download
import os
import ray
from .couchdb_client import CouchDBClient
from .config import settings
from .tensor_ops import TensorOps
logger = logging.getLogger(__name__)
class Coordinator:
"""Coordinator for distributed training of OpenPeerLLM."""
def __init__(self):
self.db_client = CouchDBClient()
self.model_id = settings.MODEL_ID
self.batch_size = settings.BATCH_SIZE
self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS
self._initialize_model()
def _initialize_model(self):
"""Initialize the model and tokenizer."""
try:
# Download model and tokenizer from Hugging Face
cache_dir = snapshot_download(self.model_id)
self.model = AutoModelForCausalLM.from_pretrained(cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
# Store initial model state
initial_state = {
'model_state': self.model.state_dict(),
'step': 0,
'epoch': 0
}
self.db_client.store_model_state(initial_state)
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
async def coordinate_training(self, training_config: Dict[str, Any]):
"""Coordinate distributed training across agents."""
try:
num_epochs = training_config.get('num_epochs', 1)
steps_per_epoch = training_config.get('steps_per_epoch', 100)
for epoch in range(num_epochs):
logger.info(f"Starting epoch {epoch}")
await self._train_epoch(epoch, steps_per_epoch)
# Save checkpoint after each epoch
self._save_checkpoint(epoch)
except Exception as e:
logger.error(f"Training coordination error: {e}")
raise
async def _train_epoch(self, epoch: int, steps_per_epoch: int):
"""Train for one epoch."""
for step in range(steps_per_epoch):
# Get active agents
active_agents = self.db_client.get_active_agents()
if not active_agents:
logger.warning("No active agents available")
await asyncio.sleep(5)
continue
# Distribute gradient computation jobs
gradient_jobs = await self._distribute_gradient_computation(
active_agents,
self.batch_size
)
# Collect and process gradients
gradients = await self._collect_gradients(gradient_jobs)
if gradients:
# Update model with collected gradients
self._update_model_parameters(gradients)
# Distribute updated model state to agents
await self._distribute_model_update()
async def _distribute_gradient_computation(
self,
agents: List[Dict[str, Any]],
batch_size: int
) -> List[str]:
"""Distribute gradient computation jobs to available agents."""
job_ids = []
# Get current model state
current_state = self.db_client.get_latest_model_state()
if not current_state:
raise RuntimeError("No model state available")
# Create gradient computation jobs
for agent in agents:
job_id = self.db_client.create_job(
'gradient_computation',
{
'batch_size': batch_size,
'state': current_state['state']
}
)
job_ids.append(job_id)
return job_ids
async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
"""Collect gradients from completed jobs."""
all_gradients = []
timeout = 300 # 5 minutes timeout
async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_time()
while True:
if asyncio.get_event_time() - start_time > timeout:
logger.warning(f"Job {job_id} timed out")
return None
job = self.db_client.get_job(job_id)
if job['status'] == 'completed':
gradient_id = job['result']['gradient_id']
return self.db_client.get_gradients(gradient_id)
elif job['status'] == 'failed':
logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}")
return None
await asyncio.sleep(1)
# Wait for all gradient computations to complete
gradient_tasks = [wait_for_job(job_id) for job_id in job_ids]
gradients = await asyncio.gather(*gradient_tasks)
# Filter out None results (failed jobs)
return [g for g in gradients if g is not None]
def _update_model_parameters(self, gradients: List[Dict[str, Any]]):
"""Update model parameters with collected gradients."""
try:
# Average gradients from all workers
avg_gradients = TensorOps.average_gradients([
{k: torch.tensor(v) for k, v in g.items()}
for g in gradients
])
# Apply gradient clipping
clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0)
# Update model parameters
with torch.no_grad():
for name, param in self.model.named_parameters():
if name in clipped_gradients:
param.sub_(clipped_gradients[name] * self.model.config.learning_rate)
except Exception as e:
logger.error(f"Error updating model parameters: {e}")
raise
async def _distribute_model_update(self):
"""Distribute updated model state to all agents."""
try:
# Store updated model state
state = {
'model_state': self.model.state_dict(),
'timestamp': datetime.utcnow().isoformat()
}
state_id = self.db_client.store_model_state(state)
# Create model update jobs for all active agents
active_agents = self.db_client.get_active_agents()
for agent in active_agents:
self.db_client.create_job(
'model_update',
{
'state_id': state_id,
'state': state
}
)
except Exception as e:
logger.error(f"Error distributing model update: {e}")
raise
def _save_checkpoint(self, epoch: int):
"""Save a checkpoint of the current model state."""
try:
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None
}, checkpoint_path)
logger.info(f"Saved checkpoint for epoch {epoch}")
except Exception as e:
logger.error(f"Error saving checkpoint: {e}")
raise |