File size: 6,012 Bytes
f2bab5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""

Base agent class for distributed computing.

"""
import torch
import ray
import uuid
import asyncio
from typing import Dict, Any, Optional
from datetime import datetime
import logging
from .couchdb_client import CouchDBClient
from .config import settings

logger = logging.getLogger(__name__)

@ray.remote
class Agent:
    """Distributed computing agent for tensor operations and model training."""
    
    def __init__(self):
        self.agent_id = str(uuid.uuid4())
        self.db_client = CouchDBClient()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.current_job: Optional[Dict] = None
        self._register_agent()
        self._start_heartbeat()
    
    def _register_agent(self):
        """Register agent with the cluster."""
        capabilities = {
            "device": str(self.device),
            "cuda_available": torch.cuda.is_available(),
            "cuda_devices": torch.cuda.device_count() if torch.cuda.is_available() else 0,
            "memory_available": torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0
        }
        success = self.db_client.register_agent(self.agent_id, capabilities)
        if not success:
            raise RuntimeError("Failed to register agent")
    
    def _start_heartbeat(self):
        """Start agent heartbeat."""
        async def heartbeat_loop():
            while True:
                try:
                    self.db_client.update_heartbeat(self.agent_id)
                    await asyncio.sleep(30)
                except Exception as e:
                    logger.error(f"Heartbeat error: {e}")
                    await asyncio.sleep(5)
        
        asyncio.create_task(heartbeat_loop())
    
    def process_tensors(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Process tensor operations."""
        results = {}
        for name, tensor in tensors.items():
            tensor = tensor.to(self.device)
            # Perform tensor operations
            results[name] = self._compute_tensor(tensor)
        return results
    
    def _compute_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
        """Compute operations on a single tensor."""
        # Add custom tensor operations here
        return tensor
    
    async def run(self):
        """Main agent loop."""
        while True:
            try:
                # Try to claim a job
                job = self.db_client.claim_job(self.agent_id)
                if job:
                    self.current_job = job
                    await self._process_job(job)
                else:
                    await asyncio.sleep(1)
            except Exception as e:
                logger.error(f"Error in agent loop: {e}")
                await asyncio.sleep(5)
    
    async def _process_job(self, job: Dict[str, Any]):
        """Process a claimed job."""
        try:
            job_type = job['type']
            params = job['params']
            
            result = None
            if job_type == 'gradient_computation':
                result = await self._compute_gradients(params)
            elif job_type == 'model_update':
                result = await self._update_model(params)
            
            # Store job results
            self.db_client.update_job_status(
                job['_id'],
                'completed',
                result
            )
        except Exception as e:
            logger.error(f"Job processing error: {e}")
            self.db_client.update_job_status(
                job['_id'],
                'failed',
                {'error': str(e)}
            )
        finally:
            self.current_job = None
    
    async def _compute_gradients(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """Compute gradients for model training."""
        try:
            # Load model checkpoint
            checkpoint = params.get('checkpoint')
            if checkpoint:
                state_dict = torch.load(checkpoint, map_location=self.device)
                # Compute gradients
                gradients = self._compute_model_gradients(state_dict, params.get('batch'))
                # Store gradients in CouchDB
                gradient_id = self.db_client.store_gradients(
                    self.current_job['_id'],
                    gradients
                )
                return {'gradient_id': gradient_id}
        except Exception as e:
            logger.error(f"Gradient computation error: {e}")
            raise
    
    def _compute_model_gradients(self, state_dict: Dict[str, torch.Tensor], batch: Dict[str, Any]) -> Dict[str, Any]:
        """Compute gradients for a given model state and batch."""
        # Convert gradients to serializable format
        gradients = {}
        for name, param in state_dict.items():
            if param.requires_grad:
                grad = param.grad
                if grad is not None:
                    gradients[name] = grad.cpu().numpy().tolist()
        return gradients
    
    async def _update_model(self, params: Dict[str, Any]) -> Dict[str, Any]:
        """Update model with new parameters."""
        try:
            new_state = params.get('state')
            if new_state:
                # Apply model updates
                state_id = self.db_client.store_model_state(new_state)
                return {'state_id': state_id}
        except Exception as e:
            logger.error(f"Model update error: {e}")
            raise
    
    def shutdown(self):
        """Shutdown the agent."""
        # Update agent status to inactive
        self.db_client.update_job_status(
            self.agent_id,
            'inactive'
        )
        # Clean up resources
        if torch.cuda.is_available():
            torch.cuda.empty_cache()