Spaces:
Sleeping
Sleeping
File size: 4,122 Bytes
0f691e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# -- coding: utf-8 --
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import torchvision
# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, norm=1, imtype=np.uint8):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
if norm == 1: #for clamp -1 to 1
image_numpy = image_tensor[0].cpu().float().clamp_(-1,1).numpy()
elif norm == 2: # for norm through max-min
image_ = image_tensor[0].cpu().float()
max_ = torch.max(image_)
min_ = torch.min(image_)
image_numpy = (image_ - min_)/(max_-min_)*2-1
image_numpy = image_numpy.numpy()
else:
pass
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
# print(image_numpy.shape)
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
# print(image_numpy.shape)
return image_numpy.astype(imtype)
def tensor2im3Channels(input_image, imtype=np.uint8):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor.cpu().float().clamp_(-1,1).numpy()
# print(image_numpy.shape)
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
# print(image_numpy.shape)
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def print_current_losses(epoch, i, losses, t, t_data):
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message)
# with open('', "a") as log_file:
# log_file.write('%s\n' % message)
def display_current_results(writer,visuals,losses,step,save_result):
for label, images in visuals.items():
if 'Mask' in label:# or 'Scale' in label:
grid = torchvision.utils.make_grid(images,normalize=False, scale_each=True)
# pass
else:
pass
grid = torchvision.utils.make_grid(images,normalize=True, scale_each=True)
writer.add_image(label,grid,step)
for k,v in losses.items():
writer.add_scalar(k,v,step)
def VisualFeature(input_feature, imtype=np.uint8):
if isinstance(input_feature, torch.Tensor):
image_tensor = input_feature.data
else:
return input_feature
image_ = image_tensor.cpu().float()
if image_.size(1) == 3:
image_ = image_.permute(1,2,0)
# assert(image_.size(1) == 1)
#####norm 0 to 1
max_ = torch.max(image_)
min_ = torch.min(image_)
image_numpy = (image_ - min_)/(max_-min_)*2-1
image_numpy = image_numpy.numpy()
image_numpy = (image_numpy + 1) / 2.0 * 255.0
#####no norm
# print((max_,min_))
# image_numpy = image_.numpy()
# image_numpy = image_numpy*255.0
# print('wwwwwwwwwwwwww')
# print(max_)
# print(min_)
# print(image_numpy.shape)
return image_numpy.astype(imtype)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
|