GoodWin commited on
Commit
51841ef
·
1 Parent(s): 916a4b8

Upload test_FaceDict.py

Browse files
Files changed (1) hide show
  1. test_FaceDict.py +287 -0
test_FaceDict.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from options.test_options import TestOptions
3
+ from data import CreateDataLoader
4
+ from models import create_model
5
+ from util.visualizer import save_crop
6
+ from util import html
7
+ import numpy as np
8
+ import math
9
+ from PIL import Image
10
+ import torchvision.transforms as transforms
11
+ import torch
12
+ import random
13
+ import cv2
14
+ import dlib
15
+ from skimage import transform as trans
16
+ from skimage import io
17
+ from data.image_folder import make_dataset
18
+ import sys
19
+ sys.path.append('FaceLandmarkDetection')
20
+ import face_alignment
21
+
22
+ ###########################################################################
23
+ ################# functions of crop and align face images #################
24
+ ###########################################################################
25
+ def get_5_points(img):
26
+ dets = detector(img, 1)
27
+ if len(dets) == 0:
28
+ return None
29
+ areas = []
30
+ if len(dets) > 1:
31
+ print('\t###### Warning: more than one face is detected. In this version, we only handle the largest one.')
32
+ for i in range(len(dets)):
33
+ area = (dets[i].rect.right()-dets[i].rect.left())*(dets[i].rect.bottom()-dets[i].rect.top())
34
+ areas.append(area)
35
+ ins = areas.index(max(areas))
36
+ shape = sp(img, dets[ins].rect)
37
+ single_points = []
38
+ for i in range(5):
39
+ single_points.append([shape.part(i).x, shape.part(i).y])
40
+ return np.array(single_points)
41
+
42
+ def align_and_save(img_path, save_path, save_input_path, save_param_path, upsample_scale=2):
43
+ out_size = (512, 512)
44
+ img = dlib.load_rgb_image(img_path)
45
+ h,w,_ = img.shape
46
+ source = get_5_points(img)
47
+ if source is None: #
48
+ print('\t################ No face is detected')
49
+ return
50
+ tform = trans.SimilarityTransform()
51
+ tform.estimate(source, reference)
52
+ M = tform.params[0:2,:]
53
+ crop_img = cv2.warpAffine(img, M, out_size)
54
+ io.imsave(save_path, crop_img) #save the crop and align face
55
+ io.imsave(save_input_path, img) #save the whole input image
56
+ tform2 = trans.SimilarityTransform()
57
+ tform2.estimate(reference, source*upsample_scale)
58
+ # inv_M = cv2.invertAffineTransform(M)
59
+ np.savetxt(save_param_path, tform2.params[0:2,:],fmt='%.3f') #save the inverse affine parameters
60
+
61
+ def reverse_align(input_path, face_path, param_path, save_path, upsample_scale=2):
62
+ out_size = (512, 512)
63
+ input_img = dlib.load_rgb_image(input_path)
64
+ h,w,_ = input_img.shape
65
+ face512 = dlib.load_rgb_image(face_path)
66
+ inv_M = np.loadtxt(param_path)
67
+ inv_crop_img = cv2.warpAffine(face512, inv_M, (w*upsample_scale,h*upsample_scale))
68
+ mask = np.ones((512, 512, 3), dtype=np.float32) #* 255
69
+ inv_mask = cv2.warpAffine(mask, inv_M, (w*upsample_scale,h*upsample_scale))
70
+ upsample_img = cv2.resize(input_img, (w*upsample_scale, h*upsample_scale))
71
+ inv_mask_erosion_removeborder = cv2.erode(inv_mask, np.ones((2 * upsample_scale, 2 * upsample_scale), np.uint8))# to remove the black border
72
+ inv_crop_img_removeborder = inv_mask_erosion_removeborder * inv_crop_img
73
+ total_face_area = np.sum(inv_mask_erosion_removeborder)//3
74
+ w_edge = int(total_face_area ** 0.5) // 20 #compute the fusion edge based on the area of face
75
+ erosion_radius = w_edge * 2
76
+ inv_mask_center = cv2.erode(inv_mask_erosion_removeborder, np.ones((erosion_radius, erosion_radius), np.uint8))
77
+ blur_size = w_edge * 2
78
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center,(blur_size + 1, blur_size + 1),0)
79
+ merge_img = inv_soft_mask * inv_crop_img_removeborder + (1 - inv_soft_mask) * upsample_img
80
+ io.imsave(save_path, merge_img.astype(np.uint8))
81
+
82
+ ###########################################################################
83
+ ################ functions of preparing the test images ###################
84
+ ###########################################################################
85
+ def AddUpSample(img):
86
+ return img.resize((512, 512), Image.BICUBIC)
87
+ def get_part_location(partpath, imgname):
88
+ Landmarks = []
89
+ if not os.path.exists(os.path.join(partpath,imgname+'.txt')):
90
+ print(os.path.join(partpath,imgname+'.txt'))
91
+ print('\t################ No landmark file')
92
+ return 0
93
+ with open(os.path.join(partpath,imgname+'.txt'),'r') as f:
94
+ for line in f:
95
+ tmp = [np.float(i) for i in line.split(' ') if i != '\n']
96
+ Landmarks.append(tmp)
97
+ Landmarks = np.array(Landmarks)
98
+ Map_LE = list(np.hstack((range(17,22), range(36,42))))
99
+ Map_RE = list(np.hstack((range(22,27), range(42,48))))
100
+ Map_NO = list(range(29,36))
101
+ Map_MO = list(range(48,68))
102
+ try:
103
+ #left eye
104
+ Mean_LE = np.mean(Landmarks[Map_LE],0)
105
+ L_LE = np.max((np.max(np.max(Landmarks[Map_LE],0) - np.min(Landmarks[Map_LE],0))/2,16))
106
+ Location_LE = np.hstack((Mean_LE - L_LE + 1, Mean_LE + L_LE)).astype(int)
107
+ #right eye
108
+ Mean_RE = np.mean(Landmarks[Map_RE],0)
109
+ L_RE = np.max((np.max(np.max(Landmarks[Map_RE],0) - np.min(Landmarks[Map_RE],0))/2,16))
110
+ Location_RE = np.hstack((Mean_RE - L_RE + 1, Mean_RE + L_RE)).astype(int)
111
+ #nose
112
+ Mean_NO = np.mean(Landmarks[Map_NO],0)
113
+ L_NO = np.max((np.max(np.max(Landmarks[Map_NO],0) - np.min(Landmarks[Map_NO],0))/2,16))
114
+ Location_NO = np.hstack((Mean_NO - L_NO + 1, Mean_NO + L_NO)).astype(int)
115
+ #mouth
116
+ Mean_MO = np.mean(Landmarks[Map_MO],0)
117
+ L_MO = np.max((np.max(np.max(Landmarks[Map_MO],0) - np.min(Landmarks[Map_MO],0))/2,16))
118
+ Location_MO = np.hstack((Mean_MO - L_MO + 1, Mean_MO + L_MO)).astype(int)
119
+ except:
120
+ return 0
121
+ return torch.from_numpy(Location_LE).unsqueeze(0), torch.from_numpy(Location_RE).unsqueeze(0), torch.from_numpy(Location_NO).unsqueeze(0), torch.from_numpy(Location_MO).unsqueeze(0)
122
+
123
+ def obtain_inputs(img_path, Landmark_path, img_name):
124
+ A_paths = os.path.join(img_path,img_name)
125
+ A = Image.open(A_paths).convert('RGB')
126
+ Part_locations = get_part_location(Landmark_path, img_name)
127
+ if Part_locations == 0:
128
+ return 0
129
+ C = A
130
+ A = AddUpSample(A)
131
+ A = transforms.ToTensor()(A)
132
+ C = transforms.ToTensor()(C)
133
+ A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) #
134
+ C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) #
135
+ return {'A':A.unsqueeze(0), 'C':C.unsqueeze(0), 'A_paths': A_paths,'Part_locations': Part_locations}
136
+
137
+ if __name__ == '__main__':
138
+ opt = TestOptions().parse()
139
+ opt.nThreads = 1 # test code only supports nThreads = 1
140
+ opt.batchSize = 1 # test code only supports batchSize = 1
141
+ opt.serial_batches = True # no shuffle
142
+ opt.no_flip = True # no flip
143
+ opt.display_id = -1 # no visdom display
144
+ opt.which_epoch = 'latest' #
145
+
146
+ #######################################################################
147
+ ########################### Test Param ################################
148
+ #######################################################################
149
+ # opt.gpu_ids = [0] # gpu id. if use cpu, set opt.gpu_ids = []
150
+ # TestImgPath = './TestData/TestWhole' # test image path
151
+ # ResultsDir = './Results/TestWholeResults' #save path
152
+ # UpScaleWhole = 4 # the upsamle scale. It should be noted that our face results are fixed to 512.
153
+ TestImgPath = opt.test_path
154
+ ResultsDir = opt.results_dir
155
+ UpScaleWhole = opt.upscale_factor
156
+
157
+ print('\n###################### Now Running the X {} task ##############################'.format(UpScaleWhole))
158
+
159
+ #######################################################################
160
+ ###########Step 1: Crop and Align Face from the whole Image ###########
161
+ #######################################################################
162
+ print('\n###############################################################################')
163
+ print('####################### Step 1: Crop and Align Face ###########################')
164
+ print('###############################################################################\n')
165
+
166
+ detector = dlib.cnn_face_detection_model_v1('./packages/mmod_human_face_detector.dat')
167
+ sp = dlib.shape_predictor('./packages/shape_predictor_5_face_landmarks.dat')
168
+ reference = np.load('./packages/FFHQ_template.npy') / 2
169
+ SaveInputPath = os.path.join(ResultsDir,'Step0_Input')
170
+ if not os.path.exists(SaveInputPath):
171
+ os.makedirs(SaveInputPath)
172
+ SaveCropPath = os.path.join(ResultsDir,'Step1_CropImg')
173
+ if not os.path.exists(SaveCropPath):
174
+ os.makedirs(SaveCropPath)
175
+
176
+ SaveParamPath = os.path.join(ResultsDir,'Step1_AffineParam') #save the inverse affine parameters
177
+ if not os.path.exists(SaveParamPath):
178
+ os.makedirs(SaveParamPath)
179
+
180
+ ImgPaths = make_dataset(TestImgPath)
181
+ for i, ImgPath in enumerate(ImgPaths):
182
+ ImgName = os.path.split(ImgPath)[-1]
183
+ print('Crop and Align {} image'.format(ImgName))
184
+ SavePath = os.path.join(SaveCropPath,ImgName)
185
+ SaveInput = os.path.join(SaveInputPath,ImgName)
186
+ SaveParam = os.path.join(SaveParamPath, ImgName+'.npy')
187
+ align_and_save(ImgPath, SavePath, SaveInput, SaveParam, UpScaleWhole)
188
+
189
+ #######################################################################
190
+ ####### Step 2: Face Landmark Detection from the Cropped Image ########
191
+ #######################################################################
192
+ print('\n###############################################################################')
193
+ print('####################### Step 2: Face Landmark Detection #######################')
194
+ print('###############################################################################\n')
195
+
196
+ SaveLandmarkPath = os.path.join(ResultsDir,'Step2_Landmarks')
197
+ if len(opt.gpu_ids) > 0:
198
+ dev = 'cuda:{}'.format(opt.gpu_ids[0])
199
+ else:
200
+ dev = 'cpu'
201
+ FD = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,device=dev, flip_input=False)
202
+ if not os.path.exists(SaveLandmarkPath):
203
+ os.makedirs(SaveLandmarkPath)
204
+ ImgPaths = make_dataset(SaveCropPath)
205
+ for i,ImgPath in enumerate(ImgPaths):
206
+ ImgName = os.path.split(ImgPath)[-1]
207
+ print('Detecting {}'.format(ImgName))
208
+ Img = io.imread(ImgPath)
209
+ try:
210
+ PredsAll = FD.get_landmarks(Img)
211
+ except:
212
+ print('\t################ Error in face detection, continue...')
213
+ continue
214
+ if PredsAll is None:
215
+ print('\t################ No face, continue...')
216
+ continue
217
+ ins = 0
218
+ if len(PredsAll)!=1:
219
+ hights = []
220
+ for l in PredsAll:
221
+ hights.append(l[8,1] - l[19,1])
222
+ ins = hights.index(max(hights))
223
+ # print('\t################ Warning: Detected too many face, only handle the largest one...')
224
+ # continue
225
+ preds = PredsAll[ins]
226
+ AddLength = np.sqrt(np.sum(np.power(preds[27][0:2]-preds[33][0:2],2)))
227
+ SaveName = ImgName+'.txt'
228
+ np.savetxt(os.path.join(SaveLandmarkPath,SaveName),preds[:,0:2],fmt='%.3f')
229
+
230
+ #######################################################################
231
+ ####################### Step 3: Face Restoration ######################
232
+ #######################################################################
233
+
234
+ print('\n###############################################################################')
235
+ print('####################### Step 3: Face Restoration ##############################')
236
+ print('###############################################################################\n')
237
+
238
+ SaveRestorePath = os.path.join(ResultsDir,'Step3_RestoreCropFace')# Only Face Results
239
+ if not os.path.exists(SaveRestorePath):
240
+ os.makedirs(SaveRestorePath)
241
+ model = create_model(opt)
242
+ model.setup(opt)
243
+ # test
244
+ ImgPaths = make_dataset(SaveCropPath)
245
+ total = 0
246
+ for i, ImgPath in enumerate(ImgPaths):
247
+ ImgName = os.path.split(ImgPath)[-1]
248
+ print('Restoring {}'.format(ImgName))
249
+ torch.cuda.empty_cache()
250
+ data = obtain_inputs(SaveCropPath, SaveLandmarkPath, ImgName)
251
+ if data == 0:
252
+ print('\t################ Error in landmark file, continue...')
253
+ continue #
254
+ total = total + 1
255
+ model.set_input(data)
256
+ try:
257
+ model.test()
258
+ visuals = model.get_current_visuals()
259
+ save_crop(visuals,os.path.join(SaveRestorePath,ImgName))
260
+ except Exception as e:
261
+ print('\t################ Error in enhancing this image: {}'.format(str(e)))
262
+ print('\t################ continue...')
263
+ continue
264
+
265
+ #######################################################################
266
+ ############ Step 4: Paste the Results to the Input Image #############
267
+ #######################################################################
268
+
269
+ print('\n###############################################################################')
270
+ print('############### Step 4: Paste the Restored Face to the Input Image ############')
271
+ print('###############################################################################\n')
272
+
273
+ SaveFianlPath = os.path.join(ResultsDir,'Step4_FinalResults')
274
+ if not os.path.exists(SaveFianlPath):
275
+ os.makedirs(SaveFianlPath)
276
+ ImgPaths = make_dataset(SaveRestorePath)
277
+ for i,ImgPath in enumerate(ImgPaths):
278
+ ImgName = os.path.split(ImgPath)[-1]
279
+ print('Final Restoring {}'.format(ImgName))
280
+ WholeInputPath = os.path.join(TestImgPath,ImgName)
281
+ FaceResultPath = os.path.join(SaveRestorePath, ImgName)
282
+ ParamPath = os.path.join(SaveParamPath, ImgName+'.npy')
283
+ SaveWholePath = os.path.join(SaveFianlPath, ImgName)
284
+ reverse_align(WholeInputPath, FaceResultPath, ParamPath, SaveWholePath, UpScaleWhole)
285
+
286
+ print('\nAll results are saved in {} \n'.format(ResultsDir))
287
+