update misc.py, now support latest version of pytorch
Browse files- util/misc.py +3 -2
util/misc.py
CHANGED
|
@@ -14,7 +14,8 @@ from pathlib import Path
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.distributed as dist
|
| 17 |
-
from torch._six import inf
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
|
| 20 |
def log_codefiles(data_root,save_root):
|
|
@@ -303,7 +304,7 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
|
| 303 |
if len(parameters) == 0:
|
| 304 |
return torch.tensor(0.)
|
| 305 |
device = parameters[0].grad.device
|
| 306 |
-
if norm_type == inf:
|
| 307 |
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 308 |
else:
|
| 309 |
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
|
|
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import torch.distributed as dist
|
| 17 |
+
#from torch._six import inf
|
| 18 |
+
import inf
|
| 19 |
import numpy as np
|
| 20 |
|
| 21 |
def log_codefiles(data_root,save_root):
|
|
|
|
| 304 |
if len(parameters) == 0:
|
| 305 |
return torch.tensor(0.)
|
| 306 |
device = parameters[0].grad.device
|
| 307 |
+
if norm_type == math.inf:
|
| 308 |
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 309 |
else:
|
| 310 |
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|