|
|
"""
|
|
|
Coordinator for distributed model training.
|
|
|
"""
|
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from typing import Dict, List, Any, Optional
|
|
|
import asyncio
|
|
|
import logging
|
|
|
from huggingface_hub import snapshot_download
|
|
|
import os
|
|
|
import ray
|
|
|
from .couchdb_client import CouchDBClient
|
|
|
from .config import settings
|
|
|
from .tensor_ops import TensorOps
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Coordinator:
|
|
|
"""Coordinator for distributed training of OpenPeerLLM."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.db_client = CouchDBClient()
|
|
|
self.model_id = settings.MODEL_ID
|
|
|
self.batch_size = settings.BATCH_SIZE
|
|
|
self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS
|
|
|
self._initialize_model()
|
|
|
|
|
|
def _initialize_model(self):
|
|
|
"""Initialize the model and tokenizer."""
|
|
|
try:
|
|
|
|
|
|
cache_dir = snapshot_download(self.model_id)
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(cache_dir)
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
|
|
|
|
|
|
|
|
|
initial_state = {
|
|
|
'model_state': self.model.state_dict(),
|
|
|
'step': 0,
|
|
|
'epoch': 0
|
|
|
}
|
|
|
self.db_client.store_model_state(initial_state)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to initialize model: {e}")
|
|
|
raise
|
|
|
|
|
|
async def coordinate_training(self, training_config: Dict[str, Any]):
|
|
|
"""Coordinate distributed training across agents."""
|
|
|
try:
|
|
|
num_epochs = training_config.get('num_epochs', 1)
|
|
|
steps_per_epoch = training_config.get('steps_per_epoch', 100)
|
|
|
|
|
|
for epoch in range(num_epochs):
|
|
|
logger.info(f"Starting epoch {epoch}")
|
|
|
await self._train_epoch(epoch, steps_per_epoch)
|
|
|
|
|
|
|
|
|
self._save_checkpoint(epoch)
|
|
|
except Exception as e:
|
|
|
logger.error(f"Training coordination error: {e}")
|
|
|
raise
|
|
|
|
|
|
async def _train_epoch(self, epoch: int, steps_per_epoch: int):
|
|
|
"""Train for one epoch."""
|
|
|
for step in range(steps_per_epoch):
|
|
|
|
|
|
active_agents = self.db_client.get_active_agents()
|
|
|
if not active_agents:
|
|
|
logger.warning("No active agents available")
|
|
|
await asyncio.sleep(5)
|
|
|
continue
|
|
|
|
|
|
|
|
|
gradient_jobs = await self._distribute_gradient_computation(
|
|
|
active_agents,
|
|
|
self.batch_size
|
|
|
)
|
|
|
|
|
|
|
|
|
gradients = await self._collect_gradients(gradient_jobs)
|
|
|
if gradients:
|
|
|
|
|
|
self._update_model_parameters(gradients)
|
|
|
|
|
|
|
|
|
await self._distribute_model_update()
|
|
|
|
|
|
async def _distribute_gradient_computation(
|
|
|
self,
|
|
|
agents: List[Dict[str, Any]],
|
|
|
batch_size: int
|
|
|
) -> List[str]:
|
|
|
"""Distribute gradient computation jobs to available agents."""
|
|
|
job_ids = []
|
|
|
|
|
|
|
|
|
current_state = self.db_client.get_latest_model_state()
|
|
|
if not current_state:
|
|
|
raise RuntimeError("No model state available")
|
|
|
|
|
|
|
|
|
for agent in agents:
|
|
|
job_id = self.db_client.create_job(
|
|
|
'gradient_computation',
|
|
|
{
|
|
|
'batch_size': batch_size,
|
|
|
'state': current_state['state']
|
|
|
}
|
|
|
)
|
|
|
job_ids.append(job_id)
|
|
|
|
|
|
return job_ids
|
|
|
|
|
|
async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
|
|
|
"""Collect gradients from completed jobs."""
|
|
|
all_gradients = []
|
|
|
timeout = 300
|
|
|
|
|
|
async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]:
|
|
|
start_time = asyncio.get_event_time()
|
|
|
while True:
|
|
|
if asyncio.get_event_time() - start_time > timeout:
|
|
|
logger.warning(f"Job {job_id} timed out")
|
|
|
return None
|
|
|
|
|
|
job = self.db_client.get_job(job_id)
|
|
|
if job['status'] == 'completed':
|
|
|
gradient_id = job['result']['gradient_id']
|
|
|
return self.db_client.get_gradients(gradient_id)
|
|
|
elif job['status'] == 'failed':
|
|
|
logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}")
|
|
|
return None
|
|
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
|
|
|
gradient_tasks = [wait_for_job(job_id) for job_id in job_ids]
|
|
|
gradients = await asyncio.gather(*gradient_tasks)
|
|
|
|
|
|
|
|
|
return [g for g in gradients if g is not None]
|
|
|
|
|
|
def _update_model_parameters(self, gradients: List[Dict[str, Any]]):
|
|
|
"""Update model parameters with collected gradients."""
|
|
|
try:
|
|
|
|
|
|
avg_gradients = TensorOps.average_gradients([
|
|
|
{k: torch.tensor(v) for k, v in g.items()}
|
|
|
for g in gradients
|
|
|
])
|
|
|
|
|
|
|
|
|
clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0)
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for name, param in self.model.named_parameters():
|
|
|
if name in clipped_gradients:
|
|
|
param.sub_(clipped_gradients[name] * self.model.config.learning_rate)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error updating model parameters: {e}")
|
|
|
raise
|
|
|
|
|
|
async def _distribute_model_update(self):
|
|
|
"""Distribute updated model state to all agents."""
|
|
|
try:
|
|
|
|
|
|
state = {
|
|
|
'model_state': self.model.state_dict(),
|
|
|
'timestamp': datetime.utcnow().isoformat()
|
|
|
}
|
|
|
state_id = self.db_client.store_model_state(state)
|
|
|
|
|
|
|
|
|
active_agents = self.db_client.get_active_agents()
|
|
|
for agent in active_agents:
|
|
|
self.db_client.create_job(
|
|
|
'model_update',
|
|
|
{
|
|
|
'state_id': state_id,
|
|
|
'state': state
|
|
|
}
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error distributing model update: {e}")
|
|
|
raise
|
|
|
|
|
|
def _save_checkpoint(self, epoch: int):
|
|
|
"""Save a checkpoint of the current model state."""
|
|
|
try:
|
|
|
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
|
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
|
|
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
|
|
|
torch.save({
|
|
|
'epoch': epoch,
|
|
|
'model_state_dict': self.model.state_dict(),
|
|
|
'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None
|
|
|
}, checkpoint_path)
|
|
|
|
|
|
logger.info(f"Saved checkpoint for epoch {epoch}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error saving checkpoint: {e}")
|
|
|
raise |