File size: 8,217 Bytes
f2bab5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""

Coordinator for distributed model training.

"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Any, Optional
import asyncio
import logging
from huggingface_hub import snapshot_download
import os
import ray
from .couchdb_client import CouchDBClient
from .config import settings
from .tensor_ops import TensorOps

logger = logging.getLogger(__name__)

class Coordinator:
    """Coordinator for distributed training of OpenPeerLLM."""
    
    def __init__(self):
        self.db_client = CouchDBClient()
        self.model_id = settings.MODEL_ID
        self.batch_size = settings.BATCH_SIZE
        self.gradient_accumulation_steps = settings.GRADIENT_ACCUMULATION_STEPS
        self._initialize_model()
        
    def _initialize_model(self):
        """Initialize the model and tokenizer."""
        try:
            # Download model and tokenizer from Hugging Face
            cache_dir = snapshot_download(self.model_id)
            self.model = AutoModelForCausalLM.from_pretrained(cache_dir)
            self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)
            
            # Store initial model state
            initial_state = {
                'model_state': self.model.state_dict(),
                'step': 0,
                'epoch': 0
            }
            self.db_client.store_model_state(initial_state)
            
        except Exception as e:
            logger.error(f"Failed to initialize model: {e}")
            raise
    
    async def coordinate_training(self, training_config: Dict[str, Any]):
        """Coordinate distributed training across agents."""
        try:
            num_epochs = training_config.get('num_epochs', 1)
            steps_per_epoch = training_config.get('steps_per_epoch', 100)
            
            for epoch in range(num_epochs):
                logger.info(f"Starting epoch {epoch}")
                await self._train_epoch(epoch, steps_per_epoch)
                
                # Save checkpoint after each epoch
                self._save_checkpoint(epoch)
        except Exception as e:
            logger.error(f"Training coordination error: {e}")
            raise
    
    async def _train_epoch(self, epoch: int, steps_per_epoch: int):
        """Train for one epoch."""
        for step in range(steps_per_epoch):
            # Get active agents
            active_agents = self.db_client.get_active_agents()
            if not active_agents:
                logger.warning("No active agents available")
                await asyncio.sleep(5)
                continue
            
            # Distribute gradient computation jobs
            gradient_jobs = await self._distribute_gradient_computation(
                active_agents,
                self.batch_size
            )
            
            # Collect and process gradients
            gradients = await self._collect_gradients(gradient_jobs)
            if gradients:
                # Update model with collected gradients
                self._update_model_parameters(gradients)
                
                # Distribute updated model state to agents
                await self._distribute_model_update()
    
    async def _distribute_gradient_computation(

        self,

        agents: List[Dict[str, Any]],

        batch_size: int

    ) -> List[str]:
        """Distribute gradient computation jobs to available agents."""
        job_ids = []
        
        # Get current model state
        current_state = self.db_client.get_latest_model_state()
        if not current_state:
            raise RuntimeError("No model state available")
        
        # Create gradient computation jobs
        for agent in agents:
            job_id = self.db_client.create_job(
                'gradient_computation',
                {
                    'batch_size': batch_size,
                    'state': current_state['state']
                }
            )
            job_ids.append(job_id)
        
        return job_ids
    
    async def _collect_gradients(self, job_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
        """Collect gradients from completed jobs."""
        all_gradients = []
        timeout = 300  # 5 minutes timeout
        
        async def wait_for_job(job_id: str) -> Optional[Dict[str, Any]]:
            start_time = asyncio.get_event_time()
            while True:
                if asyncio.get_event_time() - start_time > timeout:
                    logger.warning(f"Job {job_id} timed out")
                    return None
                
                job = self.db_client.get_job(job_id)
                if job['status'] == 'completed':
                    gradient_id = job['result']['gradient_id']
                    return self.db_client.get_gradients(gradient_id)
                elif job['status'] == 'failed':
                    logger.error(f"Job {job_id} failed: {job.get('result', {}).get('error')}")
                    return None
                
                await asyncio.sleep(1)
        
        # Wait for all gradient computations to complete
        gradient_tasks = [wait_for_job(job_id) for job_id in job_ids]
        gradients = await asyncio.gather(*gradient_tasks)
        
        # Filter out None results (failed jobs)
        return [g for g in gradients if g is not None]
    
    def _update_model_parameters(self, gradients: List[Dict[str, Any]]):
        """Update model parameters with collected gradients."""
        try:
            # Average gradients from all workers
            avg_gradients = TensorOps.average_gradients([
                {k: torch.tensor(v) for k, v in g.items()}
                for g in gradients
            ])
            
            # Apply gradient clipping
            clipped_gradients = TensorOps.gradient_clipping(avg_gradients, max_norm=1.0)
            
            # Update model parameters
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    if name in clipped_gradients:
                        param.sub_(clipped_gradients[name] * self.model.config.learning_rate)
        
        except Exception as e:
            logger.error(f"Error updating model parameters: {e}")
            raise
    
    async def _distribute_model_update(self):
        """Distribute updated model state to all agents."""
        try:
            # Store updated model state
            state = {
                'model_state': self.model.state_dict(),
                'timestamp': datetime.utcnow().isoformat()
            }
            state_id = self.db_client.store_model_state(state)
            
            # Create model update jobs for all active agents
            active_agents = self.db_client.get_active_agents()
            for agent in active_agents:
                self.db_client.create_job(
                    'model_update',
                    {
                        'state_id': state_id,
                        'state': state
                    }
                )
        
        except Exception as e:
            logger.error(f"Error distributing model update: {e}")
            raise
    
    def _save_checkpoint(self, epoch: int):
        """Save a checkpoint of the current model state."""
        try:
            checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints')
            os.makedirs(checkpoint_dir, exist_ok=True)
            
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self, 'optimizer') else None
            }, checkpoint_path)
            
            logger.info(f"Saved checkpoint for epoch {epoch}")
        
        except Exception as e:
            logger.error(f"Error saving checkpoint: {e}")
            raise