culture commited on
Commit
c483844
·
1 Parent(s): a229316

Upload gfpgan/data/ffhq_degradation_dataset.py

Browse files
gfpgan/data/ffhq_degradation_dataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os.path as osp
5
+ import torch
6
+ import torch.utils.data as data
7
+ from basicsr.data import degradations as degradations
8
+ from basicsr.data.data_util import paths_from_folder
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
11
+ from basicsr.utils.registry import DATASET_REGISTRY
12
+ from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
13
+ normalize)
14
+
15
+
16
+ @DATASET_REGISTRY.register()
17
+ class FFHQDegradationDataset(data.Dataset):
18
+ """FFHQ dataset for GFPGAN.
19
+
20
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
21
+
22
+ Args:
23
+ opt (dict): Config for train datasets. It contains the following keys:
24
+ dataroot_gt (str): Data root path for gt.
25
+ io_backend (dict): IO backend type and other kwarg.
26
+ mean (list | tuple): Image mean.
27
+ std (list | tuple): Image std.
28
+ use_hflip (bool): Whether to horizontally flip.
29
+ Please see more options in the codes.
30
+ """
31
+
32
+ def __init__(self, opt):
33
+ super(FFHQDegradationDataset, self).__init__()
34
+ self.opt = opt
35
+ # file client (io backend)
36
+ self.file_client = None
37
+ self.io_backend_opt = opt['io_backend']
38
+
39
+ self.gt_folder = opt['dataroot_gt']
40
+ self.mean = opt['mean']
41
+ self.std = opt['std']
42
+ self.out_size = opt['out_size']
43
+
44
+ self.crop_components = opt.get('crop_components', False) # facial components
45
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
46
+
47
+ if self.crop_components:
48
+ # load component list from a pre-process pth files
49
+ self.components_list = torch.load(opt.get('component_path'))
50
+
51
+ # file client (lmdb io backend)
52
+ if self.io_backend_opt['type'] == 'lmdb':
53
+ self.io_backend_opt['db_paths'] = self.gt_folder
54
+ if not self.gt_folder.endswith('.lmdb'):
55
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
56
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
57
+ self.paths = [line.split('.')[0] for line in fin]
58
+ else:
59
+ # disk backend: scan file list from a folder
60
+ self.paths = paths_from_folder(self.gt_folder)
61
+
62
+ # degradation configurations
63
+ self.blur_kernel_size = opt['blur_kernel_size']
64
+ self.kernel_list = opt['kernel_list']
65
+ self.kernel_prob = opt['kernel_prob']
66
+ self.blur_sigma = opt['blur_sigma']
67
+ self.downsample_range = opt['downsample_range']
68
+ self.noise_range = opt['noise_range']
69
+ self.jpeg_range = opt['jpeg_range']
70
+
71
+ # color jitter
72
+ self.color_jitter_prob = opt.get('color_jitter_prob')
73
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
74
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
75
+ # to gray
76
+ self.gray_prob = opt.get('gray_prob')
77
+
78
+ logger = get_root_logger()
79
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
80
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
81
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
82
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
83
+
84
+ if self.color_jitter_prob is not None:
85
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
86
+ if self.gray_prob is not None:
87
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
88
+ self.color_jitter_shift /= 255.
89
+
90
+ @staticmethod
91
+ def color_jitter(img, shift):
92
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
93
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
94
+ img = img + jitter_val
95
+ img = np.clip(img, 0, 1)
96
+ return img
97
+
98
+ @staticmethod
99
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
100
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
101
+ fn_idx = torch.randperm(4)
102
+ for fn_id in fn_idx:
103
+ if fn_id == 0 and brightness is not None:
104
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
105
+ img = adjust_brightness(img, brightness_factor)
106
+
107
+ if fn_id == 1 and contrast is not None:
108
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
109
+ img = adjust_contrast(img, contrast_factor)
110
+
111
+ if fn_id == 2 and saturation is not None:
112
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
113
+ img = adjust_saturation(img, saturation_factor)
114
+
115
+ if fn_id == 3 and hue is not None:
116
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
117
+ img = adjust_hue(img, hue_factor)
118
+ return img
119
+
120
+ def get_component_coordinates(self, index, status):
121
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
122
+ components_bbox = self.components_list[f'{index:08d}']
123
+ if status[0]: # hflip
124
+ # exchange right and left eye
125
+ tmp = components_bbox['left_eye']
126
+ components_bbox['left_eye'] = components_bbox['right_eye']
127
+ components_bbox['right_eye'] = tmp
128
+ # modify the width coordinate
129
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
130
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
131
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
132
+
133
+ # get coordinates
134
+ locations = []
135
+ for part in ['left_eye', 'right_eye', 'mouth']:
136
+ mean = components_bbox[part][0:2]
137
+ half_len = components_bbox[part][2]
138
+ if 'eye' in part:
139
+ half_len *= self.eye_enlarge_ratio
140
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
141
+ loc = torch.from_numpy(loc).float()
142
+ locations.append(loc)
143
+ return locations
144
+
145
+ def __getitem__(self, index):
146
+ if self.file_client is None:
147
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
148
+
149
+ # load gt image
150
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
151
+ gt_path = self.paths[index]
152
+ img_bytes = self.file_client.get(gt_path)
153
+ img_gt = imfrombytes(img_bytes, float32=True)
154
+
155
+ # random horizontal flip
156
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
157
+ h, w, _ = img_gt.shape
158
+
159
+ # get facial component coordinates
160
+ if self.crop_components:
161
+ locations = self.get_component_coordinates(index, status)
162
+ loc_left_eye, loc_right_eye, loc_mouth = locations
163
+
164
+ # ------------------------ generate lq image ------------------------ #
165
+ # blur
166
+ kernel = degradations.random_mixed_kernels(
167
+ self.kernel_list,
168
+ self.kernel_prob,
169
+ self.blur_kernel_size,
170
+ self.blur_sigma,
171
+ self.blur_sigma, [-math.pi, math.pi],
172
+ noise_range=None)
173
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
174
+ # downsample
175
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
176
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
177
+ # noise
178
+ if self.noise_range is not None:
179
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
180
+ # jpeg compression
181
+ if self.jpeg_range is not None:
182
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
183
+
184
+ # resize to original size
185
+ img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
186
+
187
+ # random color jitter (only for lq)
188
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
189
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
190
+ # random to gray (only for lq)
191
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
192
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
193
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
194
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
195
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
196
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
197
+
198
+ # BGR to RGB, HWC to CHW, numpy to tensor
199
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
200
+
201
+ # random color jitter (pytorch version) (only for lq)
202
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
203
+ brightness = self.opt.get('brightness', (0.5, 1.5))
204
+ contrast = self.opt.get('contrast', (0.5, 1.5))
205
+ saturation = self.opt.get('saturation', (0, 1.5))
206
+ hue = self.opt.get('hue', (-0.1, 0.1))
207
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
208
+
209
+ # round and clip
210
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
211
+
212
+ # normalize
213
+ normalize(img_gt, self.mean, self.std, inplace=True)
214
+ normalize(img_lq, self.mean, self.std, inplace=True)
215
+
216
+ if self.crop_components:
217
+ return_dict = {
218
+ 'lq': img_lq,
219
+ 'gt': img_gt,
220
+ 'gt_path': gt_path,
221
+ 'loc_left_eye': loc_left_eye,
222
+ 'loc_right_eye': loc_right_eye,
223
+ 'loc_mouth': loc_mouth
224
+ }
225
+ return return_dict
226
+ else:
227
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
228
+
229
+ def __len__(self):
230
+ return len(self.paths)