Deep-Multi-scale / models /base_model.py
GoodWin's picture
Add files
0f691e2
raw
history blame
7.36 kB
import os
import torch
from collections import OrderedDict
from . import networks
class BaseModel():
# modify parser to add command line options,
# and also change the default values if needed
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
# self.optimizers = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# load and print networks; create schedulers
def setup(self, opt, parser=None):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
# print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
if not os.path.exists(load_path):
continue
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
model_dict = net.state_dict()
# new_dict = {k: v for k, v in state_dict.items() if k in model_dict.keys()}
new_dict = {}
for k, v in state_dict.items():
if k in model_dict.keys():
# print(k)
# if k == 'sff_branch.0.sff0.MaskModel.0.weight' or k =='sff_branch.0.sff1.MaskModel.0.weight' or k == 'sff_branch.1.sff0.MaskModel.0.weight' or k =='sff_branch.1.sff1.MaskModel.0.weight' or k == 'sff_branch.2.sff0.MaskModel.0.weight' or k =='sff_branch.2.sff1.MaskModel.0.weight' or k == 'sff_branch.3.sff0.MaskModel.0.weight' or k =='sff_branch.3.sff1.MaskModel.0.weight' or k == 'sff_branch.4.MaskModel.0.weight' :
# continue
# if 'Mask_CModel.model' in k:
# continue
new_dict[k] = v
model_dict.update(new_dict)
net.load_state_dict(model_dict)
# print network information
def print_networks(self, verbose):
# print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
# if verbose:
# print(net)
# print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
# print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad