culture commited on
Commit
21b1ebc
·
1 Parent(s): 5995b60

Delete tests/test_gfpgan_model.py

Browse files
Files changed (1) hide show
  1. tests/test_gfpgan_model.py +0 -132
tests/test_gfpgan_model.py DELETED
@@ -1,132 +0,0 @@
1
- import tempfile
2
- import torch
3
- import yaml
4
- from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator
5
- from basicsr.data.paired_image_dataset import PairedImageDataset
6
- from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
7
-
8
- from gfpgan.archs.arcface_arch import ResNetArcFace
9
- from gfpgan.archs.gfpganv1_arch import FacialComponentDiscriminator, GFPGANv1
10
- from gfpgan.models.gfpgan_model import GFPGANModel
11
-
12
-
13
- def test_gfpgan_model():
14
- with open('tests/data/test_gfpgan_model.yml', mode='r') as f:
15
- opt = yaml.load(f, Loader=yaml.FullLoader)
16
-
17
- # build model
18
- model = GFPGANModel(opt)
19
- # test attributes
20
- assert model.__class__.__name__ == 'GFPGANModel'
21
- assert isinstance(model.net_g, GFPGANv1) # generator
22
- assert isinstance(model.net_d, StyleGAN2Discriminator) # discriminator
23
- # facial component discriminators
24
- assert isinstance(model.net_d_left_eye, FacialComponentDiscriminator)
25
- assert isinstance(model.net_d_right_eye, FacialComponentDiscriminator)
26
- assert isinstance(model.net_d_mouth, FacialComponentDiscriminator)
27
- # identity network
28
- assert isinstance(model.network_identity, ResNetArcFace)
29
- # losses
30
- assert isinstance(model.cri_pix, L1Loss)
31
- assert isinstance(model.cri_perceptual, PerceptualLoss)
32
- assert isinstance(model.cri_gan, GANLoss)
33
- assert isinstance(model.cri_l1, L1Loss)
34
- # optimizer
35
- assert isinstance(model.optimizers[0], torch.optim.Adam)
36
- assert isinstance(model.optimizers[1], torch.optim.Adam)
37
-
38
- # prepare data
39
- gt = torch.rand((1, 3, 512, 512), dtype=torch.float32)
40
- lq = torch.rand((1, 3, 512, 512), dtype=torch.float32)
41
- loc_left_eye = torch.rand((1, 4), dtype=torch.float32)
42
- loc_right_eye = torch.rand((1, 4), dtype=torch.float32)
43
- loc_mouth = torch.rand((1, 4), dtype=torch.float32)
44
- data = dict(gt=gt, lq=lq, loc_left_eye=loc_left_eye, loc_right_eye=loc_right_eye, loc_mouth=loc_mouth)
45
- model.feed_data(data)
46
- # check data shape
47
- assert model.lq.shape == (1, 3, 512, 512)
48
- assert model.gt.shape == (1, 3, 512, 512)
49
- assert model.loc_left_eyes.shape == (1, 4)
50
- assert model.loc_right_eyes.shape == (1, 4)
51
- assert model.loc_mouths.shape == (1, 4)
52
-
53
- # ----------------- test optimize_parameters -------------------- #
54
- model.feed_data(data)
55
- model.optimize_parameters(1)
56
- assert model.output.shape == (1, 3, 512, 512)
57
- assert isinstance(model.log_dict, dict)
58
- # check returned keys
59
- expected_keys = [
60
- 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
61
- 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
62
- 'l_d_right_eye', 'l_d_mouth'
63
- ]
64
- assert set(expected_keys).issubset(set(model.log_dict.keys()))
65
-
66
- # ----------------- remove pyramid_loss_weight-------------------- #
67
- model.feed_data(data)
68
- model.optimize_parameters(100000) # large than remove_pyramid_loss = 50000
69
- assert model.output.shape == (1, 3, 512, 512)
70
- assert isinstance(model.log_dict, dict)
71
- # check returned keys
72
- expected_keys = [
73
- 'l_g_pix', 'l_g_percep', 'l_g_style', 'l_g_gan', 'l_g_gan_left_eye', 'l_g_gan_right_eye', 'l_g_gan_mouth',
74
- 'l_g_comp_style_loss', 'l_identity', 'l_d', 'real_score', 'fake_score', 'l_d_r1', 'l_d_left_eye',
75
- 'l_d_right_eye', 'l_d_mouth'
76
- ]
77
- assert set(expected_keys).issubset(set(model.log_dict.keys()))
78
-
79
- # ----------------- test save -------------------- #
80
- with tempfile.TemporaryDirectory() as tmpdir:
81
- model.opt['path']['models'] = tmpdir
82
- model.opt['path']['training_states'] = tmpdir
83
- model.save(0, 1)
84
-
85
- # ----------------- test the test function -------------------- #
86
- model.test()
87
- assert model.output.shape == (1, 3, 512, 512)
88
- # delete net_g_ema
89
- model.__delattr__('net_g_ema')
90
- model.test()
91
- assert model.output.shape == (1, 3, 512, 512)
92
- assert model.net_g.training is True # should back to training mode after testing
93
-
94
- # ----------------- test nondist_validation -------------------- #
95
- # construct dataloader
96
- dataset_opt = dict(
97
- name='Demo',
98
- dataroot_gt='tests/data/gt',
99
- dataroot_lq='tests/data/gt',
100
- io_backend=dict(type='disk'),
101
- scale=4,
102
- phase='val')
103
- dataset = PairedImageDataset(dataset_opt)
104
- dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
105
- assert model.is_train is True
106
- with tempfile.TemporaryDirectory() as tmpdir:
107
- model.opt['path']['visualization'] = tmpdir
108
- model.nondist_validation(dataloader, 1, None, save_img=True)
109
- assert model.is_train is True
110
- # check metric_results
111
- assert 'psnr' in model.metric_results
112
- assert isinstance(model.metric_results['psnr'], float)
113
-
114
- # validation
115
- with tempfile.TemporaryDirectory() as tmpdir:
116
- model.opt['is_train'] = False
117
- model.opt['val']['suffix'] = 'test'
118
- model.opt['path']['visualization'] = tmpdir
119
- model.opt['val']['pbar'] = True
120
- model.nondist_validation(dataloader, 1, None, save_img=True)
121
- # check metric_results
122
- assert 'psnr' in model.metric_results
123
- assert isinstance(model.metric_results['psnr'], float)
124
-
125
- # if opt['val']['suffix'] is None
126
- model.opt['val']['suffix'] = None
127
- model.opt['name'] = 'demo'
128
- model.opt['path']['visualization'] = tmpdir
129
- model.nondist_validation(dataloader, 1, None, save_img=True)
130
- # check metric_results
131
- assert 'psnr' in model.metric_results
132
- assert isinstance(model.metric_results['psnr'], float)