|
|
"""
|
|
|
Tensor operations for distributed computing.
|
|
|
"""
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
from typing import Dict, List, Optional, Union, Tuple
|
|
|
|
|
|
class TensorOps:
|
|
|
"""Utility class for distributed tensor operations."""
|
|
|
|
|
|
@staticmethod
|
|
|
def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]:
|
|
|
"""Split a tensor into multiple parts for distributed processing."""
|
|
|
return torch.chunk(tensor, num_parts)
|
|
|
|
|
|
@staticmethod
|
|
|
def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
|
|
|
"""Merge multiple tensors back into a single tensor."""
|
|
|
return torch.cat(tensors, dim=dim)
|
|
|
|
|
|
@staticmethod
|
|
|
def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
|
|
"""Average gradients from multiple workers."""
|
|
|
avg_gradients = {}
|
|
|
for key in gradients[0].keys():
|
|
|
avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0)
|
|
|
return avg_gradients
|
|
|
|
|
|
@staticmethod
|
|
|
def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]:
|
|
|
"""Serialize a tensor for storage/transmission."""
|
|
|
return {
|
|
|
'data': tensor.cpu().numpy().tolist(),
|
|
|
'shape': list(tensor.shape),
|
|
|
'dtype': str(tensor.dtype)
|
|
|
}
|
|
|
|
|
|
@staticmethod
|
|
|
def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor:
|
|
|
"""Deserialize a tensor from storage/transmission format."""
|
|
|
data = np.array(tensor_dict['data'])
|
|
|
shape = tensor_dict['shape']
|
|
|
dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1])
|
|
|
return torch.tensor(data, dtype=dtype).reshape(shape)
|
|
|
|
|
|
@staticmethod
|
|
|
def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]:
|
|
|
"""Apply gradient clipping to prevent exploding gradients."""
|
|
|
for k, v in gradients.items():
|
|
|
if v is not None:
|
|
|
torch.nn.utils.clip_grad_norm_(v, max_norm)
|
|
|
return gradients
|
|
|
|
|
|
@staticmethod
|
|
|
def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor:
|
|
|
"""Reduce tensor precision for efficient transmission."""
|
|
|
if bits == 16:
|
|
|
return tensor.half()
|
|
|
elif bits == 32:
|
|
|
return tensor.float()
|
|
|
else:
|
|
|
raise ValueError("Unsupported precision bits")
|
|
|
|
|
|
@staticmethod
|
|
|
def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]:
|
|
|
"""Shard a tensor into smaller pieces for distributed processing."""
|
|
|
return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)]
|
|
|
|
|
|
@staticmethod
|
|
|
def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float:
|
|
|
"""Compute the total norm of all parameters."""
|
|
|
total_norm = 0.0
|
|
|
for param in parameters.values():
|
|
|
total_norm += param.norm().item() ** 2
|
|
|
return total_norm ** 0.5 |