Cloud-Agents / cloud_agents /tensor_ops.py
Mentors4EDU's picture
Upload 14 files
f2bab5e verified
"""
Tensor operations for distributed computing.
"""
import torch
import numpy as np
from typing import Dict, List, Optional, Union, Tuple
class TensorOps:
"""Utility class for distributed tensor operations."""
@staticmethod
def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]:
"""Split a tensor into multiple parts for distributed processing."""
return torch.chunk(tensor, num_parts)
@staticmethod
def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
"""Merge multiple tensors back into a single tensor."""
return torch.cat(tensors, dim=dim)
@staticmethod
def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""Average gradients from multiple workers."""
avg_gradients = {}
for key in gradients[0].keys():
avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0)
return avg_gradients
@staticmethod
def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]:
"""Serialize a tensor for storage/transmission."""
return {
'data': tensor.cpu().numpy().tolist(),
'shape': list(tensor.shape),
'dtype': str(tensor.dtype)
}
@staticmethod
def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor:
"""Deserialize a tensor from storage/transmission format."""
data = np.array(tensor_dict['data'])
shape = tensor_dict['shape']
dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1])
return torch.tensor(data, dtype=dtype).reshape(shape)
@staticmethod
def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]:
"""Apply gradient clipping to prevent exploding gradients."""
for k, v in gradients.items():
if v is not None:
torch.nn.utils.clip_grad_norm_(v, max_norm)
return gradients
@staticmethod
def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor:
"""Reduce tensor precision for efficient transmission."""
if bits == 16:
return tensor.half()
elif bits == 32:
return tensor.float()
else:
raise ValueError("Unsupported precision bits")
@staticmethod
def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]:
"""Shard a tensor into smaller pieces for distributed processing."""
return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)]
@staticmethod
def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float:
"""Compute the total norm of all parameters."""
total_norm = 0.0
for param in parameters.values():
total_norm += param.norm().item() ** 2
return total_norm ** 0.5