Spaces:
Build error
Build error
| import torch | |
| from torch.nn import init | |
| def init_weights(net, init_type="normal", init_gain=0.02): | |
| """Initialize network weights. | |
| Parameters: | |
| net (network) -- network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| init_gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might | |
| work better for some applications. Feel free to try yourself. | |
| """ | |
| def init_func(m): # define the initialization function | |
| classname = m.__class__.__name__ | |
| if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): | |
| if init_type == "normal": | |
| init.normal_(m.weight.data, 0.0, init_gain) | |
| elif init_type == "xavier": | |
| init.xavier_normal_(m.weight.data, gain=init_gain) | |
| elif init_type == "kaiming": | |
| init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") | |
| elif init_type == "orthogonal": | |
| init.orthogonal_(m.weight.data, gain=init_gain) | |
| else: | |
| raise NotImplementedError("initialization method [%s] is not implemented" % init_type) | |
| if hasattr(m, "bias") and m.bias is not None: | |
| init.constant_(m.bias.data, 0.0) | |
| elif ( | |
| classname.find("BatchNorm2d") != -1 | |
| ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. | |
| init.normal_(m.weight.data, 1.0, init_gain) | |
| init.constant_(m.bias.data, 0.0) | |
| # print("initialize network with %s" % init_type) | |
| net.apply(init_func) # apply the initialization function <init_func> | |
| def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): | |
| """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights | |
| Parameters: | |
| net (network) -- the network to be initialized | |
| init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal | |
| gain (float) -- scaling factor for normal, xavier and orthogonal. | |
| gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 | |
| Return an initialized network. | |
| """ | |
| if len(gpu_ids) > 0: | |
| assert torch.cuda.is_available() | |
| net.to(gpu_ids[0]) | |
| # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs | |
| init_weights(net, init_type, init_gain=init_gain) | |
| return net | |