File size: 3,091 Bytes
f2bab5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""

Tensor operations for distributed computing.

"""
import torch
import numpy as np
from typing import Dict, List, Optional, Union, Tuple

class TensorOps:
    """Utility class for distributed tensor operations."""
    
    @staticmethod
    def split_tensor(tensor: torch.Tensor, num_parts: int) -> List[torch.Tensor]:
        """Split a tensor into multiple parts for distributed processing."""
        return torch.chunk(tensor, num_parts)
    
    @staticmethod
    def merge_tensors(tensors: List[torch.Tensor], dim: int = 0) -> torch.Tensor:
        """Merge multiple tensors back into a single tensor."""
        return torch.cat(tensors, dim=dim)
    
    @staticmethod
    def average_gradients(gradients: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        """Average gradients from multiple workers."""
        avg_gradients = {}
        for key in gradients[0].keys():
            avg_gradients[key] = torch.mean(torch.stack([g[key] for g in gradients]), dim=0)
        return avg_gradients
    
    @staticmethod
    def serialize_tensor(tensor: torch.Tensor) -> Dict[str, Union[List, str]]:
        """Serialize a tensor for storage/transmission."""
        return {
            'data': tensor.cpu().numpy().tolist(),
            'shape': list(tensor.shape),
            'dtype': str(tensor.dtype)
        }
    
    @staticmethod
    def deserialize_tensor(tensor_dict: Dict[str, Union[List, str]]) -> torch.Tensor:
        """Deserialize a tensor from storage/transmission format."""
        data = np.array(tensor_dict['data'])
        shape = tensor_dict['shape']
        dtype = getattr(torch, tensor_dict['dtype'].split('.')[-1])
        return torch.tensor(data, dtype=dtype).reshape(shape)
    
    @staticmethod
    def gradient_clipping(gradients: Dict[str, torch.Tensor], max_norm: float) -> Dict[str, torch.Tensor]:
        """Apply gradient clipping to prevent exploding gradients."""
        for k, v in gradients.items():
            if v is not None:
                torch.nn.utils.clip_grad_norm_(v, max_norm)
        return gradients
    
    @staticmethod
    def reduce_precision(tensor: torch.Tensor, bits: int = 16) -> torch.Tensor:
        """Reduce tensor precision for efficient transmission."""
        if bits == 16:
            return tensor.half()
        elif bits == 32:
            return tensor.float()
        else:
            raise ValueError("Unsupported precision bits")
    
    @staticmethod
    def shard_tensor(tensor: torch.Tensor, shard_size: int) -> List[torch.Tensor]:
        """Shard a tensor into smaller pieces for distributed processing."""
        return [tensor[i:i + shard_size] for i in range(0, tensor.size(0), shard_size)]
    
    @staticmethod
    def compute_parameter_norm(parameters: Dict[str, torch.Tensor]) -> float:
        """Compute the total norm of all parameters."""
        total_norm = 0.0
        for param in parameters.values():
            total_norm += param.norm().item() ** 2
        return total_norm ** 0.5