Spaces:
No application file
No application file
culture
commited on
Commit
·
21b1ebc
1
Parent(s):
5995b60
Delete tests/test_gfpgan_model.py
Browse files- tests/test_gfpgan_model.py +0 -132
tests/test_gfpgan_model.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import tempfile
|
| 2 |
-
import torch
|
| 3 |
-
import yaml
|
| 4 |
-
from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
|
| 5 |
-
from basicsr.data.paired_image_dataset import PairedImageDataset
|
| 6 |
-
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
|
| 7 |
-
|
| 8 |
-
from gfpgan.archs.arcface_arch import ResNetArcFace
|
| 9 |
-
from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
|
| 10 |
-
from gfpgan.models.gfpgan_model import GFPGANModel
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def test_gfpgan_model():
|
| 14 |
-
with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
|
| 15 |
-
opt = yaml.load(f, Loader=yaml.FullLoader)
|
| 16 |
-
|
| 17 |
-
# build model
|
| 18 |
-
model = GFPGANModel(opt)
|
| 19 |
-
# test attributes
|
| 20 |
-
assert model.__class__.__name__ == 'GFPGANModel'
|
| 21 |
-
assert isinstance(model.net_g, GFPGANv1) # generator
|
| 22 |
-
assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
|
| 23 |
-
# facial component discriminators
|
| 24 |
-
assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
|
| 25 |
-
assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
|
| 26 |
-
assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
|
| 27 |
-
# identity network
|
| 28 |
-
assert isinstance(model.network_identity, ResNetArcFace)
|
| 29 |
-
# losses
|
| 30 |
-
assert isinstance(model.cri_pix, L1Loss)
|
| 31 |
-
assert isinstance(model.cri_perceptual, PerceptualLoss)
|
| 32 |
-
assert isinstance(model.cri_gan, GANLoss)
|
| 33 |
-
assert isinstance(model.cri_l1, L1Loss)
|
| 34 |
-
# optimizer
|
| 35 |
-
assert isinstance(model.optimizers[0], torch.optim.Adam)
|
| 36 |
-
assert isinstance(model.optimizers[1], torch.optim.Adam)
|
| 37 |
-
|
| 38 |
-
# prepare data
|
| 39 |
-
gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
| 40 |
-
lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
|
| 41 |
-
loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
|
| 42 |
-
loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
|
| 43 |
-
loc_mouth = torch.rand((1, 4), dtype=torch.float32)
|
| 44 |
-
data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
|
| 45 |
-
model.feed_data(data)
|
| 46 |
-
# check data shape
|
| 47 |
-
assert model.lq.shape == (1, 3, 512, 512)
|
| 48 |
-
assert model.gt.shape == (1, 3, 512, 512)
|
| 49 |
-
assert model.loc_left_eyes.shape == (1, 4)
|
| 50 |
-
assert model.loc_right_eyes.shape == (1, 4)
|
| 51 |
-
assert model.loc_mouths.shape == (1, 4)
|
| 52 |
-
|
| 53 |
-
# ----------------- test optimize_parameters -------------------- #
|
| 54 |
-
model.feed_data(data)
|
| 55 |
-
model.optimize_parameters(1)
|
| 56 |
-
assert model.output.shape == (1, 3, 512, 512)
|
| 57 |
-
assert isinstance(model.log_dict, dict)
|
| 58 |
-
# check returned keys
|
| 59 |
-
expected_keys = [
|
| 60 |
-
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
| 61 |
-
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
| 62 |
-
'l_d_right_eye', 'l_d_mouth'
|
| 63 |
-
]
|
| 64 |
-
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
| 65 |
-
|
| 66 |
-
# ----------------- remove pyramid_loss_weight-------------------- #
|
| 67 |
-
model.feed_data(data)
|
| 68 |
-
model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
|
| 69 |
-
assert model.output.shape == (1, 3, 512, 512)
|
| 70 |
-
assert isinstance(model.log_dict, dict)
|
| 71 |
-
# check returned keys
|
| 72 |
-
expected_keys = [
|
| 73 |
-
'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
|
| 74 |
-
'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
|
| 75 |
-
'l_d_right_eye', 'l_d_mouth'
|
| 76 |
-
]
|
| 77 |
-
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|
| 78 |
-
|
| 79 |
-
# ----------------- test save -------------------- #
|
| 80 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 81 |
-
model.opt['path']['models'] = tmpdir
|
| 82 |
-
model.opt['path']['training_states'] = tmpdir
|
| 83 |
-
model.save(0, 1)
|
| 84 |
-
|
| 85 |
-
# ----------------- test the test function -------------------- #
|
| 86 |
-
model.test()
|
| 87 |
-
assert model.output.shape == (1, 3, 512, 512)
|
| 88 |
-
# delete net_g_ema
|
| 89 |
-
model.__delattr__('net_g_ema')
|
| 90 |
-
model.test()
|
| 91 |
-
assert model.output.shape == (1, 3, 512, 512)
|
| 92 |
-
assert model.net_g.training is True # should back to training mode after testing
|
| 93 |
-
|
| 94 |
-
# ----------------- test nondist_validation -------------------- #
|
| 95 |
-
# construct dataloader
|
| 96 |
-
dataset_opt = dict(
|
| 97 |
-
name='Demo',
|
| 98 |
-
dataroot_gt='tests/data/gt',
|
| 99 |
-
dataroot_lq='tests/data/gt',
|
| 100 |
-
io_backend=dict(type='disk'),
|
| 101 |
-
scale=4,
|
| 102 |
-
phase='val')
|
| 103 |
-
dataset = PairedImageDataset(dataset_opt)
|
| 104 |
-
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
| 105 |
-
assert model.is_train is True
|
| 106 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 107 |
-
model.opt['path']['visualization'] = tmpdir
|
| 108 |
-
model.nondist_validation(dataloader, 1, None, save_img=True)
|
| 109 |
-
assert model.is_train is True
|
| 110 |
-
# check metric_results
|
| 111 |
-
assert 'psnr' in model.metric_results
|
| 112 |
-
assert isinstance(model.metric_results['psnr'], float)
|
| 113 |
-
|
| 114 |
-
# validation
|
| 115 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 116 |
-
model.opt['is_train'] = False
|
| 117 |
-
model.opt['val']['suffix'] = 'test'
|
| 118 |
-
model.opt['path']['visualization'] = tmpdir
|
| 119 |
-
model.opt['val']['pbar'] = True
|
| 120 |
-
model.nondist_validation(dataloader, 1, None, save_img=True)
|
| 121 |
-
# check metric_results
|
| 122 |
-
assert 'psnr' in model.metric_results
|
| 123 |
-
assert isinstance(model.metric_results['psnr'], float)
|
| 124 |
-
|
| 125 |
-
# if opt['val']['suffix'] is None
|
| 126 |
-
model.opt['val']['suffix'] = None
|
| 127 |
-
model.opt['name'] = 'demo'
|
| 128 |
-
model.opt['path']['visualization'] = tmpdir
|
| 129 |
-
model.nondist_validation(dataloader, 1, None, save_img=True)
|
| 130 |
-
# check metric_results
|
| 131 |
-
assert 'psnr' in model.metric_results
|
| 132 |
-
assert isinstance(model.metric_results['psnr'], float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|